Source code for DashAI.back.converters.scikit_learn.truncated_svd

import pyarrow as pa
from sklearn.decomposition import TruncatedSVD as TruncatedSVDOperation

from DashAI.back.api.utils import create_random_state
from DashAI.back.converters.category.dimensionality_reduction import (
    DimensionalityReductionConverter,
)
from DashAI.back.converters.sklearn_wrapper import SklearnWrapper
from DashAI.back.core.schema_fields import (
    enum_field,
    float_field,
    int_field,
    none_type,
    schema_field,
    union_type,
)
from DashAI.back.core.schema_fields.base_schema import BaseSchema
from DashAI.back.core.utils import MultilingualString
from DashAI.back.types.dashai_data_type import DashAIDataType
from DashAI.back.types.value_types import Float


class TruncatedSVDSchema(BaseSchema):
    n_components: schema_field(
        int_field(gt=0),
        2,
        description=MultilingualString(
            en="Desired dimensionality of output data.",
            es="Dimensionalidad deseada de los datos de salida.",
        ),
    )  # type: ignore
    algorithm: schema_field(
        enum_field(["arpack", "randomized"]),
        "randomized",
        description=MultilingualString(
            en="SVD solver to use.",
            es="Método SVD a utilizar.",
        ),
    )  # type: ignore
    n_iter: schema_field(
        int_field(gt=0),
        5,
        description=MultilingualString(
            en="Number of iterations for randomized SVD solver.",
            es="Número de iteraciones para el método SVD aleatorizado.",
        ),
    )  # type: ignore
    n_oversamples: schema_field(
        int_field(gt=0),
        10,
        description=MultilingualString(
            en="Number of power iterations used in randomized SVD solver.",
            es=(
                "Número de iteraciones de potencia utilizadas en el método "
                "SVD aleatorizado."
            ),
        ),
    )  # type: ignore
    power_iteration_normalizer: schema_field(
        enum_field(["auto", "QR", "LU", "none"]),
        "auto",
        description=MultilingualString(
            en="Method to normalize the eigenvectors.",
            es="Método para normalizar los eigenvectores.",
        ),
    )  # type: ignore
    random_state: schema_field(
        none_type(union_type(int_field(), enum_field(["RandomState"]))),
        None,
        description=MultilingualString(
            en=(
                "Used during randomized svd. Pass an int for reproducible "
                "results across multiple function calls."
            ),
            es=(
                "Usado durante SVD aleatorizado. Pasa un entero para obtener "
                "resultados reproducibles en múltiples ejecuciones."
            ),
        ),
    )  # type: ignore
    tol: schema_field(
        float_field(ge=0),
        0.0,
        description=MultilingualString(
            en="Tolerance for ARPACK.",
            es="Tolerancia para ARPACK.",
        ),
    )  # type: ignore


[docs] class TruncatedSVD( DimensionalityReductionConverter, SklearnWrapper, TruncatedSVDOperation ): """Scikit-learn's TruncatedSVD wrapper for DashAI.""" SCHEMA = TruncatedSVDSchema DESCRIPTION = MultilingualString( en=( "This transformer performs linear dimensionality reduction by means " "of truncated singular value decomposition (SVD). Contrary to PCA, " "this estimator does not center the data before computing the " "singular value decomposition. This means it can work with sparse " "matrices efficiently." ), es=( "Este transformador realiza reducción lineal de dimensionalidad por " "medio de la descomposición en valores singulares truncada (SVD). " "A diferencia de PCA, este estimador no centra los datos antes de " "calcular la descomposición, lo que permite trabajar eficientemente " "con matrices dispersas." ), ) SHORT_DESCRIPTION = MultilingualString( en="Dimensionality reduction using truncated SVD.", es="Reducción de dimensionalidad utilizando SVD truncado.", ) DISPLAY_NAME = MultilingualString(en="Truncated SVD", es="SVD Truncado") IMAGE_PREVIEW = "truncated_svd.png" metadata = {}
[docs] def __init__(self, **kwargs): self.random_state = kwargs.pop("random_state", None) if self.random_state == "RandomState": self.random_state = create_random_state() kwargs["random_state"] = self.random_state super().__init__(**kwargs)
def get_output_type(self, column_name: str = None) -> DashAIDataType: """Returns Float64 as the output type for transformed data.""" return Float(arrow_type=pa.float64())