Source code for DashAI.back.converters.scikit_learn.additive_chi_2_sampler

import pyarrow as pa
from sklearn.kernel_approximation import (
    AdditiveChi2Sampler as AdditiveChi2SamplerOperation,
)

from DashAI.back.converters.category.polynomial_kernel import PolynomialKernelConverter
from DashAI.back.converters.sklearn_wrapper import SklearnWrapper
from DashAI.back.core.schema_fields import (
    float_field,
    int_field,
    none_type,
    schema_field,
)
from DashAI.back.core.schema_fields.base_schema import BaseSchema
from DashAI.back.core.utils import MultilingualString
from DashAI.back.types.dashai_data_type import DashAIDataType
from DashAI.back.types.value_types import Float


class AdditiveChi2SamplerSchema(BaseSchema):
    sample_steps: schema_field(
        int_field(ge=1),
        2,
        description=MultilingualString(
            en="The number of sample steps (shuffling) to perform.",
            es="Número de pasos de muestreo (mezcla) a realizar.",
        ),
    )  # type: ignore
    sample_interval: schema_field(
        none_type(float_field(ge=1.0)),
        None,
        description=MultilingualString(
            en="The number of samples generated between each original sample.",
            es="Número de muestras generadas entre cada muestra original.",
        ),
    )  # type: ignore


[docs] class AdditiveChi2Sampler( PolynomialKernelConverter, SklearnWrapper, AdditiveChi2SamplerOperation ): """Scikit-learn's AdditiveChi2Sampler wrapper for DashAI.""" SCHEMA = AdditiveChi2SamplerSchema DESCRIPTION = MultilingualString( en=( "Uses sampling the Fourier transform of the kernel characteristic " "at regular intervals." ), es=( "Utiliza muestreo de la transformada de Fourier de la función kernel " "a intervalos regulares." ), ) DISPLAY_NAME = MultilingualString(en="Additive Chi² Sampler", es="Muestreador Chi²") IMAGE_PREVIEW = "additive_chi2_sampler.png" def get_output_type(self, column_name: str = None) -> DashAIDataType: """Returns Float64 as the output type for transformed data.""" return Float(arrow_type=pa.float64())