"""DashAI CSV Dataloader."""
import shutil
from itertools import islice
from typing import Any, Dict
import pandas as pd
from beartype import beartype
from datasets import Dataset, IterableDatasetDict, load_dataset
from DashAI.back.core.schema_fields import (
bool_field,
enum_field,
int_field,
none_type,
schema_field,
string_field,
)
from DashAI.back.core.schema_fields.base_schema import BaseSchema
from DashAI.back.core.utils import MultilingualString
from DashAI.back.dataloaders.classes.dashai_dataset import (
DashAIDataset,
to_dashai_dataset,
)
from DashAI.back.dataloaders.classes.dataloader import BaseDataLoader
class CSVDataloaderSchema(BaseSchema):
name: schema_field(
string_field(),
"",
description=MultilingualString(
en=(
"Custom name to register your dataset. If no name is specified, "
"the name of the uploaded file will be used."
),
es=(
"Nombre personalizado para registrar su dataset. Si no se especifica "
"un nombre, se usará el nombre del archivo subido."
),
),
alias=MultilingualString(en="Name", es="Nombre"),
) # type: ignore
separator: schema_field(
enum_field([",", ";", "blank space", "tab"]),
",",
description=MultilingualString(
en="A separator character delimits the data in a CSV file.",
es="Un carácter separador delimita los datos en un archivo CSV.",
),
alias=MultilingualString(en="Separator", es="Separador"),
) # type: ignore
header: schema_field(
string_field(),
"infer",
description=MultilingualString(
en=(
"Row number(s) containing column labels and marking the start of the "
"data (zero-indexed). Default behavior is to infer the column names. "
"If column names are passed explicitly, this should be set to '0'. "
"Header can also be a list of integers that specify row locations "
"for MultiIndex on the columns."
),
es=(
"Número(s) de fila que contienen las etiquetas de columna y marcan "
"el inicio de los datos (indexado desde cero). El comportamiento "
"predeterminado es inferir los nombres de columna. Si los nombres de "
"columna se pasan explícitamente, esto debe establecerse en '0'. "
"Header también puede ser una lista de enteros que especifican las "
"ubicaciones de fila para MultiIndex en las columnas."
),
),
alias=MultilingualString(en="Header", es="Encabezado"),
) # type: ignore
names: schema_field(
none_type(string_field()),
None,
description=MultilingualString(
en=(
"Comma-separated list of column names to use. If the file contains a "
"header row, then you should explicitly pass header=0 to override the "
"column names. Example: 'col1,col2,col3'. Leave empty to use file "
"headers."
),
es=(
"Lista de nombres de columna separados por comas. Si el archivo "
"contiene una fila de encabezado, debe pasar explícitamente header=0 "
"para sobrescribir los nombres de columna. Ejemplo: 'col1,col2,col3'. "
"Deje vacío para usar los encabezados del archivo."
),
),
alias=MultilingualString(en="Names", es="Nombres"),
) # type: ignore
encoding: schema_field(
enum_field(["utf-8", "latin1", "cp1252", "iso-8859-1"]),
"utf-8",
description=MultilingualString(
en=(
"Encoding to use for UTF when reading/writing. Most common encodings "
"provided."
),
es=(
"Codificación a usar para UTF al leer/escribir. Se proporcionan las "
"codificaciones más comunes."
),
),
alias=MultilingualString(en="Encoding", es="Codificación"),
) # type: ignore
na_values: schema_field(
none_type(string_field()),
None,
description=MultilingualString(
en=(
"Comma-separated additional strings to recognize as NA/NaN. "
"Example: 'NULL,missing,n/a'"
),
es=(
"Cadenas adicionales separadas por comas para reconocer como NA/NaN. "
"Ejemplo: 'NULL,missing,n/a'"
),
),
alias=MultilingualString(en="NA values", es="Valores NA"),
) # type: ignore
keep_default_na: schema_field(
bool_field(),
True,
description=MultilingualString(
en=(
"Whether to include the default NaN values when parsing the data "
"(True recommended)."
),
es=(
"Si se deben incluir los valores NaN predeterminados al analizar los "
"datos (se recomienda True)."
),
),
alias=MultilingualString(en="Keep default NA", es="Mantener NA predeterminado"),
) # type: ignore
true_values: schema_field(
none_type(string_field()),
None,
description=MultilingualString(
en="Comma-separated values to consider as True. Example: 'yes,true,1,on'",
es=(
"Valores separados por comas a considerar como True. "
"Ejemplo: 'yes,true,1,on'"
),
),
alias=MultilingualString(en="True values", es="Valores verdaderos"),
) # type: ignore
false_values: schema_field(
none_type(string_field()),
None,
description=MultilingualString(
en="Comma-separated values to consider as False. Example: 'no,false,0,off'",
es=(
"Valores separados por comas a considerar como False. "
"Ejemplo: 'no,false,0,off'"
),
),
alias=MultilingualString(en="False values", es="Valores falsos"),
) # type: ignore
skip_blank_lines: schema_field(
bool_field(),
True,
description=MultilingualString(
en="If True, skip over blank lines rather than interpreting as NaN values.",
es=(
"Si es True, omitir líneas en blanco en lugar de interpretarlas como "
"valores NaN."
),
),
alias=MultilingualString(en="Skip blank lines", es="Omitir líneas en blanco"),
) # type: ignore
skiprows: schema_field(
none_type(int_field()),
None,
description=MultilingualString(
en=(
"Number of lines to skip at the beginning of the file. "
"Leave empty to skip none."
),
es=(
"Número de líneas a omitir al inicio del archivo. "
"Deje vacío para no omitir ninguna."
),
),
alias=MultilingualString(en="Skip rows", es="Omitir filas"),
) # type: ignore
nrows: schema_field(
none_type(int_field()),
None,
description=MultilingualString(
en="Number of rows to read from the file. Leave empty to read all rows.",
es=(
"Número de filas a leer del archivo. Deje vacío para leer todas las "
"filas."
),
),
alias=MultilingualString(en="N rows", es="N filas"),
) # type: ignore
[docs]
class CSVDataLoader(BaseDataLoader):
"""Data loader for tabular data in CSV files."""
COMPATIBLE_COMPONENTS = ["TabularClassificationTask"]
SCHEMA = CSVDataloaderSchema
DESCRIPTION: str = MultilingualString(
en=(
"Data loader for tabular data in CSV files. "
"All uploaded CSV files must have the same column structure and use "
"consistent separators."
),
es=(
"Cargador de datos para datos tabulares en archivos CSV. "
"Todos los archivos CSV subidos deben tener la misma estructura de "
"columnas y usar separadores consistentes."
),
)
DISPLAY_NAME: str = MultilingualString(
en="CSV Data Loader",
es="Cargador de Datos CSV",
)
def _check_params(
self,
params: Dict[str, Any],
) -> None:
if "separator" not in params:
raise ValueError(
"Error trying to load the CSV dataset: "
"separator parameter was not provided."
)
clean_params = {}
separator = params["separator"]
if separator == "blank space":
separator = " "
elif separator == "tab":
separator = "\t"
if not isinstance(separator, str):
raise TypeError(
f"Param separator should be a string, got {type(params['separator'])}"
)
clean_params["delimiter"] = separator
if params.get("header") is not None:
clean_params["header"] = params["header"]
list_params = ["names", "na_values", "true_values", "false_values"]
for param in list_params:
if param in params and params[param]:
clean_params[param] = [val.strip() for val in params[param].split(",")]
bool_params = ["keep_default_na", "skip_blank_lines"]
for param in bool_params:
if param in params and params[param] is not None:
clean_params[param] = params[param]
int_params = ["skiprows", "nrows"]
for param in int_params:
if param in params and params[param] is not None:
if not isinstance(params[param], int):
raise TypeError(
f"Param {param} should be an integer, got {type(params[param])}"
)
clean_params[param] = params[param]
if "encoding" in params and params["encoding"]:
valid_encodings = ["utf-8", "latin1", "cp1252", "iso-8859-1"]
if params["encoding"] not in valid_encodings:
raise ValueError(f"Invalid encoding: {params['encoding']}")
clean_params["encoding"] = params["encoding"]
return clean_params
@beartype
def load_data(
self,
filepath_or_buffer: str,
temp_path: str,
params: Dict[str, Any],
n_sample: int | None = None,
) -> DashAIDataset:
"""Load the uploaded CSV files into a DatasetDict.
Parameters
----------
filepath_or_buffer : str, optional
An URL where the dataset is located or a FastAPI/Uvicorn uploaded file
object.
temp_path : str
The temporary path where the files will be extracted and then uploaded.
params : Dict[str, Any]
Dict with the dataloader parameters. The options are:
- `separator` (str): The character that delimits the CSV data.
n_sample : int | None
Indicates how many rows load from the dataset, all rows if null.
Returns
-------
DatasetDict
A HuggingFace's Dataset with the loaded data.
"""
clean_params = self._check_params(params)
prepared_path = self.prepare_files(filepath_or_buffer, temp_path)
if prepared_path[1] == "file":
dataset = load_dataset(
"csv",
data_files=prepared_path[0],
**clean_params,
streaming=bool(n_sample),
)
else:
dataset = load_dataset(
"csv",
data_dir=prepared_path[0],
**clean_params,
streaming=bool(n_sample),
)
shutil.rmtree(prepared_path[0])
if n_sample:
if type(dataset) is IterableDatasetDict:
dataset = dataset["train"]
dataset = Dataset.from_list(list(dataset.take(n_sample)))
return to_dashai_dataset(dataset)
def load_preview(
self,
filepath_or_buffer: str,
params: Dict[str, Any],
n_rows: int = 100,
) -> pd.DataFrame:
"""
Load a preview of the CSV dataset using streaming.
Parameters
----------
filepath_or_buffer : str
Path to the CSV file.
params : Dict[str, Any]
Parameters for loading the CSV (separator, encoding, etc.).
n_rows : int, optional
Number of rows to preview. Default is 100.
Returns
-------
pd.DataFrame
A DataFrame containing the preview rows.
"""
clean_params = self._check_params(params)
dataset_stream = load_dataset(
"csv",
data_files=filepath_or_buffer,
streaming=True,
split="train",
**clean_params,
)
sample_rows = list(islice(dataset_stream, n_rows))
df_preview = pd.DataFrame(sample_rows)
return df_preview