"""DashAI JSON Dataloader."""
import shutil
from itertools import islice
from typing import Any, Dict
import pandas as pd
from beartype import beartype
from datasets import Dataset, IterableDatasetDict, load_dataset
from DashAI.back.core.schema_fields import none_type, schema_field, string_field
from DashAI.back.core.schema_fields.base_schema import BaseSchema
from DashAI.back.core.utils import MultilingualString
from DashAI.back.dataloaders.classes.dashai_dataset import (
DashAIDataset,
to_dashai_dataset,
)
from DashAI.back.dataloaders.classes.dataloader import BaseDataLoader
class JSONDataloaderSchema(BaseSchema):
name: schema_field(
string_field(),
"",
description=MultilingualString(
en=(
"Custom name to register your dataset. If no name is specified, "
"the name of the uploaded file will be used."
),
es=(
"Nombre personalizado para registrar su dataset. Si no se especifica "
"un nombre, se usará el nombre del archivo subido."
),
),
alias=MultilingualString(en="Name", es="Nombre"),
) # type: ignore
data_key: schema_field(
none_type(string_field()),
placeholder="data",
description=MultilingualString(
en=(
"In case the data has the form "
'{"data": [{"col1": val1, "col2": val2, ...}]} '
'(also known as "table" in pandas), name of the field "data", '
"where the list with dictionaries with the data should be found. "
"In case the format is only a list of dictionaries (also known as "
'"records" orient in pandas), set this value as null.'
),
es=(
"En caso de que los datos tengan la forma "
'{"data": [{"col1": val1, "col2": val2, ...}]} '
'(también conocido como "table" en pandas), nombre del campo "data", '
"donde se debe encontrar la lista con diccionarios con los datos. "
"En caso de que el formato sea solo una lista de diccionarios "
'(también conocido como orientación "records" en pandas), '
"establezca este valor como null."
),
),
alias=MultilingualString(en="Data key", es="Clave de datos"),
) # type: ignore
[docs]
class JSONDataLoader(BaseDataLoader):
"""Data loader for tabular data in JSON files."""
COMPATIBLE_COMPONENTS = [
"TabularClassificationTask",
"TextClassificationTask",
"TranslationTask",
]
SCHEMA = JSONDataloaderSchema
DESCRIPTION: str = MultilingualString(
en=(
"Data loader for tabular data in JSON files. "
"Supports both standard JSON array format (a list of dictionaries) "
"and nested JSON data where records are contained within a specific key."
),
es=(
"Cargador de datos para datos tabulares en archivos JSON. "
"Soporta tanto el formato de array JSON estándar (una lista de "
"diccionarios) como datos JSON anidados donde los registros están "
"contenidos dentro de una clave específica."
),
)
DISPLAY_NAME: str = MultilingualString(
en="JSON Data Loader",
es="Cargador de Datos JSON",
)
def _check_params(self, params: Dict[str, Any]) -> None:
if "data_key" not in params:
raise ValueError(
"Error trying to load the JSON dataset: "
"data_key parameter was not provided."
)
if not (isinstance(params["data_key"], str) or params["data_key"] is None):
raise TypeError(
"params['data_key'] should be a string or None, "
f"got {type(params['data_key'])}"
)
@beartype
def load_data(
self,
filepath_or_buffer: str,
temp_path: str,
params: Dict[str, Any],
n_sample: int | None = None,
) -> DashAIDataset:
"""Load the uploaded JSON dataset into a DatasetDict.
Parameters
----------
filepath_or_buffer : str
An URL where the dataset is located or a FastAPI/Uvicorn uploaded file
object.
temp_path : str
The temporary path where the files will be extracted and then uploaded.
params : Dict[str, Any]
Dict with the dataloader parameters. The options are:
- data_key (str): The key of the json where the data is contained.
n_sample : int | None
Indicates how many rows load from the dataset, all rows if null.
Returns
-------
DatasetDict
A HuggingFace's Dataset with the loaded data.
"""
self._check_params(params)
field = params["data_key"]
prepared_path = self.prepare_files(filepath_or_buffer, temp_path)
if prepared_path[1] == "file":
dataset = load_dataset(
"json",
data_files=prepared_path[0],
field=field,
streaming=bool(n_sample),
)
else:
dataset = load_dataset(
"json", data_dir=prepared_path[0], field=field, streaming=bool(n_sample)
)
shutil.rmtree(prepared_path[0])
if n_sample:
if type(dataset) is IterableDatasetDict:
dataset = dataset["train"]
dataset = Dataset.from_list(list(dataset.take(n_sample)))
return to_dashai_dataset(dataset)
def load_preview(
self,
filepath_or_buffer: str,
params: Dict[str, Any],
n_rows: int = 100,
) -> pd.DataFrame:
"""
Load a preview of the JSON dataset using streaming.
Parameters
----------
filepath_or_buffer : str
Path to the JSON file.
params : Dict[str, Any]
Parameters for loading the JSON (data_key).
n_rows : int, optional
Number of rows to preview. Default is 100.
Returns
-------
pd.DataFrame
A DataFrame containing the preview rows.
"""
self._check_params(params)
field = params.get("data_key")
dataset_stream = load_dataset(
"json",
data_files=filepath_or_buffer,
field=field,
streaming=True,
split="train",
)
sample_rows = list(islice(dataset_stream, n_rows))
return pd.DataFrame(sample_rows)