Source code for DashAI.back.models.scikit_learn.decision_tree_classifier

from sklearn.tree import DecisionTreeClassifier as _DecisionTreeClassifier

from DashAI.back.core.schema_fields import (
    BaseSchema,
    enum_field,
    float_field,
    none_type,
    optimizer_int_field,
    schema_field,
    union_type,
)
from DashAI.back.core.utils import MultilingualString
from DashAI.back.models.scikit_learn.sklearn_like_classifier import (
    SklearnLikeClassifier,
)
from DashAI.back.models.tabular_classification_model import TabularClassificationModel


class DecisionTreeClassifierSchema(BaseSchema):
    """Decision Trees are a set of are a non-parametric supervised learning method that
    learns simple decision rules (structured as a tree) inferred from the data features.
    """

    criterion: schema_field(
        enum_field(enum=["entropy", "gini", "log_loss"]),
        placeholder="entropy",
        description=MultilingualString(
            en=(
                "The function to measure the quality of a split. Supported criteria "
                "are “gini” for the Gini impurity and “log_loss” and “entropy” both "
                "for the Shannon information gain."
            ),
            es=(
                "La función para medir la calidad de una división. Los criterios "
                "soportados son “gini” para la impureza de Gini y “log_loss” y "
                "“entropy” para la ganancia de información de Shannon."
            ),
        ),
        alias=MultilingualString(en="Criterion", es="Criterio"),
    )  # type: ignore
    max_depth: schema_field(
        optimizer_int_field(ge=1),
        placeholder={
            "optimize": False,
            "fixed_value": 1,
            "lower_bound": 1,
            "upper_bound": 10,
        },
        description=MultilingualString(
            en=(
                "The maximum depth of the tree. If None, then nodes are expanded "
                "until all leaves are pure or until all leaves contain less than "
                "min_samples_split samples."
            ),
            es=(
                "La profundidad máxima del árbol. Si es None, los nodos se expanden "
                "hasta que todas las hojas sean puras o hasta que todas las hojas "
                "contengan menos de min_samples_split muestras."
            ),
        ),
        alias=MultilingualString(en="Max depth", es="Profundidad máxima"),
    )  # type: ignore
    min_samples_split: schema_field(
        optimizer_int_field(ge=1),
        placeholder={
            "optimize": False,
            "fixed_value": 1,
            "lower_bound": 1,
            "upper_bound": 5,
        },
        description=MultilingualString(
            en="The minimum number of samples required to split an internal node.",
            es="El número mínimo de muestras requeridas para dividir un nodo interno.",
        ),
        alias=MultilingualString(
            en="Min samples split", es="Mínimas muestras de división"
        ),
    )  # type: ignore
    min_samples_leaf: schema_field(
        optimizer_int_field(ge=1),
        placeholder={
            "optimize": False,
            "fixed_value": 1,
            "lower_bound": 1,
            "upper_bound": 5,
        },
        description=MultilingualString(
            en="The minimum number of samples required to be at a leaf node.",
            es="El número mínimo de muestras requeridas para estar en una hoja.",
        ),
        alias=MultilingualString(
            en="Min samples leaf", es="Mínimas muestras para hoja"
        ),
    )  # type: ignore
    max_features: schema_field(
        none_type(
            union_type(enum_field(enum=["sqrt", "log2"]), float_field(gt=0.0, le=1.0))
        ),
        placeholder=None,
        description=MultilingualString(
            en=(
                "The number of features to consider when looking for the best split. "
                "If float, then max_features is a percentage of the total number of "
                "features."
            ),
            es=(
                "El número de características a considerar al buscar la mejor "
                "división. Si es float, entonces max_features es un porcentaje del "
                "total de características."
            ),
        ),
        alias=MultilingualString(en="Max features", es="Máximas características"),
    )  # type: ignore


[docs] class DecisionTreeClassifier( TabularClassificationModel, SklearnLikeClassifier, _DecisionTreeClassifier ): """Scikit-learn's Decision Tree Classifier wrapper for DashAI.""" SCHEMA = DecisionTreeClassifierSchema DISPLAY_NAME: str = MultilingualString( en="Decision Tree", es="Árbol de Decisión", ) DESCRIPTION: str = MultilingualString( en="Decision tree classifier using CART algorithm.", es=("Clasificador de árbol de decisión usando el algoritmo CART."), ) COLOR: str = "#4CAF50" ICON: str = "AccountTree"
[docs] def __init__(self, **kwargs) -> None: super().__init__(**kwargs)