Source code for DashAI.back.dataloaders.classes.json_dataloader

"""DashAI JSON Dataloader."""

import shutil
from typing import Any, Dict

from beartype import beartype
from datasets import load_dataset

from DashAI.back.core.schema_fields import (
    none_type,
    schema_field,
    string_field,
)
from DashAI.back.core.schema_fields.base_schema import BaseSchema
from DashAI.back.dataloaders.classes.dashai_dataset import (
    DashAIDataset,
    to_dashai_dataset,
)
from DashAI.back.dataloaders.classes.dataloader import BaseDataLoader


class JSONDataloaderSchema(BaseSchema):
    name: schema_field(
        none_type(string_field()),
        "",
        (
            "Custom name to register your dataset. If no name is specified, "
            "the name of the uploaded file will be used."
        ),
    )  # type: ignore
    data_key: schema_field(
        none_type(string_field()),
        placeholder="data",
        description="""
            In case the data has the form {“data”: [{“col1”: val1, “col2”: val2, ...]}}
            (also known as “table” in pandas), name of the field "data",
            where the list with dictionaries with the data should be found.

            In case the format is only a list of dictionaries (also known as
            "records" orient in pandas), set this value as null.
        """,
    )  # type: ignore


[docs]class JSONDataLoader(BaseDataLoader): """Data loader for tabular data in JSON files.""" COMPATIBLE_COMPONENTS = [ "TabularClassificationTask", "TextClassificationTask", "TranslationTask", ] SCHEMA = JSONDataloaderSchema def _check_params(self, params: Dict[str, Any]) -> None: if "data_key" not in params: raise ValueError( "Error trying to load the JSON dataset: " "data_key parameter was not provided." ) if not (isinstance(params["data_key"], str) or params["data_key"] is None): raise TypeError( "params['data_key'] should be a string or None, " f"got {type(params['data_key'])}" ) @beartype def load_data( self, filepath_or_buffer: str, temp_path: str, params: Dict[str, Any], ) -> DashAIDataset: """Load the uploaded JSON dataset into a DatasetDict. Parameters ---------- filepath_or_buffer : str An URL where the dataset is located or a FastAPI/Uvicorn uploaded file object. temp_path : str The temporary path where the files will be extracted and then uploaded. params : Dict[str, Any] Dict with the dataloader parameters. The options are: - data_key (str): The key of the json where the data is contained. Returns ------- DatasetDict A HuggingFace's Dataset with the loaded data. """ self._check_params(params) field = params["data_key"] prepared_path = self.prepare_files(filepath_or_buffer, temp_path) if prepared_path[1] == "file": dataset = load_dataset( "json", data_files=prepared_path[0], field=field, ) else: dataset = load_dataset("json", data_dir=prepared_path[0], field=field) shutil.rmtree(prepared_path[0]) return to_dashai_dataset(dataset)