from typing import Dict, List, Tuple, Union
import numpy as np
import pandas as pd
import plotly
import plotly.express as px
from datasets import DatasetDict
from sklearn.inspection import permutation_importance
from sklearn.metrics import accuracy_score, balanced_accuracy_score, make_scorer
from sklearn.preprocessing import LabelEncoder
from DashAI.back.core.schema_fields import (
BaseSchema,
enum_field,
float_field,
int_field,
schema_field,
)
from DashAI.back.explainability.global_explainer import BaseGlobalExplainer
from DashAI.back.models import BaseModel
class PermutationFeatureImportanceSchema(BaseSchema):
"""
Permutation Feature Importance is a explanation method to asses the
importance of each feature in a model by evaluating how much the model's
performance decreases when the values of a specific feature are randomly
shuffled.
"""
scoring: schema_field(
enum_field(enum=["accuracy", "balanced_accuracy"]),
placeholder="accuracy",
description="Scorer to evaluate how the perfomance of the model "
"changes when a particular feature is shuffled.",
) # type: ignore
n_repeats: schema_field(
int_field(ge=1),
placeholder=20,
description="Number of times to permute a feature.",
) # type: ignore
random_state: schema_field(
int_field(),
placeholder=0,
description="Seed for the random number generator to control the "
"permutations of each feature.",
) # type: ignore
max_samples_fraction: schema_field(
float_field(ge=0.0, le=1.0),
placeholder=1.0,
description="The fraction of samples to draw from the test set to "
"calculate feature importance at each repetition.",
) # type: ignore
[docs]
class PermutationFeatureImportance(BaseGlobalExplainer):
"""Permutation Feature Importance is a explanation method to asses the importance
of each feature in a model by evaluating how much the model's performance
decreases when the values of a specific feature are randomly shuffled.
"""
COMPATIBLE_COMPONENTS = ["TabularClassificationTask"]
DISPLAY_NAME = "Permutation Feature Importance"
COLOR = "#800080"
SCHEMA = PermutationFeatureImportanceSchema
[docs]
def __init__(
self,
model: BaseModel,
scoring: Union[str, List[str], None] = None,
n_repeats: int = 5,
random_state: Union[int, None] = None,
max_samples_fraction: float = 0.5,
):
super().__init__(model)
metrics = {
"accuracy": accuracy_score,
"balanced_accuracy": balanced_accuracy_score,
}
self.scoring = metrics[scoring]
self.n_repeats = n_repeats
self.random_state = random_state
self.max_samples_fraction = max_samples_fraction
def _get_feature_groups(self, columns: List[str]) -> Dict[str, List[int]]:
"""Group one-hot encoded columns back to their original feature."""
feature_groups = {}
if (
hasattr(self.model, "one_hot_encoder")
and self.model.one_hot_encoder is not None
and hasattr(self.model, "categorical_columns")
and self.model.categorical_columns
):
encoder = self.model.one_hot_encoder
original_cat_cols = self.model.categorical_columns
encoded_feature_names = list(
encoder.get_feature_names_out(original_cat_cols)
)
for orig_col in original_cat_cols:
prefix = f"{orig_col}_"
indices = [
columns.index(enc_col)
for enc_col in encoded_feature_names
if enc_col.startswith(prefix) and enc_col in columns
]
if indices:
feature_groups[orig_col] = indices
# Add non-categorical columns
for idx, col in enumerate(columns):
if col not in encoded_feature_names:
feature_groups[col] = [idx]
else:
for idx, col in enumerate(columns):
feature_groups[col] = [idx]
return feature_groups
def _calculate_grouped_importance(
self,
x_data: pd.DataFrame,
y: pd.DataFrame,
feature_groups: Dict[str, List[int]],
max_samples: int,
) -> Dict[str, Dict[str, np.ndarray]]:
"""Calculate permutation importance for grouped features."""
rng = np.random.RandomState(self.random_state)
n_samples = min(max_samples, len(x_data))
sample_indices = rng.choice(len(x_data), size=n_samples, replace=False)
x_sample = x_data.iloc[sample_indices].copy().reset_index(drop=True)
y_sample = y.iloc[sample_indices].copy().reset_index(drop=True)
y_array = y_sample.to_numpy().ravel()
column_names = list(x_sample.columns)
# Access the underlying sklearn model
sklearn_model = self.model
def get_predictions(data):
# Keep as DataFrame to preserve column names
return sklearn_model.predict_proba(data)
def calc_score(y_true, y_pred_probas):
y_pred = np.argmax(y_pred_probas, axis=1)
return self.scoring(y_true, y_pred)
baseline_predictions = get_predictions(x_sample)
baseline_score = calc_score(y_array, baseline_predictions)
results = {"features": [], "importances_mean": [], "importances_std": []}
for feature_name, col_indices in feature_groups.items():
importances = []
# Get column names for this group
group_cols = [column_names[i] for i in col_indices]
for _ in range(self.n_repeats):
# Work with DataFrame to preserve column names
x_permuted = x_sample.copy()
# Permute rows for this group of columns
permutation = rng.permutation(n_samples)
# Get the block of columns, permute rows, put back
original_block = x_sample[group_cols].to_numpy()
permuted_block = original_block[permutation, :]
x_permuted[group_cols] = permuted_block
permuted_predictions = get_predictions(x_permuted)
permuted_score = calc_score(y_array, permuted_predictions)
importance = baseline_score - permuted_score
importances.append(importance)
results["features"].append(feature_name)
results["importances_mean"].append(np.mean(importances))
results["importances_std"].append(np.std(importances))
return results
def explain(self, dataset: Tuple[DatasetDict, DatasetDict]):
"""Method for calculating the importance of features in the model."""
x, y = dataset
x_test = x["test"]
y_test = y["test"]
X_df = x_test.to_pandas()
y_df = y_test.to_pandas()
y_values = y_df.to_numpy().ravel()
if y_values.dtype == object or y_values.dtype.kind in ("U", "S"):
if (
hasattr(self.model, "label_encoder")
and self.model.label_encoder is not None
):
y_encoded = self.model.label_encoder.transform(y_values)
else:
le = LabelEncoder()
y_encoded = le.fit_transform(y_values)
y_df = pd.DataFrame(y_encoded, columns=y_df.columns)
input_columns = list(X_df.columns)
feature_groups = self._get_feature_groups(input_columns)
max_samples = max(int(len(x_test) * self.max_samples_fraction), 1)
has_grouped_features = any(
len(indices) > 1 for indices in feature_groups.values()
)
if has_grouped_features:
results = self._calculate_grouped_importance(
X_df, y_df, feature_groups, max_samples
)
return {
"features": results["features"],
"importances_mean": np.round(results["importances_mean"], 3).tolist(),
"importances_std": np.round(results["importances_std"], 3).tolist(),
}
else:
def patched_metric(y_true, y_pred_probas):
return self.scoring(y_true, np.argmax(y_pred_probas, axis=1))
pfi = permutation_importance(
estimator=self.model,
X=X_df,
y=y_df,
scoring=make_scorer(patched_metric),
n_repeats=self.n_repeats,
random_state=self.random_state,
max_samples=max_samples,
)
return {
"features": input_columns,
"importances_mean": np.round(pfi["importances_mean"], 3).tolist(),
"importances_std": np.round(pfi["importances_std"], 3).tolist(),
}
def _create_plot(self, data: pd.DataFrame, n_features: int):
"""Helper method to create the explanation plot using plotly."""
fig = px.bar(
data.iloc[-n_features:],
x=data.iloc[-n_features:]["importances_mean"],
y=data.iloc[-n_features:]["features"],
error_x=data.iloc[-n_features:]["importances_std"],
)
fig.update_layout(
xaxis_title="Importance",
yaxis_title=None,
annotations=[
{
"text": "",
"showarrow": False,
"x": 0,
"y": 1.15,
"xanchor": "left",
"xref": "paper",
"yref": "paper",
"yanchor": "top",
}
],
updatemenus=[
{
"x": 0,
"xanchor": "left",
"y": 1.2,
"yanchor": "top",
"buttons": [
{
"label": f"N° features: {len(data.iloc[-c:,])}",
"method": "restyle",
"args": [
{
"x": [data.iloc[-c:]["importances_mean"]],
"y": [data.iloc[-c:]["features"]],
"error_x": [data.iloc[-c:]["importances_std"]],
},
],
}
for c in range(len(data))
],
}
],
)
return [plotly.io.to_json(fig)]
def plot(self, explanation: dict) -> List[dict]:
"""Method to create the explanation plot."""
n_features = 10
data = pd.DataFrame.from_dict(explanation)
data = data.sort_values(by=["importances_mean"], ascending=True)
if n_features > len(data):
n_features = len(data)
return self._create_plot(data, n_features)