import pyarrow as pa
from sklearn.decomposition import PCA as PCAOPERATION
from DashAI.back.api.utils import create_random_state
from DashAI.back.converters.category.dimensionality_reduction import (
DimensionalityReductionConverter,
)
from DashAI.back.converters.sklearn_wrapper import SklearnWrapper
from DashAI.back.core.schema_fields import (
bool_field,
enum_field,
float_field,
int_field,
none_type,
schema_field,
union_type,
)
from DashAI.back.core.schema_fields.base_schema import BaseSchema
from DashAI.back.core.utils import MultilingualString
from DashAI.back.types.dashai_data_type import DashAIDataType
from DashAI.back.types.value_types import Float
class PCASchema(BaseSchema):
n_components: schema_field(
none_type(
union_type(
union_type(int_field(ge=1), float_field(gt=0.0, lt=1.0)),
enum_field(["mle"]),
),
),
2,
description=MultilingualString(
en="Number of components to keep. If None, all components are kept.",
es=(
"Número de componentes a conservar. Si es None, se conservan "
"todas las componentes."
),
),
) # type: ignore
use_copy: schema_field(
bool_field(),
True,
description=MultilingualString(
en=(
"If False, data passed to fit are overwritten. Use "
"fit_transform(X) instead of fit(X).transform(X)."
),
es=(
"Si es False, los datos pasados a fit se sobrescriben. Usa "
"fit_transform(X) en lugar de fit(X).transform(X)."
),
),
alias=MultilingualString(en="copy", es="copiar"),
) # type: ignore
whiten: schema_field(
bool_field(),
False,
description=MultilingualString(
en=(
"When True the components_ are scaled to ensure uncorrelated "
"outputs with unit variances. May improve downstream estimators."
),
es=(
"Cuando es True las componentes se escalan para asegurar salidas "
"no correlacionadas con varianzas unitarias. Puede mejorar "
"estimadores posteriores."
),
),
) # type: ignore
svd_solver: schema_field(
enum_field(["auto", "full", "covariance_eigh", "arpack", "randomized"]),
"auto",
description=MultilingualString(
en=(
"Solver to use for eigendecomposition. 'auto' elige el más "
"apropiado según los datos."
),
es=(
"Método para la descomposición propia. 'auto' elige el más "
"apropiado según los datos."
),
),
) # type: ignore
tol: schema_field(
float_field(ge=0.0),
0.0,
description=MultilingualString(
en="Tolerance for singular values when svd_solver == 'arpack'.",
es="Tolerancia para valores singulares cuando svd_solver == 'arpack'.",
),
) # type: ignore
iterated_power: schema_field(
union_type(int_field(ge=1), enum_field(["auto"])),
"auto",
description=MultilingualString(
en=(
"Number of iterations for the power method when "
"svd_solver == 'randomized'."
),
es=(
"Número de iteraciones para el método de potencia cuando "
"svd_solver == 'randomized'."
),
),
) # type: ignore
n_oversamples: schema_field(
int_field(ge=1),
10,
description=MultilingualString(
en="Number of power iterations used when svd_solver == 'randomized'.",
es="Número de iteraciones de potencia cuando svd_solver == 'randomized'.",
),
) # type: ignore
power_iteration_normalizer: schema_field(
none_type(enum_field(["auto", "QR", "LU"])),
"auto",
description=MultilingualString(
en=(
"How the power iteration normalizer should be computed: 'auto', "
"QR o LU. No usado por ARPACK."
),
es=(
"Cómo se calcula el normalizador de iteración de potencia: "
"'auto', QR o LU. No se usa con ARPACK."
),
),
) # type: ignore
random_state: schema_field(
none_type(union_type(int_field(), enum_field(["RandomState"]))),
None,
description=MultilingualString(
en=(
"Used when 'arpack' or 'randomized' solvers are used. Pass an int "
"for reproducible results."
),
es=(
"Usado con los métodos 'arpack' o 'randomized'. Pasa un entero "
"para resultados reproducibles."
),
),
) # type: ignore
[docs]
class PCA(DimensionalityReductionConverter, SklearnWrapper, PCAOPERATION):
"""Scikit-learn's PCA wrapper for DashAI."""
SCHEMA = PCASchema
DESCRIPTION = MultilingualString(
en=(
"Principal Component Analysis (PCA) is a dimensionality reduction "
"technique used to simplify complex datasets while retaining as much "
"variability as possible."
),
es=(
"El Análisis de Componentes Principales (PCA) es una técnica de "
"reducción de dimensionalidad usada para simplificar conjuntos de "
"datos complejos conservando tanta variabilidad como sea posible."
),
)
SHORT_DESCRIPTION = MultilingualString(
en="Dimensionality reduction using PCA.",
es="Reducción de dimensionalidad usando PCA.",
)
DISPLAY_NAME = MultilingualString(
en="Principal Component Analysis (PCA)",
es="Análisis de Componentes Principales (PCA)",
)
IMAGE_PREVIEW = "pca.png"
metadata = {}
[docs]
def __init__(self, **kwargs):
self.random_state = kwargs.pop("random_state", None)
if self.random_state == "RandomState":
self.random_state = create_random_state()
kwargs["random_state"] = self.random_state
super().__init__(**kwargs)
def get_output_type(self, column_name: str = None) -> DashAIDataType:
"""Returns Float64 as the output type for PCA components."""
return Float(arrow_type=pa.float64())