Source code for DashAI.back.models.scikit_learn.logistic_regression

from sklearn.linear_model import LogisticRegression as _LogisticRegression

from DashAI.back.core.schema_fields import (
    BaseSchema,
    enum_field,
    optimizer_float_field,
    optimizer_int_field,
    schema_field,
)
from DashAI.back.core.utils import MultilingualString
from DashAI.back.models.scikit_learn.sklearn_like_classifier import (
    SklearnLikeClassifier,
)
from DashAI.back.models.scikit_learn.sklearn_like_model import (
    CategoricalEncodingStrategy,
)
from DashAI.back.models.tabular_classification_model import TabularClassificationModel


class LogisticRegressionSchema(BaseSchema):
    """Logistic Regression is a supervised classification method that uses a linear
    model plus a a logistic funcion to predict binary outcomes (it can be configured
    as multiclass via the one-vs-rest strategy).
    """

    penalty: schema_field(
        enum_field(enum=["l2", "l1", "elasticnet"]),
        placeholder="l2",
        description=MultilingualString(
            en="Specify the norm of the penalty",
            es="Especifica la norma de la penalización",
        ),
        alias=MultilingualString(en="Penalty", es="Penalización"),
    )  # type: ignore
    tol: schema_field(
        optimizer_float_field(ge=0.0),
        placeholder={
            "optimize": False,
            "fixed_value": 0.0,
            "lower_bound": 0.0,
            "upper_bound": 5.0,
        },
        description=MultilingualString(
            en="Tolerance for stopping criteria.",
            es="Tolerancia para el criterio de detención.",
        ),
        alias=MultilingualString(en="Tolerance", es="Tolerancia"),
    )  # type: ignore
    C: schema_field(
        optimizer_float_field(gt=0.0),
        placeholder={
            "optimize": False,
            "fixed_value": 1.0,
            "lower_bound": 1.0,
            "upper_bound": 7.0,
        },
        description=MultilingualString(
            en=(
                "Inverse of regularization strength, smaller values specify stronger "
                "regularization. Must be a positive number."
            ),
            es=(
                "Inverso de la fuerza de regularización, valores más pequeños "
                "especifican una regularización más fuerte. Debe ser un número "
                "positivo."
            ),
        ),
        alias=MultilingualString(en="C", es="C"),
    )  # type: ignore
    max_iter: schema_field(
        optimizer_int_field(ge=50),
        placeholder={
            "optimize": False,
            "fixed_value": 100,
            "lower_bound": 50,
            "upper_bound": 250,
        },
        description=MultilingualString(
            en=("Maximum number of iterations taken for the solvers to converge."),
            es=("Número máximo de iteraciones para que los solucionadores converjan."),
        ),
        alias=MultilingualString(en="Max iterations", es="Máximas iteraciones"),
    )  # type: ignore


[docs] class LogisticRegression( TabularClassificationModel, SklearnLikeClassifier, _LogisticRegression ): """Scikit-learn's Logistic Regression wrapper for DashAI.""" SCHEMA = LogisticRegressionSchema DISPLAY_NAME: str = MultilingualString( en="Logistic Regression", es="Regresión Logística", ) DESCRIPTION: str = MultilingualString( en="Linear model for classification using logistic function.", es="Modelo lineal para clasificación usando la función logística.", ) COLOR: str = "#64B5F6" ICON: str = "TrendingUp" CATEGORICAL_ENCODING = CategoricalEncodingStrategy.ONE_HOT
[docs] def __init__(self, **kwargs) -> None: super().__init__(**kwargs)