Source code for DashAI.back.explainability.explainers.partial_dependence

from typing import List, Tuple

import numpy as np
import pandas as pd
import plotly
import plotly.express as px
from datasets import DatasetDict
from sklearn.inspection import partial_dependence

from DashAI.back.core.schema_fields import (
    BaseSchema,
    float_field,
    int_field,
    schema_field,
)
from DashAI.back.explainability.global_explainer import BaseGlobalExplainer
from DashAI.back.models import BaseModel


class PartialDependenceSchema(BaseSchema):
    """PartialDependence of a feature shows the average prediction of a machine
    learning model for each possible value of the feature.
    """

    grid_resolution: schema_field(
        int_field(ge=1),
        placeholder=100,
        description="The number of equidistant points to split the range of "
        "the target feature",
    )  # type: ignore

    lower_percentile: schema_field(
        float_field(ge=0, le=0.99),
        placeholder=0.05,
        description="The lower percentile used to limit the feature values.",
    )  # type: ignore

    upper_percentile: schema_field(
        float_field(ge=0.01, le=1),
        placeholder=0.95,
        description="The upper percentile used to limit the feature values.",
    )  # type: ignore


[docs]class PartialDependence(BaseGlobalExplainer): """PartialDependence is a model-agnostic explainability method that shows the average prediction of a machine learning model for each possible value of a feature. """ COMPATIBLE_COMPONENTS = ["TabularClassificationTask"] SCHEMA = PartialDependenceSchema
[docs] def __init__( self, model: BaseModel, lower_percentile: float = 0.05, upper_percentile: float = 0.95, grid_resolution: int = 100, ): """Initialize a new instance of a PartialDependence explainer. Parameters ---------- model: BaseModel Model to be explained. lower_percentile: int The lower and upper percentile used to limit the feature values. Defaults to 0.05 upper_percentile: int The lower and upper percentile used to limit the feature values. Default to 0.95 grid_resolution: int The number of equidistant points to split the range of the target feature. Defaults to 100. """ assert ( upper_percentile > lower_percentile ), "upper_percentile value must be greater than lower_percentile" super().__init__(model) self.percentiles = (lower_percentile, upper_percentile) self.grid_resolution = grid_resolution self.explanation = None
def explain(self, dataset: Tuple[DatasetDict, DatasetDict]): """Method to generate the explanation Parameters ---------- X: Tuple[DatasetDict, DatasetDict] Tuple with (input_samples, targets). Input samples are used to evaluate the partial dependence of each feature Returns: dict Dictionary with metadata and the partial dependence of each feature """ x, y = dataset x_test = x["test"].to_pandas() features = x["test"].features features_names = list(features) categorical_features = [ 1 if features[feature]._type == "ClassLabel" else 0 for feature in features ] output_column = list(y["test"].features.keys())[0] target_names = y["test"].features[output_column].names explanation = {"metadata": {"target_names": target_names}} for idx in range(len(features)): pd = partial_dependence( estimator=self.model, X=x_test, features=idx, categorical_features=categorical_features, feature_names=features, percentiles=self.percentiles, grid_resolution=self.grid_resolution, kind="average", ) explanation[features_names[idx]] = { "grid_values": np.round(pd["values"][0], 3).tolist(), "average": np.round(pd["average"], 3).tolist(), } return explanation def _create_plot(self, data: List[pd.DataFrame]) -> List[dict]: """Helper method to create the explanation plot using plotly. Parameters ---------- data: List dictionary with the explanation generated by the explainer. Returns: List[dict] list of JSON containing the information of the explanation plot to be rendered. """ fig = px.line( data[0], x=data[0]["grid_values"], y=data[0].iloc[:, 0], labels={"grid_values": "Feature value"}, ) fig.update_layout( yaxis_title="Partial Dependence", updatemenus=[ { "x": 0, "xanchor": "left", "y": 1.2, "yanchor": "top", "buttons": [ { "label": data[i].columns[0], "method": "restyle", "args": [ { "x": [data[i]["grid_values"]], "y": [data[i].iloc[:, 0]], }, ], } for i in range(len(data)) ], } ], ) plot_note = ( "This graph shows the marginal effect of the selected feature " "on the <br> probability predicted by the model for the selected " "class" ) fig.add_annotation( align="center", arrowsize=0.3, arrowwidth=0.1, borderwidth=2, font={"size": 12}, showarrow=False, text=plot_note, xanchor="center", yanchor="bottom", xref="paper", yref="paper", y=-0.35, ) return [plotly.io.to_json(fig)] def plot(self, explanation: dict) -> List[dict]: """Method to create the explanation plot. Parameters ---------- explanation: dict dictionary with the explanation generated by the explainer. Returns: List[dict] list of JSONs containing the information of the explanation plot to be rendered. """ explanation = explanation.copy() metadata = explanation.pop("metadata") target_names = metadata["target_names"] dfs = [] for feature, data in explanation.items(): average = data["average"] grid_values = data["grid_values"] # Binary-classification case if len(target_names) == 2: target_names = target_names[1] for target, values in zip(target_names, average): # noqa B905 column_name = f"Feature: {feature} - Class: {target}" data = pd.DataFrame({column_name: values}) data["grid_values"] = grid_values dfs.append(data) return self._create_plot(dfs)