Source code for DashAI.back.converters.scikit_learn.truncated_svd

from sklearn.decomposition import TruncatedSVD as TruncatedSVDOperation

from DashAI.back.api.utils import create_random_state
from DashAI.back.converters.sklearn_wrapper import SklearnWrapper
from DashAI.back.core.schema_fields import (
    enum_field,
    float_field,
    int_field,
    none_type,
    schema_field,
    union_type,
)
from DashAI.back.core.schema_fields.base_schema import BaseSchema


class TruncatedSVDSchema(BaseSchema):
    n_components: schema_field(
        int_field(gt=0),
        2,
        "Desired dimensionality of output data.",
    )  # type: ignore
    algorithm: schema_field(
        enum_field(["arpack", "randomized"]),
        "randomized",
        "SVD solver to use.",
    )  # type: ignore
    n_iter: schema_field(
        int_field(gt=0),
        5,
        "Number of iterations for randomized SVD solver.",
    )  # type: ignore
    n_oversamples: schema_field(
        int_field(gt=0),
        10,
        "Number of power iterations used in randomized SVD solver.",
    )  # type: ignore
    power_iteration_normalizer: schema_field(
        enum_field(["auto", "QR", "LU", "none"]),
        "auto",
        "Method to normalize the eigenvectors.",
    )  # type: ignore
    random_state: schema_field(
        none_type(
            union_type(int_field(), enum_field(["RandomState"]))
        ),  # int, RandomState instance or None
        None,
        (
            "Used during randomized svd. Pass an int for reproducible results "
            "across multiple function calls."
        ),
    )  # type: ignore
    tol: schema_field(
        float_field(ge=0),
        0.0,
        "Tolerance for ARPACK.",
    )  # type: ignore


[docs] class TruncatedSVD(SklearnWrapper, TruncatedSVDOperation): """Scikit-learn's TruncatedSVD wrapper for DashAI.""" SCHEMA = TruncatedSVDSchema DESCRIPTION = ( "This transformer performs linear dimensionality reduction by means of " "truncated singular value decomposition (SVD). " "Contrary to PCA, this estimator does not center the data before " "computing the singular value decomposition. " "This means it can work with sparse matrices efficiently." )
[docs] def __init__(self, **kwargs): self.random_state = kwargs.pop("random_state", None) if self.random_state == "RandomState": self.random_state = create_random_state() kwargs["random_state"] = self.random_state