Source code for DashAI.back.models.scikit_learn.ridge_regression

from sklearn.linear_model import Ridge as _Ridge

from DashAI.back.core.schema_fields import (
    BaseSchema,
    bool_field,
    enum_field,
    none_type,
    optimizer_float_field,
    optimizer_int_field,
    schema_field,
    union_type,
)
from DashAI.back.core.utils import MultilingualString
from DashAI.back.models.regression_model import RegressionModel
from DashAI.back.models.scikit_learn.sklearn_like_model import (
    CategoricalEncodingStrategy,
)
from DashAI.back.models.scikit_learn.sklearn_like_regressor import SklearnLikeRegressor


class RidgeRegressionSchema(BaseSchema):
    """Ridge regression is a linear model that includes L2 regularization."""

    alpha: schema_field(
        optimizer_int_field(ge=1),
        placeholder={
            "optimize": False,
            "fixed_value": 1,
            "lower_bound": 1,
            "upper_bound": 10,
        },
        description=MultilingualString(
            en=(
                "Regularization strength; must be a positive float. "
                "Larger values specify stronger regularization."
            ),
            es=(
                "Fuerza de regularización; debe ser un float positivo. "
                "Valores más grandes especifican una regularización más fuerte."
            ),
        ),
        alias=MultilingualString(en="Alpha", es="Alfa"),
    )  # type: ignore

    fit_intercept: schema_field(
        bool_field(),
        placeholder=True,
        description=MultilingualString(
            en=(
                "Whether to calculate the intercept for this model. "
                "If set to False, no intercept will be used in calculations "
                "(e.g., data is expected to be centered)."
            ),
            es=(
                "Si se debe calcular el intercepto para este modelo. "
                "Si se establece en False, no se usará intercepto en los cálculos "
                "(ej., se espera que los datos estén centrados)."
            ),
        ),
        alias=MultilingualString(en="Fit intercept", es="Ajustar intercepto"),
    )  # type: ignore

    copy_X: schema_field(  # noqa: N815
        bool_field(),
        placeholder=True,
        description=MultilingualString(
            en="If True, X will be copied; else, it may be overwritten.",
            es="Si es True, X será copiado; si no, puede ser sobrescrito.",
        ),
        alias=MultilingualString(en="Copy X", es="Copiar X"),
    )  # type: ignore

    max_iter: schema_field(
        optimizer_int_field(ge=10),
        placeholder={
            "optimize": False,
            "fixed_value": 100,
            "lower_bound": 10,
            "upper_bound": 10000,
        },
        description=MultilingualString(
            en="Maximum number of iterations for conjugate gradient solver.",
            es=(
                "Número máximo de iteraciones para el "
                "solucionador de gradiente conjugado."
            ),
        ),
        alias=MultilingualString(en="Max iterations", es="Máximas iteraciones"),
    )  # type: ignore
    tol: schema_field(
        optimizer_float_field(ge=0.0),
        placeholder={
            "optimize": False,
            "fixed_value": 0.001,
            "lower_bound": 1e-5,
            "upper_bound": 1e-1,
        },
        description=MultilingualString(
            en="Precision of the solution.",
            es="Precisión de la solución.",
        ),
        alias=MultilingualString(en="Tolerance", es="Tolerancia"),
    )  # type: ignore
    solver: schema_field(
        enum_field(
            enum=["auto", "svd", "cholesky", "lsqr", "sparse_cg", "sag", "saga"]
        ),
        placeholder="auto",
        description=MultilingualString(
            en=(
                "Solver to use in the computation. 'auto' chooses the "
                "solver automatically based on the type of data."
            ),
            es=(
                "Solucionador a usar en el cálculo. 'auto' elige el "
                "solucionador automáticamente basado en el tipo de datos."
            ),
        ),
        alias=MultilingualString(en="Solver", es="Solucionador"),
    )  # type: ignore
    positive: schema_field(
        bool_field(),
        placeholder=False,
        description=MultilingualString(
            en="When set to True, forces the coefficients to be positive.",
            es="Cuando se establece en True, fuerza los coeficientes a ser positivos.",
        ),
        alias=MultilingualString(en="Positive", es="Positivo"),
    )  # type: ignore
    random_state: schema_field(
        union_type(optimizer_int_field(ge=0), none_type(int)),
        placeholder=None,
        description=MultilingualString(
            en=(
                "The seed of the pseudo random number generator to use "
                "when shuffling the data. Pass an int for reproducible output across "
                "multiple function calls, or None to not set a specific seed."
            ),
            es=(
                "La semilla del generador de números pseudoaleatorios a usar "
                "al mezclar los datos. Pase un int para salida reproducible entre "
                "múltiples llamadas, o None para no establecer una semilla específica."
            ),
        ),
        alias=MultilingualString(en="Random state", es="Estado aleatorio"),
    )  # type: ignore


[docs] class RidgeRegression(RegressionModel, SklearnLikeRegressor, _Ridge): """Scikit-learn's Ridge regression wrapper for DashAI.""" SCHEMA = RidgeRegressionSchema DISPLAY_NAME: str = MultilingualString( en="Ridge Regression", es="Regresión Ridge", ) DESCRIPTION: str = MultilingualString( en="Linear regression with L2 regularization.", es="Regresión lineal con regularización L2.", ) COLOR: str = "#2196F3" ICON: str = "ShowChart" CATEGORICAL_ENCODING = CategoricalEncodingStrategy.ONE_HOT
[docs] def __init__(self, **kwargs) -> None: super().__init__(**kwargs)