Source code for DashAI.back.converters.scikit_learn.select_k_best

from sklearn.feature_selection import SelectKBest as SelectKBestOperation

from DashAI.back.converters.category.feature_selection import FeatureSelectionConverter
from DashAI.back.converters.sklearn_wrapper import SklearnWrapper
from DashAI.back.core.schema_fields import (
    enum_field,
    int_field,
    schema_field,
    union_type,
)
from DashAI.back.core.schema_fields.base_schema import BaseSchema


class SelectKBestSchema(BaseSchema):
    k: schema_field(
        union_type(enum_field(["all"]), int_field(ge=1)),
        10,
        "Number of top features to select.",
    )  # type: ignore


[docs] class SelectKBest(FeatureSelectionConverter, SklearnWrapper, SelectKBestOperation): """SciKit-Learn's SelectKBest wrapper for DashAI.""" SCHEMA = SelectKBestSchema DESCRIPTION = "Select features according to the k highest scores." SUPERVISED = True DISPLAY_NAME = "Select K Best" IMAGE_PREVIEW = "select_k_best.png" metadata = {}
[docs] def __init__(self, **kwargs): if callable(self._get_tags): original_get_tags = self._get_tags self._get_tags = lambda *a, **k: { **original_get_tags(*a, **k), "requires_y": True, } else: self._get_tags = {**self._get_tags, "requires_y": True} super().__init__(**kwargs)