from abc import ABCMeta, abstractmethod
from typing import Type, Union
import numpy as np
import pandas as pd
import pyarrow as pa
from DashAI.back.converters.base_converter import BaseConverter
from DashAI.back.dataloaders.classes.dashai_dataset import (
DashAIDataset,
to_dashai_dataset,
)
from DashAI.back.types.categorical import Categorical
from DashAI.back.types.dashai_data_type import DashAIDataType
[docs]
class SklearnWrapper(BaseConverter, metaclass=ABCMeta):
"""Abstract class to define generic rules for sklearn transformers"""
[docs]
def __init__(self, **kwargs):
super().__init__(**kwargs)
if hasattr(
self, "set_output"
): # Not all scikit-learn transformers support the set_output API
self.set_output(
transform="pandas"
) # Cast the output from numpy ndarray to pandas DataFrame
@abstractmethod
def get_output_type(self, column_name: str = None) -> DashAIDataType:
"""
Each sklearn converter must implement this method to specify its output type.
"""
raise NotImplementedError
def fit(
self, x: DashAIDataset, y: Union[DashAIDataset, None] = None
) -> Type[BaseConverter]:
"""Generic fit method for sklearn transformers"""
x_pandas = x.to_pandas()
if y is not None:
y_pandas = y.to_pandas()
requires_y = hasattr(self, "_get_tags") and self._get_tags().get(
"requires_y", False
)
# Check for supervised transformers that require y
if requires_y and y is None:
raise ValueError("This transformer requires y for fitting")
if requires_y:
super().fit(x_pandas, y_pandas)
else:
super().fit(x_pandas)
return self
def transform(
self, x: DashAIDataset, y: Union[DashAIDataset, None] = None
) -> DashAIDataset:
"""Generic transform method for sklearn transformers"""
x_pandas = x.to_pandas()
x_new = super().transform(x_pandas)
if isinstance(x_new, np.ndarray):
columns = x_pandas.columns if hasattr(x_pandas, "columns") else None
x_new = pd.DataFrame(x_new, columns=columns)
converted_dataset = to_dashai_dataset(x_new)
for col in converted_dataset.column_names:
try:
output_type = self.get_output_type(col)
# Special handling for categorical types that need class information
if isinstance(output_type, Categorical) and hasattr(self, "classes_"):
values = pa.array(self.classes_.tolist())
encoding = {v: i for i, v in enumerate(self.classes_)}
converted_dataset.types[col] = Categorical(
values=values, encoding=encoding, converted=True
)
else:
converted_dataset.types[col] = output_type
except NotImplementedError:
print(
f"Warning: Converter {self.__class__.__name__} does not implement "
f"get_output_type. Column {col} type may not be properly set."
)
return converted_dataset