Source code for DashAI.back.converters.scikit_learn.knn_imputer

import pyarrow as pa
from sklearn.impute import KNNImputer as KNNImputerOperation

from DashAI.back.converters.category.basic_preprocessing import (
    BasicPreprocessingConverter,
)
from DashAI.back.converters.sklearn_wrapper import SklearnWrapper
from DashAI.back.core.schema_fields import (
    bool_field,
    enum_field,
    int_field,
    schema_field,
)
from DashAI.back.core.schema_fields.base_schema import BaseSchema
from DashAI.back.core.utils import MultilingualString
from DashAI.back.types.dashai_data_type import DashAIDataType
from DashAI.back.types.value_types import Float


class KNNImputerSchema(BaseSchema):
    n_neighbors: schema_field(
        int_field(ge=1),
        5,
        description=MultilingualString(
            en="The number of nearest neighbors to use for imputation.",
            es="Número de vecinos más cercanos a usar para la imputación.",
        ),
    )  # type: ignore
    weights: schema_field(
        enum_field(["uniform", "distance"]),
        "uniform",
        description=MultilingualString(
            en="The weight function to use for imputation.",
            es="La función de peso a usar para la imputación.",
        ),
    )  # type: ignore
    metric: schema_field(
        enum_field(["nan_euclidean"]),
        "nan_euclidean",
        description=MultilingualString(
            en="The metric to use for imputation.",
            es="La métrica a usar para la imputación.",
        ),
    )  # type: ignore
    use_copy: schema_field(
        bool_field(),
        True,
        description=MultilingualString(
            en="If True, a copy of X will be created.",
            es="Si es True, se creará una copia de X.",
        ),
        alias=MultilingualString(en="copy", es="copiar"),
    )  # type: ignore
    add_indicator: schema_field(
        bool_field(),
        False,
        description=MultilingualString(
            en="If True, a MissingIndicator transform will stack onto output.",
            es="Si es True, se apilará un MissingIndicator sobre la salida.",
        ),
    )  # type: ignore
    keep_empty_features: schema_field(
        bool_field(),
        False,
        description=MultilingualString(
            en="If True, empty features will be kept.",
            es="Si es True, se mantendrán las características vacías.",
        ),
    )  # type: ignore


[docs] class KNNImputer(BasicPreprocessingConverter, SklearnWrapper, KNNImputerOperation): """Scikit-learn's KNNImputer wrapper for DashAI.""" SCHEMA = KNNImputerSchema DESCRIPTION = MultilingualString( en=("Imputation for completing missing values using k-Nearest Neighbors."), es=( "Imputación para completar valores faltantes utilizando " "k-Vecinos Más Cercanos." ), ) DISPLAY_NAME = MultilingualString(en="KNN Imputer", es="Imputador KNN") IMAGE_PREVIEW = "knn_imputer.png"
[docs] def __init__(self, **kwargs): super().__init__(**kwargs)
def get_output_type(self, column_name: str = None) -> DashAIDataType: """Returns Float64 as the output type for imputed data.""" return Float(arrow_type=pa.float64())