Source code for DashAI.back.dataloaders.classes.csv_dataloader

"""DashAI CSV Dataloader."""

import shutil
from typing import Any, Dict

from beartype import beartype
from datasets import load_dataset

from DashAI.back.core.schema_fields import (
    enum_field,
    none_type,
    schema_field,
    string_field,
)
from DashAI.back.core.schema_fields.base_schema import BaseSchema
from DashAI.back.dataloaders.classes.dashai_dataset import (
    DashAIDataset,
    to_dashai_dataset,
)
from DashAI.back.dataloaders.classes.dataloader import BaseDataLoader


class CSVDataloaderSchema(BaseSchema):
    name: schema_field(
        none_type(string_field()),
        "",
        (
            "Custom name to register your dataset. If no name is specified, "
            "the name of the uploaded file will be used."
        ),
    )  # type: ignore
    separator: schema_field(
        enum_field([",", ";", "\u0020", "\t"]),
        ",",
        "A separator character delimits the data in a CSV file.",
    )  # type: ignore


[docs]class CSVDataLoader(BaseDataLoader): """Data loader for tabular data in CSV files.""" COMPATIBLE_COMPONENTS = ["TabularClassificationTask"] SCHEMA = CSVDataloaderSchema def _check_params( self, params: Dict[str, Any], ) -> None: if "separator" not in params: raise ValueError( "Error trying to load the CSV dataset: " "separator parameter was not provided." ) separator = params["separator"] if not isinstance(separator, str): raise TypeError( f"Param separator should be a string, got {type(params['separator'])}" ) @beartype def load_data( self, filepath_or_buffer: str, temp_path: str, params: Dict[str, Any], ) -> DashAIDataset: """Load the uploaded CSV files into a DatasetDict. Parameters ---------- filepath_or_buffer : str, optional An URL where the dataset is located or a FastAPI/Uvicorn uploaded file object. temp_path : str The temporary path where the files will be extracted and then uploaded. params : Dict[str, Any] Dict with the dataloader parameters. The options are: - `separator` (str): The character that delimits the CSV data. Returns ------- DatasetDict A HuggingFace's Dataset with the loaded data. """ self._check_params(params) separator = params["separator"] prepared_path = self.prepare_files(filepath_or_buffer, temp_path) if prepared_path[1] == "file": dataset = load_dataset( "csv", data_files=prepared_path[0], delimiter=separator, ) else: dataset = load_dataset( "csv", data_dir=prepared_path[0], delimiter=separator, ) shutil.rmtree(prepared_path[0]) return to_dashai_dataset(dataset)