Source code for DashAI.back.converters.scikit_learn.one_hot_encoder

import pyarrow as pa
from sklearn.preprocessing import OneHotEncoder as OneHotEncoderOperation

from DashAI.back.api.utils import cast_string_to_type, parse_string_to_list
from DashAI.back.converters.category.encoding import EncodingConverter
from DashAI.back.converters.sklearn_wrapper import SklearnWrapper
from DashAI.back.core.schema_fields import (
    enum_field,
    float_field,
    int_field,
    none_type,
    schema_field,
    string_field,
    union_type,
)
from DashAI.back.core.schema_fields.base_schema import BaseSchema
from DashAI.back.core.utils import MultilingualString
from DashAI.back.types.dashai_data_type import DashAIDataType
from DashAI.back.types.value_types import Integer


class OneHotEncoderSchema(BaseSchema):
    categories: schema_field(
        string_field(),
        "auto",
        description=MultilingualString(
            en="The categories of each feature.",
            es="Las categorías de cada característica.",
        ),
    )  # type: ignore
    drop: schema_field(
        none_type(string_field()),
        None,
        description=MultilingualString(
            en=("Specifies a methodology to drop one of the categories per feature."),
            es=(
                "Especifica una metodología para eliminar una categoría por "
                "característica."
            ),
        ),
    )  # type: ignore
    dtype: schema_field(
        enum_field(["int", "np.float32", "np.float64"]),
        "np.float64",
        description=MultilingualString(
            en="Desired dtype of output.",
            es="Tipo de dato de salida deseado.",
        ),
    )  # type: ignore
    handle_unknown: schema_field(
        enum_field(["error", "ignore", "infrequent_if_exist"]),
        "error",
        description=MultilingualString(
            en=("How to handle unknown categories during transform."),
            es=("Cómo manejar categorías desconocidas durante la transformación."),
        ),
    )  # type: ignore
    min_frequency: schema_field(
        none_type(union_type(int_field(ge=0), float_field(ge=0.0, le=1.0))),
        None,
        description=MultilingualString(
            en="Minimum frequency of a category to be considered as frequent.",
            es="Frecuencia mínima para considerar una categoría como frecuente.",
        ),
    )  # type: ignore
    max_categories: schema_field(
        none_type(int_field(ge=1)),
        None,
        description=MultilingualString(
            en="Maximum number of categories to encode.",
            es="Número máximo de categorías a codificar.",
        ),
    )  # type: ignore
    feature_name_combiner: schema_field(
        enum_field(["concat"]),
        "concat",
        description=MultilingualString(
            en="Method used to combine feature names.",
            es="Método usado para combinar nombres de características.",
        ),
    )  # type: ignore


[docs] class OneHotEncoder(EncodingConverter, SklearnWrapper, OneHotEncoderOperation): """Scikit-learn's OneHotEncoder wrapper for DashAI.""" SCHEMA = OneHotEncoderSchema DESCRIPTION = MultilingualString( en="Encode categorical integer features as a one-hot numeric array.", es=( "Codifica características categóricas enteras como un arreglo " "numérico one-hot." ), ) DISPLAY_NAME = MultilingualString(en="One-Hot Encoder", es="Codificador One-Hot") IMAGE_PREVIEW = "one_hot_encoder.png"
[docs] def __init__(self, **kwargs): self.categories = kwargs.pop("categories", "auto") if self.categories != "auto": self.categories = [parse_string_to_list(self.categories)] kwargs["categories"] = self.categories self.drop = kwargs.pop("drop", None) if self.drop is not None and self.drop != "first" and self.drop != "if_binary": self.drop = [parse_string_to_list(self.drop)] kwargs["drop"] = self.drop self.dtype = kwargs.pop("dtype", "np.float64") self.dtype = cast_string_to_type(self.dtype) kwargs["dtype"] = self.dtype # Pandas output does not support sparse data. Set sparse_output=False self.sparse_output = kwargs.pop("sparse_output", False) kwargs["sparse_output"] = self.sparse_output super().__init__(**kwargs)
def get_output_type(self, column_name: str = None) -> DashAIDataType: """Returns Integer64 as the output type for one-hot encoded data.""" return Integer(arrow_type=pa.int64())