"""DashAI JSON Dataloader."""
import shutil
from typing import Any, Dict
from beartype import beartype
from datasets import load_dataset
from DashAI.back.core.schema_fields import (
none_type,
schema_field,
string_field,
)
from DashAI.back.core.schema_fields.base_schema import BaseSchema
from DashAI.back.dataloaders.classes.dashai_dataset import (
DashAIDataset,
to_dashai_dataset,
)
from DashAI.back.dataloaders.classes.dataloader import BaseDataLoader
class JSONDataloaderSchema(BaseSchema):
name: schema_field(
none_type(string_field()),
"",
(
"Custom name to register your dataset. If no name is specified, "
"the name of the uploaded file will be used."
),
) # type: ignore
data_key: schema_field(
none_type(string_field()),
placeholder="data",
description="""
In case the data has the form {“data”: [{“col1”: val1, “col2”: val2, ...]}}
(also known as “table” in pandas), name of the field "data",
where the list with dictionaries with the data should be found.
In case the format is only a list of dictionaries (also known as
"records" orient in pandas), set this value as null.
""",
) # type: ignore
[docs]class JSONDataLoader(BaseDataLoader):
"""Data loader for tabular data in JSON files."""
COMPATIBLE_COMPONENTS = [
"TabularClassificationTask",
"TextClassificationTask",
"TranslationTask",
]
SCHEMA = JSONDataloaderSchema
def _check_params(self, params: Dict[str, Any]) -> None:
if "data_key" not in params:
raise ValueError(
"Error trying to load the JSON dataset: "
"data_key parameter was not provided."
)
if not (isinstance(params["data_key"], str) or params["data_key"] is None):
raise TypeError(
"params['data_key'] should be a string or None, "
f"got {type(params['data_key'])}"
)
@beartype
def load_data(
self,
filepath_or_buffer: str,
temp_path: str,
params: Dict[str, Any],
) -> DashAIDataset:
"""Load the uploaded JSON dataset into a DatasetDict.
Parameters
----------
filepath_or_buffer : str
An URL where the dataset is located or a FastAPI/Uvicorn uploaded file
object.
temp_path : str
The temporary path where the files will be extracted and then uploaded.
params : Dict[str, Any]
Dict with the dataloader parameters. The options are:
- data_key (str): The key of the json where the data is contained.
Returns
-------
DatasetDict
A HuggingFace's Dataset with the loaded data.
"""
self._check_params(params)
field = params["data_key"]
prepared_path = self.prepare_files(filepath_or_buffer, temp_path)
if prepared_path[1] == "file":
dataset = load_dataset(
"json",
data_files=prepared_path[0],
field=field,
)
else:
dataset = load_dataset("json", data_dir=prepared_path[0], field=field)
shutil.rmtree(prepared_path[0])
return to_dashai_dataset(dataset)