Source code for DashAI.back.models.scikit_learn.hist_gradient_boosting_classifier

from sklearn.ensemble import (
    HistGradientBoostingClassifier as _HistGradientBoostingClassifier,
)

from DashAI.back.core.schema_fields import (
    BaseSchema,
    optimizer_float_field,
    optimizer_int_field,
    schema_field,
)
from DashAI.back.core.utils import MultilingualString
from DashAI.back.models.scikit_learn.sklearn_like_classifier import (
    SklearnLikeClassifier,
)
from DashAI.back.models.tabular_classification_model import TabularClassificationModel


class HistGradientBoostingClassifierSchema(BaseSchema):
    """A gradient boosting classifier is a machine learning algorithm that combines
    multiple weak prediction models (typically decision trees) to create a strong
    predictive model by training the models sequentially, in which each new model is
    focused on correcting the errors made by the previous ones.
    """

    learning_rate: schema_field(
        optimizer_float_field(ge=0.0),
        placeholder={
            "optimize": False,
            "fixed_value": 0.1,
            "lower_bound": 0.1,
            "upper_bound": 1,
        },
        description=MultilingualString(
            en=(
                "The learning rate, also known as shrinkage. This is used as a "
                "multiplicative factor for the leaves values. Use 1 for no shrinkage."
            ),
            es=(
                "La tasa de aprendizaje, también conocida como shrinkage. Se utiliza "
                "como factor multiplicativo para los valores de las hojas. Use 1 para "
                "no aplicar shrinkage."
            ),
        ),
        alias=MultilingualString(en="Learning rate", es="Tasa de aprendizaje"),
    )  # type: ignore
    max_iter: schema_field(
        optimizer_int_field(ge=0),
        placeholder={
            "optimize": False,
            "fixed_value": 100,
            "lower_bound": 100,
            "upper_bound": 250,
        },
        description=MultilingualString(
            en=(
                "The maximum number of iterations of the boosting process, i.e. the "
                "maximum number of trees for binary classification."
            ),
            es=(
                "El número máximo de iteraciones del proceso de boosting, es decir, "
                "el número máximo de árboles para clasificación binaria."
            ),
        ),
        alias=MultilingualString(en="Max iterations", es="Máximas iteraciones"),
    )  # type: ignore
    max_depth: schema_field(
        optimizer_int_field(ge=0),
        placeholder={
            "optimize": False,
            "fixed_value": 1,
            "lower_bound": 1,
            "upper_bound": 10,
        },
        description=MultilingualString(
            en=(
                "The maximum depth of each tree. The depth of a tree is the number "
                "of edges to go from the root to the deepest leaf. Depth isn't "
                "constrained by default."
            ),
            es=(
                "La profundidad máxima de cada árbol. La profundidad es el número de "
                "aristas desde la raíz hasta la hoja más profunda. Por defecto, la "
                "profundidad no está restringida."
            ),
        ),
        alias=MultilingualString(en="Max depth", es="Profundidad máxima"),
    )  # type: ignore
    max_leaf_nodes: schema_field(
        optimizer_int_field(ge=2),
        placeholder={
            "optimize": False,
            "fixed_value": 31,
            "lower_bound": 10,
            "upper_bound": 40,
        },
        description=MultilingualString(
            en=(
                "The maximum number of leaves for each tree. Must be strictly "
                "greater than 1. If None, there is no maximum limit."
            ),
            es=(
                "El número máximo de hojas para cada árbol. Debe ser estrictamente "
                "mayor que 1. Si es None, no hay límite máximo."
            ),
        ),
        alias=MultilingualString(en="Max leaf nodes", es="Nodos de hoja máximos"),
    )  # type: ignore
    min_samples_leaf: schema_field(
        optimizer_int_field(ge=1),
        placeholder={
            "optimize": False,
            "fixed_value": 20,
            "lower_bound": 2,
            "upper_bound": 25,
        },
        description=MultilingualString(
            en="The minimum number of samples required to be at a leaf node.",
            es="El número mínimo de muestras requeridas para estar en una hoja.",
        ),
        alias=MultilingualString(en="Min samples leaf", es="Muestras de hoja mínimas"),
    )  # type: ignore
    l2_regularization: schema_field(
        optimizer_float_field(ge=0.0),
        placeholder={
            "optimize": False,
            "fixed_value": 0.0,
            "lower_bound": 0.0,
            "upper_bound": 1.0,
        },
        description=MultilingualString(
            en="The L2 regularization parameter. Use 0 for no regularization.",
            es=(
                "El parámetro de regularización L2. "
                "Use 0 para no aplicar regularización."
            ),
        ),
        alias=MultilingualString(en="L2 regularization", es="Regularización L2"),
    )  # type: ignore


[docs] class HistGradientBoostingClassifier( TabularClassificationModel, SklearnLikeClassifier, _HistGradientBoostingClassifier ): """Scikit-learn's HistGradientBoostingRegressor wrapper for DashAI.""" SCHEMA = HistGradientBoostingClassifierSchema DISPLAY_NAME: str = MultilingualString( en="Histogram-based Gradient Boosting", es="Gradient Boosting basado en histogramas", ) DESCRIPTION: str = MultilingualString( en="Fast gradient boosting using histogram-based algorithms.", es=("Gradient boosting rápido usando algoritmos basados en histogramas."), ) COLOR: str = "#9575CD" ICON: str = "RocketLaunch"
[docs] def __init__(self, **kwargs) -> None: super().__init__(**kwargs)