from typing import List, Tuple, Union
import numpy as np
import pandas as pd
import plotly
import plotly.graph_objs as go
import shap
from datasets import DatasetDict, concatenate_datasets
from DashAI.back.core.schema_fields import (
BaseSchema,
bool_field,
enum_field,
int_field,
schema_field,
)
from DashAI.back.explainability.local_explainer import BaseLocalExplainer
from DashAI.back.models import BaseModel
class KernelShapSchema(BaseSchema):
"""Kernel SHAP is a model-agnostic explainability method for approximating SHAP
values to explain the output of machine learning model by attributing contributions
of each feature to the model's prediction.
"""
link: schema_field(
enum_field(enum=["identity", "logit"]),
placeholder="identity",
description="Link function to connect the feature importance values to the "
"model's outputs. Options are 'identity' to use identity function or 'logit' "
"to use log-odds function.",
) # type: ignore
fit_parameter_sample_background_data: schema_field(
bool_field(),
placeholder=False,
description="Parameter to fit the explainer. 'true' if the background "
"data must be sampled, otherwise the entire train data set is used. "
"Smaller datasets speed up the algorithm run time.",
) # type: ignore
fit_parameter_n_background_samples: schema_field(
int_field(ge=1),
placeholder=1,
description="Parameter to fit the explainer. If the parameter "
"'sample_background_data' is 'true', the number of background "
"data samples to be drawn.",
) # type: ignore
fit_parameter_sampling_method: schema_field(
enum_field(enum=["shuffle", "kmeans"]),
placeholder="shuffle",
description="Parameter to fit the explainer. If the parameter "
"'sample_background_data' is 'true', whether to sample random "
"samples with 'shuffle' option or summarize the data set with "
"'kmeans' option. If 'categorical_features' is 'true', 'shuffle' "
"options used by default.",
) # type: ignore
[docs]class KernelShap(BaseLocalExplainer):
"""Kernel SHAP is a model-agnostic explainability method for approximating SHAP
values to explain the output of machine learning model by attributing contributions
of each feature to the model's prediction.
"""
COMPATIBLE_COMPONENTS = ["TabularClassificationTask"]
SCHEMA = KernelShapSchema
[docs] def __init__(
self,
model: BaseModel,
link: str = "identity",
):
"""Initialize a new instance of a KernelShap explainer.
Parameters
----------
model: BaseModel
Model to be explained.
link: str
String indicating the link function to connect the feature importance
values to the model's outputs. Options are 'identity' to use identity
function or 'logit'to use log-odds function.
"""
super().__init__(model)
self.link = link
def _sample_background_data(
self,
background_data: np.array,
n_background_samples: int,
sampling_method: str = "shuffle",
categorical_features: bool = False,
):
"""Method to sample the background dataset used to fit the explainer.
Parameters
----------
background_data: np.array
Data used to estimate feature attributions and establish a baseline for
the calculation of SHAP values.
n_background_samples: int
Number of background data samples used to estimate of SHAP values. By
default, the entire train dataset is used, but this option limits the
samples to reduce run times.
sampling_method: str
Sampling method used to select the background samples. Options are
'shuffle' to select random samples or 'kmeans' to summarise the data
set. 'kmeans' option can only be used if there are no categorical
features.
categorical_features: bool
Bool indicating whether some features are categorical.
Returns
-------
pd.DataFrame
pandas DataFrame with the background data used to fit the
explainer.
"""
samplers = {"shuffle": shap.sample, "kmeans": shap.kmeans}
if categorical_features:
data = samplers["shuffle"](background_data, n_background_samples)
else:
data = samplers[sampling_method](background_data, n_background_samples)
return data
def fit(
self,
background_dataset: Tuple[DatasetDict, DatasetDict],
sample_background_data: str = "false",
n_background_samples: Union[int, None] = None,
sampling_method: Union[str, None] = None,
):
"""Method to train the KernelShap explainer.
Parameters
----------
background_data: Tuple[DatasetDict, DatasetDict]
Tuple with (input_samples, targets). Input samples are used to estimate
feature attributions and establish a baseline for the calculation of
SHAP values.
sample_background_data: bool
True if the background data must be sampled. Smaller data sets speed up
the algorithm run time. False by default.
n_background_samples: int
Number of background data samples used to estimate of SHAP values if
``sample_background_data=True``.
sampling_method: str
Sampling method used to select the background samples if
``sample_background_data=True``. Options are 'shuffle' to select random
samples or 'kmeans' to summarise the data set. 'kmeans' option can only
be used if there are no categorical features.
Returns
-------
KernelShap object
"""
sample_background_data = bool(sample_background_data)
x, y = background_dataset
background_data = x["train"].to_pandas()
features = x["train"].features
feature_names = list(features)
categorical_features = False
for feature in features:
if features[feature]._type == "ClassLabel":
categorical_features = True
if sample_background_data:
background_data = self._sample_background_data(
background_data.to_numpy(),
n_background_samples,
sampling_method,
categorical_features,
)
# TODO: consider the case where the predictor is not a Sklearn model
self.explainer = shap.KernelExplainer(
model=self.model.predict,
data=background_data,
feature_names=feature_names,
link=self.link,
)
# Metadata
output_column = list(y["train"].features)[0]
target_names = y["train"].features[output_column].names
self.metadata = {"feature_names": feature_names, "target_names": target_names}
return self
def explain_instance(
self,
instances: DatasetDict,
):
"""Method for explaining the model prediciton of an instance using the Kernel
Shap method.
Parameters
----------
instances: DatasetDict
Instances to be explained.
Returns
-------
dict
dictionary with the shap values for each instance.
"""
splits = list(instances.keys())
X = instances[splits[0]]
for split in splits[1:]:
X = concatenate_datasets([X, instances[split]])
X = X.to_pandas()
predictions = self.model.predict(x_pred=X)
# TODO: evaluate args nsamples y l1_reg
shap_values = self.explainer.shap_values(X=X)
# shap_values has size (n_clases, n_instances, n_features)
# Reorder shap values: (n_instances, n_clases, n_features)
shap_values = np.array(shap_values).swapaxes(1, 0)
explanation = {
"metadata": self.metadata,
"base_values": np.round(self.explainer.expected_value, 3).tolist(),
}
for i, (instance, prediction, contribution_values) in enumerate(
zip(X.to_numpy(), predictions, shap_values) # noqa B905
):
explanation[i] = {
"instance_values": instance.tolist(),
"model_prediction": prediction.tolist(),
"shap_values": np.round(contribution_values, 3).tolist(),
}
return explanation
def _create_plot(
self, data: pd.DataFrame, base_value: float, y_pred_pbb: float, y_pred_name: str
):
"""Helper method to create the explanation plot using plotly.
Parameters
----------
data: pd.DataFrame
dataframe containing the data to be plotted.
base_value: float
value to set where the bar base is drawn.
y_pred_pbb: float
predicted probability.
y_pred_name
name of the predicted class.
Returns:
JSON
JSON containing the information of the explanation plot
to be rendered.
"""
x = data["shap_values"].to_numpy()
y = data["label"].to_numpy()
measure = np.repeat("relative", len(y))
texts = data["shap_values"].to_numpy()
fig = go.Figure(
go.Waterfall(
x=x,
y=y,
base=base_value,
name="20",
orientation="h",
measure=measure,
text=texts,
textposition="auto",
constraintext="inside",
decreasing={"marker": {"color": "rgb(47,138,196)"}},
increasing={"marker": {"color": "rgb(231,63,116)"}},
)
)
fig.update_layout(
margin={"pad": 20, "l": 100, "r": 130, "t": 60, "b": 10},
xaxis={
"tickangle": -90,
"tickwidth": 100,
"title_text": "",
},
yaxis={"showgrid": True, "tickwidth": 150},
)
fig.update_xaxes(
gridcolor="#1B2631",
gridwidth=1,
tickmode="array",
nticks=2,
tickvals=[base_value, y_pred_pbb],
ticktext=[f"E[f(x)]={base_value}", f"f(x)={y_pred_pbb}"],
tickangle=0,
showgrid=True,
)
plot_note = (
f"The predicted class was {y_pred_name} with probability f(x)={y_pred_pbb}."
)
fig.add_annotation(
align="center",
arrowsize=0.3,
arrowwidth=0.1,
font={"size": 12},
showarrow=False,
text=plot_note,
xanchor="center",
yanchor="bottom",
xref="paper",
yref="paper",
y=-0.3,
)
return plotly.io.to_json(fig)
def plot(self, explanation: List[dict]):
"""Method to create the explanation plot using plotly.
Parameters
----------
explanation: dict
dictionary with the explanation generated by the explainer.
Returns:
List[dict]
list of JSONs containing the information of the explanation plot
to be rendered.
"""
explanation = explanation.copy()
max_features = 8
metadata = explanation.pop("metadata")
base_values = explanation.pop("base_values")
feature_names = metadata["feature_names"]
target_names = metadata["target_names"]
plots = []
for i in explanation:
instance_values = explanation[i]["instance_values"]
model_prediction = explanation[i]["model_prediction"]
y_pred_class = np.argmax(model_prediction)
y_pred_name = target_names[y_pred_class]
y_pred_pbb = np.round(model_prediction[y_pred_class], 2)
shap_values = explanation[i]["shap_values"][y_pred_class]
data = pd.DataFrame(
{
"values": instance_values,
"shap_values": shap_values,
"features": feature_names,
}
)
data["shap_values_abs"] = np.abs(data["shap_values"])
data = data.sort_values(by="shap_values_abs", ascending=True)
if len(data) > max_features:
data_1 = data.iloc[-max_features:, :]
data_2 = data.iloc[:-max_features, :]
others = pd.DataFrame.from_dict(
data={
"values": [None],
"shap_values": np.round(data_2["shap_values"].sum(), 3),
"shap_values_abs": [None],
"features": ["Others"],
}
)
data = pd.concat([others, data_1])
data["label"] = data["features"] + "=" + data["values"].map(str)
base_value = base_values[y_pred_class]
plot = self._create_plot(data, base_value, y_pred_pbb, y_pred_name)
plots.append(plot)
return plots