Source code for DashAI.back.job.converter_job

import logging
import re
from importlib import import_module
from pathlib import Path
from typing import Dict, List

import pyarrow as pa
from datasets.arrow_dataset import update_metadata_with_features
from datasets.features import Features
from kink import inject
from sqlalchemy import exc

from DashAI.back.api.api_v1.endpoints.converters import ConverterParams
from DashAI.back.converters.scikit_learn.converter_chain import ConverterChain
from DashAI.back.dataloaders.classes.dashai_dataset import (
    DashAIDataset,
    load_dataset,
    save_dataset,
)
from DashAI.back.dependencies.database.models import ConverterList
from DashAI.back.dependencies.database.models import Dataset as DatasetModel
from DashAI.back.dependencies.registry import ComponentRegistry
from DashAI.back.job.base_job import BaseJob, JobError

logging.basicConfig(level=logging.DEBUG)
log = logging.getLogger(__name__)


def _rebuild_dataset_with_transformed_columns(
    base: DashAIDataset,
    transformed: DashAIDataset,
    scope_column_names: List[str],
    scope_column_indexes: List[int],
) -> DashAIDataset:
    """
    Replaces specific columns in the base dataset with columns from the transformed
    dataset, preserving their original positions. Also appends any additional columns
    that were generated by the transformer at the end. Keeps the features and metadata
    consistent.

    Parameters
    ----------
    base : DashAIDataset
        The original dataset before transformation.

    transformed : DashAIDataset
        The dataset resulting from applying a transformer, containing updated and/or
        new columns.

    scope_column_names : List[str]
        Names of the columns that were originally selected for transformation.

    scope_column_indexes : List[int]
        The indices of the columns in the base dataset that were replaced.
        Must match the order of scope_column_names.

    Returns
    -------
    DashAIDataset
        A new dataset with the specified columns replaced in-place, new columns
        appended, and original metadata and split information preserved.
    """

    original_columns = base.column_names
    original_without_scope = base.remove_columns(scope_column_names)

    transformed_cols = transformed.column_names
    replacement_cols = transformed_cols[: len(scope_column_indexes)]
    new_cols = transformed_cols[len(scope_column_indexes) :]

    index_to_replacement = dict(zip(scope_column_indexes, replacement_cols))
    new_columns_order = []
    for i, col in enumerate(original_columns):
        if i in index_to_replacement:
            new_columns_order.append(index_to_replacement[i])
        else:
            new_columns_order.append(col)
    new_columns_order.extend(new_cols)

    original_table = original_without_scope.arrow_table
    transformed_table = transformed.arrow_table

    new_arrays = []
    for col in new_columns_order:
        if col in original_table.column_names:
            new_arrays.append(original_table[col])
        elif col in transformed_table.column_names:
            new_arrays.append(transformed_table[col])
        else:
            raise ValueError(f"Column '{col}' not found in any dataset")

    new_table = pa.Table.from_arrays(new_arrays, names=new_columns_order)
    new_dataset = DashAIDataset(new_table, splits=base.splits)

    features = base.features.copy()
    features.update(
        {
            col: transformed.features[col]
            for col in transformed.column_names
            if col in new_columns_order and col in transformed.features
        }
    )
    new_dataset._info.features = Features(
        {col: features[col] for col in new_columns_order if col in features}
    )

    new_dataset._data = update_metadata_with_features(
        new_dataset._data, new_dataset.features
    )

    return new_dataset


