Source code for DashAI.back.job.model_job

import gc
import json
import logging
import os
import pickle
from typing import List

from kink import inject
from sqlalchemy import exc
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.attributes import flag_modified

from DashAI.back.core.enums.metrics import LevelEnum, SplitEnum
from DashAI.back.dataloaders.classes.dashai_dataset import (
    DashAIDataset,
    load_dataset,
    prepare_for_experiment,
    select_columns,
    split_dataset,
)
from DashAI.back.dependencies.database.models import Dataset, Experiment, Metric, Run
from DashAI.back.job.base_job import BaseJob, JobError
from DashAI.back.metrics import BaseMetric
from DashAI.back.models import BaseModel
from DashAI.back.models.model_factory import ModelFactory
from DashAI.back.optimizers import BaseOptimizer
from DashAI.back.tasks import BaseTask

logging.basicConfig(level=logging.DEBUG)
log = logging.getLogger(__name__)


[docs] class ModelJob(BaseJob): """ModelJob class to run the model training.""" @inject def set_status_as_delivered( self, session_factory: sessionmaker = lambda di: di["session_factory"] ) -> None: """Set the status of the job as delivered.""" run_id: int = self.kwargs["run_id"] with session_factory() as db: run: Run = db.get(Run, run_id) if not run: raise JobError(f"Run {run_id} does not exist in DB.") try: run.set_status_as_delivered() db.commit() except exc.SQLAlchemyError as e: log.exception(e) raise JobError( "Internal database error", ) from e @inject def set_status_as_error( self, session_factory: sessionmaker = lambda di: di["session_factory"] ) -> None: """Set the status of the job as error.""" run_id: int = self.kwargs.get("run_id") if run_id is None: return with session_factory() as db: run: Run = db.get(Run, run_id) if not run: return try: run.set_status_as_error() db.commit() except exc.SQLAlchemyError as e: log.exception(e) @inject def get_job_name(self) -> str: """Get a descriptive name for the job.""" run_id = self.kwargs.get("run_id") if not run_id: return "Model Training" from kink import di session_factory = di["session_factory"] try: with session_factory() as db: run: Run = db.get(Run, run_id) if run and run.name: return f"Train: {run.name}" except Exception: pass return f"Model Training ({run_id})" @inject def run( self, ) -> None: from kink import di component_registry = di["component_registry"] session_factory = di["session_factory"] config = di["config"] # Get the necessary parameters run_id: int = self.kwargs["run_id"] with session_factory() as db: run: Run = db.get(Run, run_id) run.huey_id = self.kwargs.get("huey_id", None) db.commit() try: # Get the experiment, dataset, task, metrics and splits experiment: Experiment = db.get(Experiment, run.experiment_id) if not experiment: raise JobError( f"Experiment {run.experiment_id} does not exist in DB." ) dataset: Dataset = db.get(Dataset, experiment.dataset_id) if not dataset: raise JobError( f"Dataset {experiment.dataset_id} does not exist in DB." ) try: loaded_dataset: DashAIDataset = load_dataset( f"{dataset.file_path}/dataset" ) except Exception as e: log.exception(e) raise JobError( f"Can not load dataset from path {dataset.file_path}", ) from e try: task: BaseTask = component_registry[experiment.task_name]["class"]() except Exception as e: log.exception(e) raise JobError( ( f"Unable to find Task with name {experiment.task_name} " "in registry" ), ) from e try: # Get metrics from experiment train_metrics: List[BaseMetric] = [ component_registry[m]["class"] for m in experiment.train_metrics ] validation_metrics: List[BaseMetric] = [ component_registry[m]["class"] for m in experiment.validation_metrics ] test_metrics: List[BaseMetric] = [ component_registry[m]["class"] for m in experiment.test_metrics ] except Exception as e: log.exception(e) raise JobError( "Unable to find metrics associated with" f"Task {experiment.task_name} in registry", ) from e try: prepared_dataset = task.prepare_for_task( dataset=loaded_dataset, input_columns=experiment.input_columns, output_columns=experiment.output_columns, ) n_labels = task.num_labels( prepared_dataset, experiment.output_columns[0] ) splits = json.loads(experiment.splits) prepared_dataset, splits = prepare_for_experiment( dataset=prepared_dataset, splits=splits, output_columns=experiment.output_columns, ) run.split_indexes = json.dumps( { "train_indexes": splits["train_indexes"], "test_indexes": splits["test_indexes"], "val_indexes": splits["val_indexes"], } ) x, y = select_columns( prepared_dataset, experiment.input_columns, experiment.output_columns, ) x = split_dataset(x) y = split_dataset(y) except Exception as e: log.exception(e) raise JobError( f"""Can not prepare Dataset {dataset.id} for Task {experiment.task_name}""", ) from e try: run_model_class = component_registry[run.model_name]["class"] except Exception as e: log.exception(e) raise JobError( f"Unable to find Model with name {run.model_name} in registry.", ) from e try: factory = ModelFactory( run_model_class, run.parameters, run_id, x, y, train_metrics, validation_metrics, test_metrics, n_labels=n_labels, ) model: BaseModel = factory.model run_optimizable_parameters = factory.optimizable_parameters except Exception as e: log.exception(e) raise JobError( f"Unable to instantiate model using run {run_id}", ) from e try: if run_optimizable_parameters: goal_metric = component_registry[run.goal_metric] except Exception as e: log.exception(e) raise JobError( f"Metric is not compatible with the Task. {e}", ) from e try: # Optimizer configuration if run_optimizable_parameters: run_optimizer_class = component_registry[run.optimizer_name][ "class" ] optimizer: BaseOptimizer = run_optimizer_class( **run.optimizer_parameters ) except Exception as e: log.exception(e) raise JobError( f"Error instantiating optimizer {run.optimizer_name}, {e}", ) from e try: run.set_status_as_started() db.commit() except exc.SQLAlchemyError as e: log.exception(e) raise JobError( "Connection with the database failed", ) from e try: # Hyperparameter Tunning plot_paths = [] if not run_optimizable_parameters: model.train( x["train"], y["train"], x["validation"], y["validation"] ) else: optimizer.optimize( model, x, y, run_optimizable_parameters, goal_metric, task, ) model = optimizer.get_model() best_params = optimizer.get_best_params() updated_params = run.parameters.copy() for param_name, param_value in best_params.items(): updated_params[param_name]["fixed_value"] = param_value run.parameters = updated_params flag_modified(run, "parameters") db.commit() # Generate hyperparameter plot trials = optimizer.get_trials_values() plot_filenames, plots = optimizer.create_plots( trials, run_id, n_params=len(run_optimizable_parameters), goal_metric=goal_metric, ) for filename, plot in zip(plot_filenames, plots): plot_path = os.path.join(config["RUNS_PATH"], filename) with open(plot_path, "wb") as file: pickle.dump(plot, file) plot_paths.append(plot_path) except Exception as e: log.exception(e) raise JobError( f"Model training failed {e}", ) from e try: paths = plot_paths + [None] * (4 - len(plot_paths)) ( run.plot_history_path, run.plot_slice_path, run.plot_contour_path, run.plot_importance_path, ) = paths[:4] db.commit() except Exception as e: log.exception(e) raise JobError( f"Hyperparameter plot path saving failed {e}", ) from e # Calculate metrics at the end of training if not done already try: last_train_metric = ( db.query(Metric) .filter_by(run_id=run.id, split="TRAIN", level="LAST") .first() ) if not last_train_metric: model.calculate_metrics( split=SplitEnum.TRAIN, level=LevelEnum.LAST, ) last_val_metric = ( db.query(Metric) .filter_by(run_id=run.id, split="VALIDATION", level="LAST") .first() ) if not last_val_metric: model.calculate_metrics( split=SplitEnum.VALIDATION, level=LevelEnum.LAST, ) last_test_metric = ( db.query(Metric) .filter_by(run_id=run.id, split="TEST", level="LAST") .first() ) if not last_test_metric: model.calculate_metrics( split=SplitEnum.TEST, level=LevelEnum.LAST, ) except Exception as e: log.exception(e) raise JobError( f"Metric calculation failed {e}", ) from e try: run_path = os.path.join(config["RUNS_PATH"], str(run.id)) model.save(run_path) except Exception as e: log.exception(e) raise JobError( "Model saving failed", ) from e try: run.run_path = run_path db.commit() except exc.SQLAlchemyError as e: log.exception(e) run.set_status_as_error() db.commit() raise JobError( "Connection with the database failed", ) from e try: run.set_status_as_finished() db.commit() except exc.SQLAlchemyError as e: log.exception(e) raise JobError( "Connection with the database failed", ) from e except Exception as e: run.set_status_as_error() db.commit() raise e finally: gc.collect()