Source code for DashAI.back.converters.hugging_face_wrapper

from abc import ABCMeta, abstractmethod
from typing import Type

from datasets import concatenate_datasets

from DashAI.back.converters.base_converter import BaseConverter
from DashAI.back.dataloaders.classes.dashai_dataset import DashAIDataset
from DashAI.back.types.dashai_data_type import DashAIDataType


[docs] class HuggingFaceWrapper(BaseConverter, metaclass=ABCMeta): """Abstract base wrapper for HuggingFace transformers."""
[docs] def __init__(self, **kwargs): super().__init__()
@abstractmethod def _load_model(self): """Load the HuggingFace model and tokenizer.""" raise NotImplementedError @abstractmethod def _process_batch(self, batch: DashAIDataset) -> DashAIDataset: """Process a batch of data through the model.""" raise NotImplementedError @abstractmethod def get_output_type(self, column_name: str = None) -> DashAIDataType: """ Each HuggingFace converter must implement this method to specify its output type. """ raise NotImplementedError def fit(self, x: DashAIDataset, y: DashAIDataset = None) -> Type[BaseConverter]: """Validate parameters and prepare for transformation.""" if len(x) == 0: raise ValueError("Input dataset is empty") # Check that all columns contain string data for column in x.column_names: if not isinstance(x[0][column], str): raise ValueError(f"Column {column} must contain string data") # Load model if not already loaded self._load_model() return self def transform(self, x: DashAIDataset, y: DashAIDataset = None) -> DashAIDataset: """Transform the input data using the model.""" all_results = [] # Process in batches for i in range(0, len(x), self.batch_size): # Get the current batch batch = x.select(range(i, min(i + self.batch_size, len(x)))) # Process the batch batch_results = self._process_batch(batch) all_results.append(batch_results) # Concatenate all results concatenated_dataset = concatenate_datasets(all_results) converted_dataset = DashAIDataset(concatenated_dataset.data.table) # Set types for each column using the converter's get_output_type method for col in converted_dataset.column_names: try: converted_dataset.types[col] = self.get_output_type(col) except NotImplementedError: print( f"Warning: Converter {self.__class__.__name__} does not implement " f"get_output_type. Column {col} type may not be properly set." ) return converted_dataset