"""DashAI CSV Dataloader."""
import shutil
from typing import Any, Dict
from beartype import beartype
from datasets import load_dataset
from DashAI.back.core.schema_fields import (
bool_field,
enum_field,
int_field,
none_type,
schema_field,
string_field,
)
from DashAI.back.core.schema_fields.base_schema import BaseSchema
from DashAI.back.dataloaders.classes.dashai_dataset import (
DashAIDataset,
to_dashai_dataset,
)
from DashAI.back.dataloaders.classes.dataloader import BaseDataLoader
class CSVDataloaderSchema(BaseSchema):
name: schema_field(
none_type(string_field()),
"",
(
"Custom name to register your dataset. If no name is specified, "
"the name of the uploaded file will be used."
),
) # type: ignore
separator: schema_field(
enum_field([",", ";", "blank space", "tab"]),
",",
"A separator character delimits the data in a CSV file.",
) # type: ignore
header: schema_field(
string_field(),
"infer",
(
"Row number(s) containing column labels and marking the start of the data "
"(zero-indexed). Default behavior is to infer the column names. If column "
"names are passed explicitly, this should be set to '0'. "
"Header can also be a list of integers that specify row locations "
"for MultiIndex on the columns."
),
) # type: ignore
names: schema_field(
none_type(string_field()),
None,
(
"Comma-separated list of column names to use. If the file contains a "
"header row, "
"then you should explicitly pass header=0 to override the column names. "
"Example: 'col1,col2,col3'. Leave empty to use file headers."
),
) # type: ignore
encoding: schema_field(
enum_field(["utf-8", "latin1", "cp1252", "iso-8859-1"]),
"utf-8",
"Encoding to use for UTF when reading/writing. Most common encodings provided.",
) # type: ignore
na_values: schema_field(
none_type(string_field()),
None,
(
"Comma-separated additional strings to recognize as NA/NaN. "
"Example: 'NULL,missing,n/a'"
),
) # type: ignore
keep_default_na: schema_field(
bool_field(),
True,
(
"Whether to include the default NaN values when parsing the data "
"(True recommended)."
),
) # type: ignore
true_values: schema_field(
none_type(string_field()),
None,
"Comma-separated values to consider as True. Example: 'yes,true,1,on'",
) # type: ignore
false_values: schema_field(
none_type(string_field()),
None,
"Comma-separated values to consider as False. Example: 'no,false,0,off'",
) # type: ignore
skip_blank_lines: schema_field(
bool_field(),
True,
"If True, skip over blank lines rather than interpreting as NaN values.",
) # type: ignore
skiprows: schema_field(
none_type(int_field()),
None,
"Number of lines to skip at the beginning of the file. "
"Leave empty to skip none.",
) # type: ignore
nrows: schema_field(
none_type(int_field()),
None,
"Number of rows to read from the file. Leave empty to read all rows.",
) # type: ignore
[docs]
class CSVDataLoader(BaseDataLoader):
"""Data loader for tabular data in CSV files."""
COMPATIBLE_COMPONENTS = ["TabularClassificationTask"]
SCHEMA = CSVDataloaderSchema
DESCRIPTION: str = """
Data loader for tabular data in CSV files.
All uploaded CSV files must have the same column structure and use
consistent separators.
"""
def _check_params(
self,
params: Dict[str, Any],
) -> None:
if "separator" not in params:
raise ValueError(
"Error trying to load the CSV dataset: "
"separator parameter was not provided."
)
clean_params = {}
separator = params["separator"]
if separator == "blank space":
separator = " "
elif separator == "tab":
separator = "\t"
if not isinstance(separator, str):
raise TypeError(
f"Param separator should be a string, got {type(params['separator'])}"
)
clean_params["delimiter"] = separator
if params.get("header") is not None:
clean_params["header"] = params["header"]
list_params = ["names", "na_values", "true_values", "false_values"]
for param in list_params:
if param in params and params[param]:
clean_params[param] = [val.strip() for val in params[param].split(",")]
bool_params = ["keep_default_na", "skip_blank_lines"]
for param in bool_params:
if param in params and params[param] is not None:
clean_params[param] = params[param]
int_params = ["skiprows", "nrows"]
for param in int_params:
if param in params and params[param] is not None:
if not isinstance(params[param], int):
raise TypeError(
f"Param {param} should be an integer, got {type(params[param])}"
)
clean_params[param] = params[param]
if "encoding" in params and params["encoding"]:
valid_encodings = ["utf-8", "latin1", "cp1252", "iso-8859-1"]
if params["encoding"] not in valid_encodings:
raise ValueError(f"Invalid encoding: {params['encoding']}")
clean_params["encoding"] = params["encoding"]
return clean_params
@beartype
def load_data(
self,
filepath_or_buffer: str,
temp_path: str,
params: Dict[str, Any],
) -> DashAIDataset:
"""Load the uploaded CSV files into a DatasetDict.
Parameters
----------
filepath_or_buffer : str, optional
An URL where the dataset is located or a FastAPI/Uvicorn uploaded file
object.
temp_path : str
The temporary path where the files will be extracted and then uploaded.
params : Dict[str, Any]
Dict with the dataloader parameters. The options are:
- `separator` (str): The character that delimits the CSV data.
Returns
-------
DatasetDict
A HuggingFace's Dataset with the loaded data.
"""
print("parameters are", params)
clean_params = self._check_params(params)
print("cleaned parameters are", clean_params)
prepared_path = self.prepare_files(filepath_or_buffer, temp_path)
if prepared_path[1] == "file":
dataset = load_dataset(
"csv",
data_files=prepared_path[0],
**clean_params,
)
else:
dataset = load_dataset(
"csv",
data_dir=prepared_path[0],
**clean_params,
)
shutil.rmtree(prepared_path[0])
return to_dashai_dataset(dataset)