Source code for DashAI.back.dataloaders.classes.dashai_dataset

"""DashAI Dataset implementation."""

import json
import os
from typing import Dict, List, Literal, Tuple, Union

import numpy as np
import pyarrow as pa
import pyarrow.ipc as ipc
from beartype import beartype
from datasets import ClassLabel, Dataset, DatasetDict, Value, concatenate_datasets
from sklearn.model_selection import train_test_split


def get_arrow_table(ds: Dataset) -> pa.Table:
    """
    Retrieve the underlying PyArrow table from a Hugging Face Dataset.
    This function abstracts away the need to access private attributes.

    Parameters:
        ds (Dataset): A Hugging Face Dataset.

    Returns:
        pa.Table: The underlying PyArrow table.

    Raises:
        ValueError: If the arrow table cannot be retrieved.
    """
    if hasattr(ds, "arrow_table"):
        return ds.arrow_table
    elif hasattr(ds, "data") and hasattr(ds.data, "table"):
        return ds.data.table
    else:
        raise ValueError("Unable to retrieve underlying arrow table from the dataset.")


[docs]class DashAIDataset(Dataset): """DashAI dataset wrapper for Huggingface datasets with extra metadata."""
[docs] @beartype def __init__( self, table: pa.Table, splits: dict = None, *args, **kwargs, ): """Initialize a new instance of a DashAI dataset. Parameters ---------- table : Table Arrow table from which the dataset will be created """ super().__init__(table, *args, **kwargs) self.splits = splits or {}
@beartype def cast(self, *args, **kwargs) -> "DashAIDataset": """Override of the cast method to leave it in DashAI dataset format. Returns ------- DatasetDashAI Dataset after cast """ ds = super().cast(*args, **kwargs) arrow_tbl = get_arrow_table(ds) return DashAIDataset(arrow_tbl, splits=self.splits) @property def arrow_table(self) -> pa.Table: """ Provides a clean way to access the underlying PyArrow table. Returns: pa.Table: The underlying PyArrow table. """ try: return self._data.table except AttributeError: raise ValueError("Unable to retrieve the underlying Arrow table.") from None def keys(self) -> List[str]: """Return the available splits in the dataset. Returns ------- List[str] List of split names (e.g., ['train', 'test', 'validation']) """ if "split_indices" in self.splits: return list(self.splits["split_indices"].keys()) return [] @beartype def save_to_disk(self, dataset_path: Union[str, os.PathLike]) -> None: """ Overrides the default save_to_disk method to save the dataset as a single directory with: - "data.arrow": the dataset's Arrow table. - "splits.json": the dataset's splits (e.g., original split indices). Parameters ---------- dataset_path : Union[str, os.PathLike] path where the dataset will be saved """ save_dataset(self, dataset_path) @beartype def change_columns_type(self, column_types: Dict[str, str]) -> "DashAIDataset": """Change the type of some columns. Note: this is a temporal method, and it will probably will delete in the future. Parameters ---------- column_types : Dict[str, str] dictionary whose keys are the names of the columns to be changed and the values the new types. Returns ------- DashAIDataset The dataset after columns type changes. """ if not isinstance(column_types, dict): raise TypeError(f"types should be a dict, got {type(column_types)}") for column in column_types: if column in self.column_names: pass else: raise ValueError( f"Error while changing column types: column '{column}' does not " "exist in dataset." ) new_features = self.features.copy() for column in column_types: if column_types[column] == "Categorical": names = list(set(self[column])) new_features[column] = ClassLabel(names=names) elif column_types[column] == "Numerical": new_features[column] = Value("float32") dataset = self.cast(new_features) return dataset @beartype def remove_columns(self, column_names: Union[str, List[str]]) -> "DashAIDataset": """Remove one or several column(s) in the dataset and the features associated to them. Parameters ---------- column_names : Union[str, List[str]] Name, or list of names of columns to be removed. Returns ------- DashAIDataset The dataset after columns removal. """ if isinstance(column_names, str): column_names = [column_names] # Remove column from features modified_dataset = super().remove_columns(column_names) # Update self with modified dataset attributes self.__dict__.update(modified_dataset.__dict__) return self @beartype def sample( self, n: int = 1, method: Literal["head", "tail", "random"] = "head", seed: Union[int, None] = None, ) -> Dict[str, List]: """Return sample rows from dataset. Parameters ---------- n : int number of samples to return. method: Literal[str] method for selecting samples. Possible values are: 'head' to select the first n samples, 'tail' to select the last n samples and 'random' to select n random samples. seed : int, optional seed for random number generator when using 'random' method. Returns ------- Dict A dictionary with selected samples. """ if n > len(self): raise ValueError( "Number of samples must be less than or equal to the length " f"of the dataset. Number of samples: {n}, " f"dataset length: {len(self)}" ) if method == "random": rng = np.random.default_rng(seed=seed) indexes = rng.integers(low=0, high=(len(self) - 1), size=n) sample = self.select(indexes).to_dict() elif method == "head": sample = self[:n] elif method == "tail": sample = self[-n:] return sample @beartype def get_split(self, split_name: str) -> "DashAIDataset": """ Returns a new DashAIDataset corresponding to the specified split. This method uses the metadata 'split_indices' stored in the original DashAIDataset to obtain the list of indices for the desired split, then it creates a new dataset containing only those rows. Parameters: split_name (str): The name of the split to extract (e.g., "train", "test", "validation"). Returns: DashAIDataset: A new DashAIDataset instance containing only the rows of the specified split. Raises: ValueError: If the specified split is not found in the splits of the dataset. """ splits = self.splits.get("split_indices", {}) if split_name not in splits: raise ValueError(f"Split '{split_name}' not found in dataset splits.") indices = splits[split_name] subset = self.select(indices) new_splits = {"split_indices": {split_name: indices}} arrow_table = subset.with_format("arrow")[:] subset = DashAIDataset(arrow_table, splits=new_splits) return subset
@beartype def merge_splits_with_metadata(dataset_dict: DatasetDict) -> DashAIDataset: """ Merges the splits from a DatasetDict into a single DashAIDataset and records the original indices for each split in the metadata. Parameters: dataset_dict (DatasetDict): A Hugging Face DatasetDict containing multiple splits. Returns: DashAIDataset: A unified dataset with merged data and metadata containing the original split indices. """ concatenated_datasets = [] split_index = {} current_index = 0 if len(dataset_dict.keys()) == 1: arrow_tbl = get_arrow_table(dataset_dict["train"]) return DashAIDataset(arrow_tbl) for split in sorted(dataset_dict.keys()): ds = dataset_dict[split] n_rows = len(ds) split_index[split] = list(range(current_index, current_index + n_rows)) current_index += n_rows concatenated_datasets.append(ds) merged_dataset = concatenate_datasets(concatenated_datasets) arrow_tbl = get_arrow_table(merged_dataset) dashai_dataset = DashAIDataset(arrow_tbl, splits={"split_indices": split_index}) return dashai_dataset @beartype def save_dataset(dataset: DashAIDataset, path: Union[str, os.PathLike]) -> None: """ Saves a DashAIDataset in a custom format using two files in the specified directory: - "data.arrow": contains the dataset's PyArrow table. - "splits.json": contains the dataset's splits indices. Parameters: dataset (DashAIDataset): The dataset to save. path (Union[str, os.PathLike]): The directory path where the files will be saved. """ os.makedirs(path, exist_ok=True) table = dataset.arrow_table data_filepath = os.path.join(path, "data.arrow") with pa.OSFile(data_filepath, "wb") as sink: writer = ipc.new_file(sink, table.schema) writer.write_table(table) writer.close() metadata_filepath = os.path.join(path, "splits.json") with open(metadata_filepath, "w") as f: json.dump(dataset.splits, f, indent=2, sort_keys=True, ensure_ascii=False) @beartype def load_dataset(dataset_path: Union[str, os.PathLike]) -> DashAIDataset: """ Loads a DashAIDataset previously saved with save_dataset. It expects the directory at 'path' to contain: - "data.arrow": the saved PyArrow table. - "splits.json": the saved split indices. Parameters: path (Union[str, os.PathLike]): The directory path where the dataset was saved. Returns: DashAIDataset: The loaded dataset with data and metadata. """ data_filepath = os.path.join(dataset_path, "data.arrow") with pa.OSFile(data_filepath, "rb") as source: reader = ipc.open_file(source) data = reader.read_all() metadata_filepath = os.path.join(dataset_path, "splits.json") if os.path.exists(metadata_filepath): with open(metadata_filepath, "r") as f: splits = json.load(f) else: splits = {} return DashAIDataset(data, splits=splits) @beartype def check_split_values( train_size: float, test_size: float, val_size: float, ) -> None: if train_size < 0 or train_size > 1: raise ValueError( "train_size should be in the (0, 1) range " f"(0 and 1 not included), got {val_size}" ) if test_size < 0 or test_size > 1: raise ValueError( "test_size should be in the (0, 1) range " f"(0 and 1 not included), got {val_size}" ) if val_size < 0 or val_size > 1: raise ValueError( "val_size should be in the (0, 1) range " f"(0 and 1 not included), got {val_size}" ) @beartype def split_indexes( total_rows: int, train_size: float, test_size: float, val_size: float, seed: Union[int, None] = None, shuffle: bool = True, stratify: bool = False, labels: Union[List, None] = None, ) -> Tuple[List, List, List]: """Generate lists with train, test and validation indexes. The algorithm for splitting the dataset is as follows: 1. The dataset is divided into a training and a test-validation split (sum of test_size and val_size). 2. The test and validation set is generated from the test-validation set, where the size of the test-validation set is now considered to be 100%. Therefore, the sizes of the test and validation sets will now be calculated as 100%, i.e. as val_size/(test_size+val_size) and test_size/(test_size+val_size) respectively. Example: If we split a dataset into 0.8 training, a 0.1 test, and a 0.1 validation, in the first process we split the training data with 80% of the data, and the test-validation data with the remaining 20%; and then in the second process we split this 20% into 50% test and 50% validation. Parameters ---------- total_rows : int Size of the Dataset. train_size : float Proportion of the dataset for train split (in 0-1). test_size : float Proportion of the dataset for test split (in 0-1). val_size : float Proportion of the dataset for validation split (in 0-1). seed : Union[int, None], optional Set seed to control to enable replicability, by default None shuffle : bool, optional If True, the data will be shuffled when splitting the dataset, by default True. stratify : bool, optional If True, the data will be stratified when splitting the dataset, by default False. Returns ------- Tuple[List, List, List] Train, Test and Validation indexes. """ # Generate shuffled indexes if seed is None: np.random.seed(seed) indexes = np.arange(total_rows) test_val = test_size + val_size val_proportion = test_size / test_val stratify_labels = np.array(labels) if stratify else None train_indexes, test_val_indexes = train_test_split( indexes, train_size=train_size, random_state=seed, shuffle=shuffle, stratify=stratify_labels, ) stratify_labels_test_val = stratify_labels[test_val_indexes] if stratify else None test_indexes, val_indexes = train_test_split( test_val_indexes, train_size=val_proportion, random_state=seed, shuffle=shuffle, stratify=stratify_labels_test_val, ) return list(train_indexes), list(test_indexes), list(val_indexes) @beartype def split_dataset( dataset: DashAIDataset, train_indexes: List = None, test_indexes: List = None, val_indexes: List = None, ) -> DatasetDict: """ Split the dataset in train, test and validation subsets. If indexes are not provided, it will use the split indices from the dataset's splits. Parameters ---------- dataset : DashAIDataset A HuggingFace DashAIDataset containing the dataset to be split. train_indexes : List, optional Train split indexes. If None, uses indices from splits. test_indexes : List, optional Test split indexes. If None, uses indices from splits. val_indexes : List, optional Validation split indexes. If None, uses indices from splits. Returns ------- DatasetDict The split dataset. Raises ------- ValueError Must provide all indexes or none. """ if all(idx is None for idx in [train_indexes, test_indexes, val_indexes]): train_dataset = dataset.get_split("train") test_dataset = dataset.get_split("test") val_dataset = dataset.get_split("validation") return DatasetDict( { "train": train_dataset, "test": test_dataset, "validation": val_dataset, } ) elif any(idx is None for idx in [train_indexes, test_indexes, val_indexes]): raise ValueError("Must provide all indexes or none.") # Get the number of records n = len(dataset) # Convert the indexes into boolean masks train_mask = np.isin(np.arange(n), train_indexes) test_mask = np.isin(np.arange(n), test_indexes) val_mask = np.isin(np.arange(n), val_indexes) # Get the underlying table table = dataset.arrow_table dataset.splits["split_indices"] = { "train": train_indexes, "test": test_indexes, "validation": val_indexes, } # Create separate tables for each split train_table = table.filter(pa.array(train_mask)) test_table = table.filter(pa.array(test_mask)) val_table = table.filter(pa.array(val_mask)) separate_dataset_dict = DatasetDict( { "train": DashAIDataset(train_table), "test": DashAIDataset(test_table), "validation": DashAIDataset(val_table), } ) return separate_dataset_dict def to_dashai_dataset( dataset: Union[DatasetDict, Dataset, DashAIDataset], ) -> DashAIDataset: """ Converts a DatasetDict into a unified DashAIDataset. If the DatasetDict has only one split, it simply wraps it in a DashAIDataset and records its indices. If there are multiple splits, it merges them using merge_splits_with_metadata. Parameters: dataset_dict (DatasetDict): The original dataset with one or more splits. Returns: DashAIDataset: A unified dataset containing all data and metadata about the original splits. """ if isinstance(dataset, DashAIDataset): return dataset if isinstance(dataset, Dataset): arrow_tbl = get_arrow_table(dataset) return DashAIDataset(arrow_tbl) elif len(dataset) == 1: key = list(dataset.keys())[0] ds = dataset[key] arrow_tbl = get_arrow_table(ds) return DashAIDataset(arrow_tbl) else: return merge_splits_with_metadata(dataset) @beartype def validate_inputs_outputs( datasetdict: Union[DatasetDict, DashAIDataset], inputs: List[str], outputs: List[str], ) -> None: """Validate the columns to be chosen as input and output. The algorithm considers those that already exist in the dataset. Parameters ---------- names : List[str] Dataset column names. inputs : List[str] List of input column names. outputs : List[str] List of output column names. """ datasetdict = to_dashai_dataset(datasetdict) dataset_features = list((datasetdict.features).keys()) if len(inputs) == 0 or len(outputs) == 0: raise ValueError( "Inputs and outputs columns lists to validate must not be empty" ) if len(inputs) + len(outputs) > len(dataset_features): raise ValueError( "Inputs and outputs cannot have more elements than names. " f"Number of inputs: {len(inputs)}, " f"number of outputs: {len(outputs)}, " f"number of names: {len(dataset_features)}. " ) # Validate that inputs and outputs only contain elements that exist in names if not set(dataset_features).issuperset(set(inputs + outputs)): raise ValueError( f"Inputs and outputs can only contain elements that exist in names. " f"Extra elements: " f"{', '.join(set(inputs + outputs).difference(set(dataset_features)))}" ) @beartype def get_column_names_from_indexes( dataset: Union[DashAIDataset, DatasetDict], indexes: List[int] ) -> List[str]: """Obtain the column labels that correspond to the provided indexes. Note: indexing starts from 1. Parameters ---------- datasetdict : DatasetDict Path where the dataset is stored. indices : List[int] List with the indices of the columns. Returns ------- List[str] List with the labels of the columns """ dataset = to_dashai_dataset(dataset) dataset_features = list((dataset.features).keys()) col_names = [] for index in indexes: if index > len(dataset_features): raise ValueError( f"The list of indices can only contain elements within" f" the amount of columns. " f"Index {index} is greater than the total of columns." ) col_names.append(dataset_features[index - 1]) return col_names @beartype def select_columns( dataset: Union[DatasetDict, DashAIDataset], input_columns: List[str], output_columns: List[str], ) -> Tuple[DashAIDataset, DashAIDataset]: """Divide the dataset into a dataset with only the input columns in it and other dataset only with the output columns Parameters ---------- dataset : Union[DatasetDict, DashAIDataset] Dataset to divide input_columns : List[str] List with the input columns labels output_columns : List[str] List with the output columns labels Returns ------- DashAIDataset Tuple with the separated datasets x and y """ dataset = to_dashai_dataset(dataset) input_columns_dataset = to_dashai_dataset(dataset.select_columns(input_columns)) output_columns_dataset = to_dashai_dataset(dataset.select_columns(output_columns)) return (input_columns_dataset, output_columns_dataset) @beartype def get_columns_spec(dataset_path: str) -> Dict[str, Dict]: """Return the column with their respective types Parameters ---------- dataset_path : str Path where the dataset is stored. Returns ------- Dict Dict with the columns and types """ dataset = load_dataset(dataset_path) dataset_features = dataset.features column_types = {} for column in dataset_features: if dataset_features[column]._type == "Value": column_types[column] = { "type": "Value", "dtype": dataset_features[column].dtype, } elif dataset_features[column]._type == "ClassLabel": column_types[column] = { "type": "Classlabel", "dtype": "", } return column_types @beartype def update_columns_spec(dataset_path: str, columns: Dict) -> DashAIDataset: """Update the column specification of some dataset on secondary memory. Parameters ---------- dataset_path : str Path where the dataset is stored. columns : Dict Dict with columns and types to change Returns ------- Dict Dict with the columns and types """ if not isinstance(columns, dict): raise TypeError(f"types should be a dict, got {type(columns)}") # load the dataset from where its stored dataset = load_dataset(dataset_path) # copy the features with the columns ans types new_features = dataset.features for column in columns: if columns[column].type == "ClassLabel": names = list(set(dataset[column])) new_features[column] = ClassLabel(names=names) elif columns[column].type == "Value": new_features[column] = Value(columns[column].dtype) # cast the column types with the changes try: dataset = dataset.cast(new_features) except ValueError as e: raise ValueError("Error while trying to cast the columns") from e return dataset def get_dataset_info(dataset_path: str) -> object: """Return the info of the dataset with the number of rows, number of columns and splits size. Parameters ---------- dataset_path : str Path where the dataset is stored. Returns ------- object Dictionary with the information of the dataset """ dataset = load_dataset(dataset_path=dataset_path) total_rows = dataset.num_rows total_columns = len(dataset.features) splits = dataset.splits.get("split_indices", {}) train_indices = splits.get("train", []) test_indices = splits.get("test", []) val_indices = splits.get("validation", []) train_size = len(train_indices) test_size = len(test_indices) val_size = len(val_indices) dataset_info = { "total_rows": total_rows, "total_columns": total_columns, "train_size": train_size, "test_size": test_size, "val_size": val_size, "train_indices": train_indices, "test_indices": test_indices, "val_indices": val_indices, } return dataset_info @beartype def update_dataset_splits( dataset: DashAIDataset, new_splits: object, is_random: bool ) -> DashAIDataset: """Update the metadata splits of a DashAIDataset. The splits could be random by giving numbers between 0 and 1 in new_splits parameters and setting the is_random parameter to True, or the could be manually selected by giving lists of indices to new_splits parameter and setting the is_random parameter to False. Args: dataset (DashAIDataset: Dataset to update splits new_splits (object): Object with the new train, test and validation config is_random (bool): If the new splits are random by percentage Returns: DashAIDataset: New DashAIDataset with the new splits configuration. """ n = dataset.num_rows if is_random: check_split_values( new_splits["train"], new_splits["test"], new_splits["validation"] ) train_indexes, test_indexes, val_indexes = split_indexes( n, new_splits["train"], new_splits["test"], new_splits["validation"] ) else: train_indexes = new_splits["train"] test_indexes = new_splits["test"] val_indexes = new_splits["validation"] dataset.splits["split_indices"] = { "train": train_indexes, "test": test_indexes, "validation": val_indexes, } return dataset def prepare_for_experiment( dataset: DashAIDataset, splits: dict, output_columns: List[str] ) -> DatasetDict: """Prepare the dataset for an experiment by updating the splits configuration""" splitType = splits.get("splitType") if splitType == "manual" or splitType == "predefined": splits_index = splits prepared_dataset = split_dataset( dataset, train_indexes=splits_index["train"], test_indexes=splits_index["test"], val_indexes=splits_index["validation"], ) else: n = len(dataset) if splits.get("stratify", False): output_column = output_columns labels = dataset[output_column] else: labels = None train_indexes, test_indexes, val_indexes = split_indexes( n, splits["train"], splits["test"], splits["validation"], shuffle=splits.get("shuffle", False), seed=splits.get("seed"), stratify=splits.get("stratify", False), labels=labels, ) prepared_dataset = split_dataset( dataset, train_indexes=train_indexes, test_indexes=test_indexes, val_indexes=val_indexes, ) return prepared_dataset