Source code for DashAI.back.converters.sklearn_wrapper

from abc import ABCMeta
from typing import Type, Union

from DashAI.back.converters.base_converter import BaseConverter
from DashAI.back.dataloaders.classes.dashai_dataset import (
    DashAIDataset,
    to_dashai_dataset,
)


[docs]class SklearnWrapper(BaseConverter, metaclass=ABCMeta): """Abstract class to define generic rules for sklearn transformers"""
[docs] def __init__(self, **kwargs): # Initialize sklearn operation with provided parameters super(SklearnWrapper, self).__init__() # Initialize BaseConverter super(BaseConverter, self).__init__(**kwargs) # Initialize sklearn operation if hasattr( self, "set_output" ): # Not all scikit-learn transformers support the set_output API self.set_output( transform="pandas" ) # Cast the output from numpy ndarray to pandas DataFrame
def fit( self, x: DashAIDataset, y: Union[DashAIDataset, None] = None ) -> Type[BaseConverter]: """Generic fit method for sklearn transformers""" x_pandas = x.to_pandas() if y is not None: y_pandas = y.to_pandas() requires_y = hasattr(self, "_get_tags") and self._get_tags().get( "requires_y", False ) # Check for supervised transformers that require y if requires_y and y is None: raise ValueError("This transformer requires y for fitting") if requires_y: super(BaseConverter, self).fit(x_pandas, y_pandas) else: super(BaseConverter, self).fit(x_pandas) return self def transform( self, x: DashAIDataset, y: Union[DashAIDataset, None] = None ) -> DashAIDataset: """Generic transform method for sklearn transformers""" x_pandas = x.to_pandas() x_new_pandas = super(BaseConverter, self).transform(x_pandas) return to_dashai_dataset(x_new_pandas)