import logging
import re
from importlib import import_module
from pathlib import Path
from typing import Dict, List
import pyarrow as pa
from datasets.arrow_dataset import update_metadata_with_features
from datasets.features import Features
from kink import inject
from sqlalchemy import exc
from DashAI.back.api.api_v1.endpoints.converters import ConverterParams
from DashAI.back.converters.scikit_learn.converter_chain import ConverterChain
from DashAI.back.dataloaders.classes.dashai_dataset import (
DashAIDataset,
load_dataset,
save_dataset,
)
from DashAI.back.dependencies.database.models import ConverterList
from DashAI.back.dependencies.database.models import Dataset as DatasetModel
from DashAI.back.dependencies.registry import ComponentRegistry
from DashAI.back.job.base_job import BaseJob, JobError
logging.basicConfig(level=logging.DEBUG)
log = logging.getLogger(__name__)
def _rebuild_dataset_with_transformed_columns(
base: DashAIDataset,
transformed: DashAIDataset,
scope_column_names: List[str],
scope_column_indexes: List[int],
) -> DashAIDataset:
"""
Replaces specific columns in the base dataset with columns from the transformed
dataset, preserving their original positions. Also appends any additional columns
that were generated by the transformer at the end. Keeps the features and metadata
consistent.
Parameters
----------
base : DashAIDataset
The original dataset before transformation.
transformed : DashAIDataset
The dataset resulting from applying a transformer, containing updated and/or
new columns.
scope_column_names : List[str]
Names of the columns that were originally selected for transformation.
scope_column_indexes : List[int]
The indices of the columns in the base dataset that were replaced.
Must match the order of scope_column_names.
Returns
-------
DashAIDataset
A new dataset with the specified columns replaced in-place, new columns
appended, and original metadata and split information preserved.
"""
original_columns = base.column_names
original_without_scope = base.remove_columns(scope_column_names)
transformed_cols = transformed.column_names
replacement_cols = transformed_cols[: len(scope_column_indexes)]
new_cols = transformed_cols[len(scope_column_indexes) :]
index_to_replacement = dict(zip(scope_column_indexes, replacement_cols))
new_columns_order = []
for i, col in enumerate(original_columns):
if i in index_to_replacement:
new_columns_order.append(index_to_replacement[i])
else:
new_columns_order.append(col)
new_columns_order.extend(new_cols)
original_table = original_without_scope.arrow_table
transformed_table = transformed.arrow_table
new_arrays = []
for col in new_columns_order:
if col in original_table.column_names:
new_arrays.append(original_table[col])
elif col in transformed_table.column_names:
new_arrays.append(transformed_table[col])
else:
raise ValueError(f"Column '{col}' not found in any dataset")
new_table = pa.Table.from_arrays(new_arrays, names=new_columns_order)
new_dataset = DashAIDataset(new_table, splits=base.splits)
features = base.features.copy()
features.update(
{
col: transformed.features[col]
for col in transformed.column_names
if col in new_columns_order and col in transformed.features
}
)
new_dataset._info.features = Features(
{col: features[col] for col in new_columns_order if col in features}
)
new_dataset._data = update_metadata_with_features(
new_dataset._data, new_dataset.features
)
return new_dataset
[docs]class ConverterListJob(BaseJob):
"""ConverterListJob class to modify a dataset by applying a
sequence of converters."""
def set_status_as_delivered(self) -> None:
"""Set the status of the list as delivered."""
converter_list_id = self.kwargs["converter_list_id"]
db = self.kwargs["db"]
converter_list = db.get(ConverterList, converter_list_id)
if converter_list is None:
raise JobError(
f"Converter list with id {converter_list_id} does not exist in DB."
)
try:
converter_list.set_status_as_delivered()
db.commit()
except exc.SQLAlchemyError as e:
log.exception(e)
raise JobError("Error setting converter list status as delivered") from e
@inject
def run(
self,
component_registry: ComponentRegistry = lambda di: di["component_registry"],
) -> None:
def instantiate_converters(
converter_name: str,
converter_params: ConverterParams,
camel_to_snake: re.Pattern,
converter_submodule_inverse_index: Dict,
) -> object:
# Get converter constructor and parameters
converter_filename = camel_to_snake.sub("_", converter_name).lower()
submodule = converter_submodule_inverse_index[converter_filename]
module_path = f"DashAI.back.converters.{submodule}.{converter_filename}"
# Import the converter
try:
module = import_module(module_path)
converter_constructor = getattr(module, converter_name)
except ImportError as e:
log.exception(e)
raise JobError(
f"Error importing converter {converter_name}: {e}"
) from e
# Get parameters or empty dict if none
converter_parameters = converter_params.get("params", {})
return converter_constructor(**converter_parameters)
def instantiate_chain(
steps: List,
camel_to_snake: re.Pattern,
converter_submodule_inverse_index: Dict,
) -> ConverterChain:
converter_instances = []
for converter_name, converter_params in steps:
converter_instance = instantiate_converters(
converter_name,
converter_params,
camel_to_snake,
converter_submodule_inverse_index,
)
converter_instances.append(converter_instance)
return ConverterChain(steps=converter_instances)
# Extract job parameters
converter_list_id = self.kwargs["converter_list_id"]
target_column_index = self.kwargs["target_column_index"]
db = self.kwargs["db"]
# Validate input parameters
try:
if converter_list_id is None or target_column_index is None:
raise JobError("Converter list ID and target column index are required")
converter_list = db.get(ConverterList, converter_list_id)
if not converter_list:
raise JobError(f"Converter list with id {converter_list_id} not found")
converter_list.set_status_as_started()
db.commit()
except exc.SQLAlchemyError as e:
log.exception(e)
raise JobError("Error loading converter list info") from e
# Get dataset
try:
dataset_id = converter_list.dataset_id
dataset = db.get(DatasetModel, dataset_id)
if not dataset:
raise JobError(f"Dataset with id {dataset_id} not found")
except exc.SQLAlchemyError as e:
log.exception(e)
converter_list.set_status_as_error()
db.commit()
raise JobError("Error loading dataset info") from e
# Load dataset
try:
dataset_path = f"{dataset.file_path}/dataset"
loaded_dataset = load_dataset(dataset_path)
# Validate target column index
if int(target_column_index) < 1 or int(target_column_index) > len(
loaded_dataset.features
):
raise JobError(
f"Target column index {target_column_index} is out of bounds"
)
except Exception as e:
log.exception(e)
converter_list.set_status_as_error()
db.commit()
raise JobError(f"Cannot load dataset from {dataset_path}") from e
try:
# Regex to convert camel case to snake case
camel_to_snake = re.compile(r"(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])")
# Get the absolute path to the converters directory
current_file = Path(__file__)
project_root = (
current_file.parent.parent.parent
) # Go up three levels to reach project root
converters_base_path = project_root / "back" / "converters"
if not converters_base_path.exists():
raise JobError(
f"Converters directory not found at {converters_base_path}"
)
# Build converter name to submodule mapping using a more functional approach
converter_submodule_inverse_index = {
file.stem: submodule.name
for submodule in converters_base_path.iterdir()
if submodule.is_dir()
for file in submodule.glob("*.py")
if not file.name.startswith(
"_"
) # Skip __init__.py and other special files
}
# Get stored converter configurations
converters_stored_info = converter_list.converters
dataset_original_columns = loaded_dataset.column_names
# Sort converters by order
converters_sorted_list = sorted(
converters_stored_info.items(), key=lambda x: x[1]["order"]
)
# Process converters
i = 0
converter_instances = []
while i < len(converters_sorted_list):
converter_name = converters_sorted_list[i][0]
converter_params = converters_sorted_list[i][1]
# Check if it's a chain of converters
if converter_name == "ConverterChain":
try:
n_steps = int(converter_params["params"]["steps"])
# Get the steps
chain_steps = converters_sorted_list[i + 1 : i + n_steps + 1]
# Instantiate chain of converters
chain_instance = instantiate_chain(
chain_steps,
camel_to_snake,
converter_submodule_inverse_index,
)
# Get scope or use default
scope = converter_params.get(
"scope", {"columns": [], "rows": []}
)
# Add converter chain to instances
converter_instances.append(
{
"name": "ConverterChain",
"instance": chain_instance,
"scope": scope,
}
)
i += n_steps + 1
except Exception as e:
log.exception(e)
raise JobError(
f"Error instantiating converter chain: {e}"
) from e
else:
# Regular converter
converter_instance = instantiate_converters(
converter_name,
converter_params,
camel_to_snake,
converter_submodule_inverse_index,
)
# Get scope or use default
scope = converter_params.get("scope", {"columns": [], "rows": []})
# Add to instances
converter_instances.append(
{
"name": converter_name,
"instance": converter_instance,
"scope": scope,
}
)
i += 1
# Apply each converter in sequence
for converter_info in converter_instances:
converter = converter_info["instance"]
converter_scope = converter_info["scope"]
# Process columns scope
columns_scope = [column - 1 for column in converter_scope["columns"]]
scope_column_indexes = sorted(set(columns_scope))
# If no columns specified, use all columns
if not scope_column_indexes:
scope_column_indexes = list(range(len(loaded_dataset.features)))
scope_column_names = [
dataset_original_columns[index] for index in scope_column_indexes
]
# Process rows scope
rows_scope = [row - 1 for row in converter_scope["rows"]]
scope_rows_indexes = sorted(set(rows_scope))
# Adjust target column index (0-based internally)
target_column_index_0based = int(target_column_index) - 1
target_column_name = dataset_original_columns[
target_column_index_0based
]
# Select data for fitting using DashAIDataset operations
X_dataset = loaded_dataset.select_columns(scope_column_names)
y_dataset = loaded_dataset.select_columns([target_column_name])
# Select specified rows if provided
if scope_rows_indexes:
X_dataset = X_dataset.select(scope_rows_indexes)
y_dataset = y_dataset.select(scope_rows_indexes)
try:
converter = converter.fit(X_dataset, y_dataset)
except Exception as e:
log.exception(e)
raise JobError(
f"Error fitting converter {converter_name}: {e}"
) from e
# Transform data using full dataset for selected columns
X_full = loaded_dataset.select_columns(scope_column_names)
y_full = loaded_dataset.select_columns([target_column_name])
try:
transformed_dataset = converter.transform(X_full, y_full)
except Exception as e:
log.exception(e)
raise JobError(f"Error transforming data: {e}") from e
# Now we need to merge the transformed data back into the original
# dataset, preserving their original positions
loaded_dataset = _rebuild_dataset_with_transformed_columns(
loaded_dataset,
transformed_dataset,
scope_column_names,
scope_column_indexes,
)
# Save the final dataset
save_dataset(loaded_dataset, f"{dataset_path}")
converter_list.set_status_as_finished()
db.commit()
db.refresh(dataset)
except Exception as e:
log.exception(e)
converter_list.set_status_as_error()
db.commit()
raise JobError(
f"Error applying converters to dataset {dataset_id}: {e}"
) from e