Source code for DashAI.back.converters.scikit_learn.pca

from sklearn.decomposition import PCA as PCAOPERATION

from DashAI.back.api.utils import create_random_state
from DashAI.back.converters.sklearn_wrapper import SklearnWrapper
from DashAI.back.core.schema_fields import (
    bool_field,
    enum_field,
    float_field,
    int_field,
    none_type,
    schema_field,
    union_type,
)
from DashAI.back.core.schema_fields.base_schema import BaseSchema


class PCASchema(BaseSchema):
    n_components: schema_field(
        none_type(
            union_type(
                union_type(int_field(ge=1), float_field(gt=0.0, lt=1.0)),
                enum_field(["mle"]),
            ),
        ),
        2,
        "Number of components to keep. If None, all components are kept.",
    )  # type: ignore
    use_copy: schema_field(
        bool_field(),
        True,
        (
            "If False, data passed to fit are overwritten and running "
            "fit(X).transform(X) will not yield the expected results, "
            "use fit_transform(X) instead."
        ),
        alias="copy",
    )  # type: ignore
    whiten: schema_field(
        bool_field(),
        False,
        (
            "When True (False by default) the components_ vectors are multiplied "
            "by the square root of n_samples and then divided by the singular values "
            "to ensure uncorrelated outputs with unit component-wise variances. "
            "Whitening will remove some information from the transformed signal "
            "(the relative variance scales of the components) but can sometime "
            "improve the predictive accuracy of the downstream estimators by "
            "making their data respect some hard-wired assumptions."
        ),
    )  # type: ignore
    svd_solver: schema_field(
        enum_field(["auto", "full", "covariance_eigh", "arpack", "randomized"]),
        "auto",
        (
            "The solver to use for the eigendecomposition. If 'auto', it will "
            "choose the most appropriate solver based on the type of data passed."
        ),
    )  # type: ignore
    tol: schema_field(
        float_field(ge=0.0),
        0.0,
        ("Tolerance for singular values computed by " "svd_solver == 'arpack'."),
    )  # type: ignore
    iterated_power: schema_field(
        union_type(int_field(ge=1), enum_field(["auto"])),
        "auto",
        (
            "Number of iterations for the power method computed by "
            "svd_solver == 'randomized'."
        ),
    )  # type: ignore
    n_oversamples: schema_field(
        int_field(ge=1),
        10,
        "Number of power iterations used when svd_solver == 'randomized'.",
    )  # type: ignore
    power_iteration_normalizer: schema_field(
        none_type(enum_field(["auto", "QR", "LU"])),
        "auto",
        (
            "Whether the power iteration normalizer should be computed with QR "
            "(the 'auto' option), "
            "LU decomposition ('LU') or left untouched ('QR'). Not used by ARPACK."
        ),
    )  # type: ignore
    random_state: schema_field(
        none_type(
            union_type(int_field(), enum_field(["RandomState"]))
        ),  # int, RandomState instance or None
        None,
        (
            "Used when the ‘arpack’ or ‘randomized’ solvers are used. "
            "Pass an int for reproducible results across multiple function calls."
        ),
    )  # type: ignore


[docs] class PCA(SklearnWrapper, PCAOPERATION): """Scikit-learn's PCA wrapper for DashAI.""" SCHEMA = PCASchema DESCRIPTION = "Principal component analysis (PCA)."
[docs] def __init__(self, **kwargs): self.random_state = kwargs.pop("random_state", None) if self.random_state == "RandomState": self.random_state = create_random_state() kwargs["random_state"] = self.random_state super().__init__(**kwargs)