Source code for DashAI.back.converters.scikit_learn.knn_imputer

from sklearn.impute import KNNImputer as KNNImputerOperation

from DashAI.back.api.utils import cast_string_to_type
from DashAI.back.converters.sklearn_wrapper import SklearnWrapper
from DashAI.back.core.schema_fields import (
    bool_field,
    enum_field,
    float_field,
    int_field,
    none_type,
    schema_field,
    string_field,
    union_type,
)
from DashAI.back.core.schema_fields.base_schema import BaseSchema


class KNNImputerSchema(BaseSchema):
    missing_values: schema_field(
        none_type(
            union_type(int_field(), union_type(float_field(), string_field()))
        ),  # int, float, str, np.nan or None
        None,  # np.nan,
        "The placeholder for the missing values.",
    )  # type: ignore
    n_neighbors: schema_field(
        int_field(ge=1),
        5,
        "The number of nearest neighbors to use for imputation.",
    )  # type: ignore
    weights: schema_field(
        enum_field(["uniform", "distance"]),  # {‘uniform’, ‘distance’} or callable
        "uniform",
        "The weight function to use for imputation.",
    )  # type: ignore
    metric: schema_field(
        enum_field(["nan_euclidean"]),  # {‘nan_euclidean‘} or callable
        "nan_euclidean",
        "The metric to use for imputation.",
    )  # type: ignore
    use_copy: schema_field(
        bool_field(),
        True,
        "If True, a copy of X will be created.",
        alias="copy",
    )  # type: ignore
    add_indicator: schema_field(
        bool_field(),
        False,
        "If True, a MissingIndicator transform will stack onto output.",
    )  # type: ignore
    keep_empty_features: schema_field(
        bool_field(),
        False,
        "If True, empty features will be kept.",
    )  # type: ignore


[docs] class KNNImputer(SklearnWrapper, KNNImputerOperation): """Scikit-learn's KNNImputer wrapper for DashAI.""" SCHEMA = KNNImputerSchema DESCRIPTION = "Imputation for completing missing values using k-Nearest Neighbors."
[docs] def __init__(self, **kwargs): self.missing_values = kwargs.pop("missing_values", None) self.missing_values = cast_string_to_type(self.missing_values) kwargs["missing_values"] = self.missing_values super().__init__(**kwargs)