PermutationFeatureImportance
Global explainer that ranks features by the drop in model performance when permuted.
Permutation Feature Importance (PFI) measures the importance of a feature by
randomly shuffling its values across the test set and recording the resulting
decrease in a scoring metric. A large decrease indicates that the model relies
heavily on that feature; a small (or zero) decrease indicates the feature
contributes little. The process is repeated n_repeats times to produce a
mean importance and standard deviation, which quantify both rank order and
uncertainty.
Unlike impurity-based importance (from decision trees), PFI is computed on held out data and is therefore not biased towards high cardinality features. It is model agnostic and captures interaction effects, but assumes that permuting a feature does not violate important correlations in the data.
References
- [1] Breiman, L. (2001). "Random Forests." Machine Learning, 45(1), 5-32.
- [2] Fisher, A. et al. (2019). "All Models are Wrong, but Many are Useful." JMLR, 20(177), 1-81. https://arxiv.org/abs/1801.01489
- [3] https://scikit-learn.org/stable/modules/permutation_importance.html
Parameters
- scoring : string, default=
accuracy - Metric used to evaluate how the model's performance changes when a particular feature is shuffled.
- n_repeats : integer, default=
20 - Number of times to permute a feature.
- random_state : integer, default=
0 - Seed for the random number generator to control permutations of each feature.
- max_samples_fraction : number, default=
1.0 - Fraction of samples to draw from the test set to calculate feature importance at each repetition.
Methods
explain(self, dataset)
PermutationFeatureImportanceCompute permutation feature importance for the fitted model.
Parameters
- dataset : tuple of (DatasetDict, DatasetDict)
- A
(x, y)pair where each element is a DatasetDict with at least a"test"split.
Returns
- dict
- Dictionary with keys
"features"(list of str),"importances_mean"(list of float, rounded to 3 dp), and"importances_std"(list of float, rounded to 3 dp).
plot(self, explanation: dict) -> List[dict]
PermutationFeatureImportanceCreate a Plotly bar chart from a feature importance explanation dict.
Parameters
- explanation : dict
- Output of :meth:
explain: must contain"features","importances_mean", and"importances_std"lists.
Returns
- list of str
- A single-element list containing the Plotly figure serialised to JSON (passed through :meth:
_create_plot).
get_schema(cls) -> dict
ConfigObjectGenerates the component related Json Schema.
Returns
- dict
- Dictionary representing the Json Schema of the component.
validate_and_transform(self, raw_data: dict) -> dict
ConfigObjectIt takes the data given by the user to initialize the model and returns it with all the objects that the model needs to work.
Parameters
- raw_data : dict
- A dictionary with the data provided by the user to initialize the model.
Returns
- dict
- A validated dictionary with the necessary objects.