[docs]class ConverterListJob(BaseJob): """ConverterListJob class to modify a dataset by applying a sequence of converters.""" def set_status_as_delivered(self) -> None: """Set the status of the list as delivered.""" converter_list_id = self.kwargs["converter_list_id"] db = self.kwargs["db"] converter_list = db.get(ConverterList, converter_list_id) if converter_list is None: raise JobError( f"Converter list with id {converter_list_id} does not exist in DB." ) try: converter_list.set_status_as_delivered() db.commit() except exc.SQLAlchemyError as e: log.exception(e) raise JobError("Error setting converter list status as delivered") from e @inject def run( self, component_registry: ComponentRegistry = lambda di: di["component_registry"], ) -> None: def instantiate_converters( converter_name: str, converter_params: ConverterParams, camel_to_snake: re.Pattern, converter_submodule_inverse_index: Dict, ) -> object: # Get converter constructor and parameters converter_filename = camel_to_snake.sub("_", converter_name).lower() submodule = converter_submodule_inverse_index[converter_filename] module_path = f"DashAI.back.converters.{submodule}.{converter_filename}" # Import the converter try: module = import_module(module_path) converter_constructor = getattr(module, converter_name) except ImportError as e: log.exception(e) raise JobError( f"Error importing converter {converter_name}: {e}" ) from e # Get parameters or empty dict if none converter_parameters = converter_params.get("params", {}) return converter_constructor(**converter_parameters) def instantiate_chain( steps: List, camel_to_snake: re.Pattern, converter_submodule_inverse_index: Dict, ) -> ConverterChain: converter_instances = [] for converter_name, converter_params in steps: converter_instance = instantiate_converters( converter_name, converter_params, camel_to_snake, converter_submodule_inverse_index, ) converter_instances.append(converter_instance) return ConverterChain(steps=converter_instances) # Extract job parameters converter_list_id = self.kwargs["converter_list_id"] target_column_index = self.kwargs["target_column_index"] db = self.kwargs["db"] # Validate input parameters try: if converter_list_id is None or target_column_index is None: raise JobError("Converter list ID and target column index are required") converter_list = db.get(ConverterList, converter_list_id) if not converter_list: raise JobError(f"Converter list with id {converter_list_id} not found") converter_list.set_status_as_started() db.commit() except exc.SQLAlchemyError as e: log.exception(e) raise JobError("Error loading converter list info") from e # Get dataset try: dataset_id = converter_list.dataset_id dataset = db.get(DatasetModel, dataset_id) if not dataset: raise JobError(f"Dataset with id {dataset_id} not found") except exc.SQLAlchemyError as e: log.exception(e) converter_list.set_status_as_error() db.commit() raise JobError("Error loading dataset info") from e # Load dataset try: dataset_path = f"{dataset.file_path}/dataset" loaded_dataset = load_dataset(dataset_path) # Validate target column index if int(target_column_index) < 1 or int(target_column_index) > len( loaded_dataset.features ): raise JobError( f"Target column index {target_column_index} is out of bounds" ) except Exception as e: log.exception(e) converter_list.set_status_as_error() db.commit() raise JobError(f"Cannot load dataset from {dataset_path}") from e try: # Regex to convert camel case to snake case camel_to_snake = re.compile(r"(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])") # Get the absolute path to the converters directory current_file = Path(__file__) project_root = ( current_file.parent.parent.parent ) # Go up three levels to reach project root converters_base_path = project_root / "back" / "converters" if not converters_base_path.exists(): raise JobError( f"Converters directory not found at {converters_base_path}" ) # Build converter name to submodule mapping using a more functional approach converter_submodule_inverse_index = { file.stem: submodule.name for submodule in converters_base_path.iterdir() if submodule.is_dir() for file in submodule.glob("*.py") if not file.name.startswith( "_" ) # Skip __init__.py and other special files } # Get stored converter configurations converters_stored_info = converter_list.converters dataset_original_columns = loaded_dataset.column_names # Sort converters by order converters_sorted_list = sorted( converters_stored_info.items(), key=lambda x: x[1]["order"] ) # Process converters i = 0 converter_instances = [] while i < len(converters_sorted_list): converter_name = converters_sorted_list[i][0] converter_params = converters_sorted_list[i][1] # Check if it's a chain of converters if converter_name == "ConverterChain": try: n_steps = int(converter_params["params"]["steps"]) # Get the steps chain_steps = converters_sorted_list[i + 1 : i + n_steps + 1] # Instantiate chain of converters chain_instance = instantiate_chain( chain_steps, camel_to_snake, converter_submodule_inverse_index, ) # Get scope or use default scope = converter_params.get( "scope", {"columns": [], "rows": []} ) # Add converter chain to instances converter_instances.append( { "name": "ConverterChain", "instance": chain_instance, "scope": scope, } ) i += n_steps + 1 except Exception as e: log.exception(e) raise JobError( f"Error instantiating converter chain: {e}" ) from e else: # Regular converter converter_instance = instantiate_converters( converter_name, converter_params, camel_to_snake, converter_submodule_inverse_index, ) # Get scope or use default scope = converter_params.get("scope", {"columns": [], "rows": []}) # Add to instances converter_instances.append( { "name": converter_name, "instance": converter_instance, "scope": scope, } ) i += 1 # Apply each converter in sequence for converter_info in converter_instances: converter = converter_info["instance"] converter_scope = converter_info["scope"] # Process columns scope columns_scope = [column - 1 for column in converter_scope["columns"]] scope_column_indexes = sorted(set(columns_scope)) # If no columns specified, use all columns if not scope_column_indexes: scope_column_indexes = list(range(len(loaded_dataset.features))) scope_column_names = [ dataset_original_columns[index] for index in scope_column_indexes ] # Process rows scope rows_scope = [row - 1 for row in converter_scope["rows"]] scope_rows_indexes = sorted(set(rows_scope)) # Adjust target column index (0-based internally) target_column_index_0based = int(target_column_index) - 1 target_column_name = dataset_original_columns[ target_column_index_0based ] # Select data for fitting using DashAIDataset operations X_dataset = loaded_dataset.select_columns(scope_column_names) y_dataset = loaded_dataset.select_columns([target_column_name]) # Select specified rows if provided if scope_rows_indexes: X_dataset = X_dataset.select(scope_rows_indexes) y_dataset = y_dataset.select(scope_rows_indexes) try: converter = converter.fit(X_dataset, y_dataset) except Exception as e: log.exception(e) raise JobError( f"Error fitting converter {converter_name}: {e}" ) from e # Transform data using full dataset for selected columns X_full = loaded_dataset.select_columns(scope_column_names) y_full = loaded_dataset.select_columns([target_column_name]) try: transformed_dataset = converter.transform(X_full, y_full) except Exception as e: log.exception(e) raise JobError(f"Error transforming data: {e}") from e # Now we need to merge the transformed data back into the original # dataset, preserving their original positions loaded_dataset = _rebuild_dataset_with_transformed_columns( loaded_dataset, transformed_dataset, scope_column_names, scope_column_indexes, ) # Save the final dataset save_dataset(loaded_dataset, f"{dataset_path}") converter_list.set_status_as_finished() db.commit() db.refresh(dataset) except Exception as e: log.exception(e) converter_list.set_status_as_error() db.commit() raise JobError( f"Error applying converters to dataset {dataset_id}: {e}" ) from e