Source code for DashAI.back.converters.scikit_learn.one_hot_encoder

from sklearn.preprocessing import OneHotEncoder as OneHotEncoderOperation

from DashAI.back.api.utils import cast_string_to_type, parse_string_to_list
from DashAI.back.converters.sklearn_wrapper import SklearnWrapper
from DashAI.back.core.schema_fields import (
    enum_field,
    float_field,
    int_field,
    none_type,
    schema_field,
    string_field,
    union_type,
)
from DashAI.back.core.schema_fields.base_schema import BaseSchema


class OneHotEncoderSchema(BaseSchema):
    categories: schema_field(
        string_field(),  # ‘auto’ or a list of array-like
        "auto",
        "The categories of each feature.",
    )  # type: ignore
    drop: schema_field(
        none_type(
            string_field()
        ),  # {‘first’, ‘if_binary’} or an array-like of shape (n_features,)
        None,
        "Specifies a methodology to use to drop one of the categories per feature.",
    )  # type: ignore
    # sparse_output: Sparse output is not supported in pandas
    dtype: schema_field(
        enum_field(["int", "np.float32", "np.float64"]),  # number type
        "np.float64",
        "Desired dtype of output.",
    )  # type: ignore
    handle_unknown: schema_field(
        enum_field(["error", "ignore", "infrequent_if_exist"]),
        "error",
        (
            "Whether to raise an error or ignore if an unknown categorical feature "
            "is present during transform."
        ),
    )  # type: ignore
    min_frequency: schema_field(
        none_type(union_type(int_field(ge=0), float_field(ge=0.0, le=1.0))),
        None,
        "Minimum frequency of a category to be considered as frequent.",
    )  # type: ignore
    max_categories: schema_field(
        none_type(int_field(ge=1)),
        None,
        "Maximum number of categories to encode.",
    )  # type: ignore
    # Added in version 1.3
    feature_name_combiner: schema_field(
        enum_field(
            [
                "concat",
            ]
        ),  # “concat” or callable
        "concat",
        "Method used to combine feature names.",
    )  # type: ignore


[docs] class OneHotEncoder(SklearnWrapper, OneHotEncoderOperation): """Scikit-learn's OneHotEncoder wrapper for DashAI.""" SCHEMA = OneHotEncoderSchema DESCRIPTION = "Encode categorical integer features as a one-hot numeric array."
[docs] def __init__(self, **kwargs): self.categories = kwargs.pop("categories", "auto") if self.categories != "auto": self.categories = [parse_string_to_list(self.categories)] kwargs["categories"] = self.categories self.drop = kwargs.pop("drop", None) if self.drop is not None and self.drop != "first" and self.drop != "if_binary": self.drop = [parse_string_to_list(self.drop)] kwargs["drop"] = self.drop self.dtype = kwargs.pop("dtype", "np.float64") self.dtype = cast_string_to_type(self.dtype) kwargs["dtype"] = self.dtype # Pandas output does not support sparse data. Set sparse_output=False self.sparse_output = kwargs.pop("sparse_output", False) kwargs["sparse_output"] = self.sparse_output super().__init__(**kwargs)