Source code for DashAI.back.dataloaders.classes.excel_dataloader

"""DashAI Excel Dataloader."""

import glob
import shutil
from typing import Any, Dict

import pandas as pd
from beartype import beartype
from datasets import Dataset, DatasetDict
from datasets.builder import DatasetGenerationError

from DashAI.back.core.schema_fields import (
    int_field,
    none_type,
    schema_field,
    string_field,
    union_type,
)
from DashAI.back.core.schema_fields.base_schema import BaseSchema
from DashAI.back.dataloaders.classes.dashai_dataset import (
    DashAIDataset,
    to_dashai_dataset,
)
from DashAI.back.dataloaders.classes.dataloader import BaseDataLoader


class ExcelDataloaderSchema(BaseSchema):
    name: schema_field(
        none_type(string_field()),
        "",
        (
            "Custom name to register your dataset. If no name is specified, "
            "the name of the uploaded file will be used."
        ),
    )  # type: ignore
    sheet: schema_field(
        union_type(int_field(ge=0), string_field()),
        placeholder=0,
        description="""
        The name of the sheet to read or its zero-based index.
        If a string is provided, the reader will search for a sheet named exactly as
        the string.
        If an integer is provided, the reader will select the sheet at the corresponding
        index.
        By default, the first sheet will be read.
        """,
    )  # type: ignore
    header: schema_field(
        none_type(int_field(ge=0)),
        placeholder=0,
        description="""
        The row number where the column names are located, indexed from 0.
        If null, the file will be considered to have no column names.
        """,
    )  # type: ignore
    usecols: schema_field(
        none_type(string_field()),
        placeholder=None,
        description="""
        If None, the reader will load all columns.
        If str, then indicates comma separated list of Excel column letters and column
        ranges (e.g. “A:E” or “A,C,E:F”). Ranges are inclusive of both sides.
        """,
    )  # type: ignore


[docs]class ExcelDataLoader(BaseDataLoader): """Data loader for tabular data in Excel files.""" COMPATIBLE_COMPONENTS = ["TabularClassificationTask"] SCHEMA = ExcelDataloaderSchema @beartype def load_data( self, filepath_or_buffer: str, temp_path: str, params: Dict[str, Any], ) -> DashAIDataset: """Load the uploaded Excel files into a DatasetDict. Parameters ---------- filepath_or_buffer : str An URL where the dataset is located or a FastAPI/Uvicorn uploaded file object. temp_path : str The temporary path where the files will be extracted and then uploaded. params : Dict[str, Any] Dict with the dataloader parameters. Returns ------- DatasetDict A HuggingFace's Dataset with the loaded data. """ prepared_path = self.prepare_files(filepath_or_buffer, temp_path) if prepared_path[1] == "file": try: dataset = pd.read_excel( io=prepared_path[0], sheet_name=params["sheet"], header=params["header"], usecols=params["usecols"], ) except ValueError as e: raise DatasetGenerationError from e dataset_dict = DatasetDict({"train": Dataset.from_pandas(dataset)}) if prepared_path[1] == "dir": train_files = glob.glob(prepared_path[0] + "/train/*") test_files = glob.glob(prepared_path[0] + "/test/*") val_files = glob.glob(prepared_path[0] + "/val/*") + glob.glob( prepared_path[0] + "/validation/*" ) try: train_df_list = [ pd.read_excel( io=file_path, sheet_name=params["sheet"], header=params["header"], usecols=params["usecols"], ) for file_path in sorted(train_files) ] train_df = pd.concat(train_df_list) test_df_list = [ pd.read_excel( io=file_path, sheet_name=params["sheet"], header=params["header"], usecols=params["usecols"], ) for file_path in sorted(test_files) ] test_df_list = pd.concat(test_df_list) val_df_list = [ pd.read_excel( io=file_path, sheet_name=params["sheet"], header=params["header"], usecols=params["usecols"], ) for file_path in sorted(val_files) ] val_df = pd.concat(val_df_list) dataset_dict = DatasetDict( { "train": Dataset.from_pandas(train_df, preserve_index=False), "test": Dataset.from_pandas(test_df_list, preserve_index=False), "validation": Dataset.from_pandas(val_df, preserve_index=False), } ) except ValueError as e: raise DatasetGenerationError from e finally: shutil.rmtree(prepared_path[0]) return to_dashai_dataset(dataset_dict)