PermutationFeatureImportance
Global explainer that ranks features by the drop in model performance when permuted.
Permutation Feature Importance (PFI) measures the importance of a feature by
randomly shuffling its values across the test set and recording the resulting
decrease in a scoring metric. A large decrease indicates that the model relies
heavily on that feature; a small (or zero) decrease indicates the feature
contributes little. The process is repeated n_repeats times to produce a
mean importance and standard deviation, which quantify both rank order and
uncertainty.
Unlike impurity-based importance (from decision trees), PFI is computed on held-out data and is therefore not biased towards high-cardinality features. It is model-agnostic and captures interaction effects, but assumes that permuting a feature does not violate important correlations in the data.
References
- [1] Breiman, L. (2001). "Random Forests." Machine Learning, 45(1), 5-32.
- [2] Fisher, A. et al. (2019). "All Models are Wrong, but Many are Useful." JMLR, 20(177), 1-81. https://arxiv.org/abs/1801.01489
- [3] https://scikit-learn.org/stable/modules/permutation_importance.html
Parameters
- scoring : string, default=
accuracy - Metric used to evaluate how the model's performance changes when a particular feature is shuffled.
- n_repeats : integer, default=
20 - Number of times to permute a feature.
- random_state : integer, default=
0 - Seed for the random number generator to control permutations of each feature.
- max_samples_fraction : number, default=
1.0 - Fraction of samples to draw from the test set to calculate feature importance at each repetition.
Methods
explain(self, dataset)
PermutationFeatureImportanceCompute permutation feature importance for the fitted model.
Parameters
- dataset : tuple of (DatasetDict, DatasetDict)
- A
(x, y)pair where each element is a DatasetDict with at least a"test"split.
Returns
- dict
- Dictionary with keys
"features"(list of str),"importances_mean"(list of float, rounded to 3 dp), and"importances_std"(list of float, rounded to 3 dp).
plot(self, explanation: dict) -> List[dict]
PermutationFeatureImportanceCreate a Plotly bar chart from a feature importance explanation dict.
Parameters
- explanation : dict
- Output of :meth:
explain: must contain"features","importances_mean", and"importances_std"lists.
Returns
- list of str
- A single-element list containing the Plotly figure serialised to JSON (passed through :meth:
_create_plot).
get_schema(cls) -> dict
ConfigObjectGenerates the component related Json Schema.
Returns
- dict
- Dictionary representing the Json Schema of the component.
validate_and_transform(self, raw_data: dict) -> dict
ConfigObjectIt takes the data given by the user to initialize the model and returns it with all the objects that the model needs to work.
Parameters
- raw_data : dict
- A dictionary with the data provided by the user to initialize the model.
Returns
- dict
- A validated dictionary with the necessary objects.