import json
import logging
import os
import pickle
from typing import List
from kink import inject
from sqlalchemy import exc
from sqlalchemy.orm import Session
from DashAI.back.dataloaders.classes.dashai_dataset import (
DashAIDataset,
load_dataset,
prepare_for_experiment,
select_columns,
split_dataset,
)
from DashAI.back.dependencies.database.models import Dataset, Experiment, Run
from DashAI.back.dependencies.registry import ComponentRegistry
from DashAI.back.job.base_job import BaseJob, JobError
from DashAI.back.metrics import BaseMetric
from DashAI.back.models import BaseModel
from DashAI.back.models.model_factory import ModelFactory
from DashAI.back.optimizers import BaseOptimizer
from DashAI.back.tasks import BaseTask
logging.basicConfig(level=logging.DEBUG)
log = logging.getLogger(__name__)
[docs]class ModelJob(BaseJob):
"""ModelJob class to run the model training."""
def set_status_as_delivered(self) -> None:
"""Set the status of the job as delivered."""
run_id: int = self.kwargs["run_id"]
db: Session = self.kwargs["db"]
run: Run = db.get(Run, run_id)
if not run:
raise JobError(f"Run {run_id} does not exist in DB.")
try:
run.set_status_as_delivered()
db.commit()
except exc.SQLAlchemyError as e:
log.exception(e)
raise JobError(
"Internal database error",
) from e
@inject
def run(
self,
component_registry: ComponentRegistry = lambda di: di["component_registry"],
config=lambda di: di["config"],
) -> None:
from DashAI.back.api.api_v1.endpoints.components import (
_intersect_component_lists,
)
# Get the necessary parameters
run_id: int = self.kwargs["run_id"]
db: Session = self.kwargs["db"]
run: Run = db.get(Run, run_id)
try:
# Get the experiment, dataset, task, metrics and splits
experiment: Experiment = db.get(Experiment, run.experiment_id)
if not experiment:
raise JobError(f"Experiment {run.experiment_id} does not exist in DB.")
dataset: Dataset = db.get(Dataset, experiment.dataset_id)
if not dataset:
raise JobError(f"Dataset {experiment.dataset_id} does not exist in DB.")
try:
loaded_dataset: DashAIDataset = load_dataset(
f"{dataset.file_path}/dataset"
)
except Exception as e:
log.exception(e)
raise JobError(
f"Can not load dataset from path {dataset.file_path}",
) from e
try:
task: BaseTask = component_registry[experiment.task_name]["class"]()
except Exception as e:
log.exception(e)
raise JobError(
f"Unable to find Task with name {experiment.task_name} in registry",
) from e
try:
# Get all the metrics
all_metrics = {
component_dict["name"]: component_dict
for component_dict in component_registry.get_components_by_types(
select="Metric"
)
}
# Get the intersection between the metrics and the task
# related components
selected_metrics = _intersect_component_lists(
all_metrics,
component_registry.get_related_components(experiment.task_name),
)
metrics: List[BaseMetric] = [
metric["class"] for metric in selected_metrics.values()
]
except Exception as e:
log.exception(e)
raise JobError(
"Unable to find metrics associated with"
f"Task {experiment.task_name} in registry",
) from e
try:
prepared_dataset = task.prepare_for_task(
loaded_dataset, experiment.output_columns
)
splits = json.loads(experiment.splits)
prepared_dataset = prepare_for_experiment(
dataset=prepared_dataset,
splits=splits,
output_columns=experiment.output_columns,
)
x, y = select_columns(
prepared_dataset,
experiment.input_columns,
experiment.output_columns,
)
x = split_dataset(x)
y = split_dataset(y)
except Exception as e:
log.exception(e)
raise JobError(
f"""Can not prepare Dataset {dataset.id}
for Task {experiment.task_name}""",
) from e
try:
run_model_class = component_registry[run.model_name]["class"]
except Exception as e:
log.exception(e)
raise JobError(
f"Unable to find Model with name {run.model_name} in registry.",
) from e
try:
factory = ModelFactory(run_model_class, run.parameters)
model: BaseModel = factory.model
run_optimizable_parameters = factory.optimizable_parameters
except Exception as e:
log.exception(e)
raise JobError(
f"Unable to instantiate model using run {run_id}",
) from e
if experiment.task_name in [
"TextClassificationTask",
"TabularClassificationTask",
]:
try:
# Optimizer configuration
run_optimizer_class = component_registry[run.optimizer_name][
"class"
]
except Exception as e:
log.exception(e)
raise JobError(
f"Unable to find Model with name {run.optimizer_name} in "
"registry.",
) from e
try:
goal_metric = selected_metrics[run.goal_metric]
except Exception as e:
log.exception(e)
raise JobError(
"Metric is not compatible with the Task",
) from e
try:
optimizer: BaseOptimizer = run_optimizer_class(
**run.optimizer_parameters
)
except Exception as e:
log.exception(e)
raise JobError(
"Optimizer parameters are not compatible with the optimizer",
) from e
try:
run.set_status_as_started()
db.commit()
except exc.SQLAlchemyError as e:
log.exception(e)
raise JobError(
"Connection with the database failed",
) from e
try:
# Hyperparameter Tunning
if not run_optimizable_parameters:
model.fit(x["train"], y["train"])
else:
optimizer.optimize(
model,
x,
y,
run_optimizable_parameters,
goal_metric,
experiment.task_name,
)
model = optimizer.get_model()
# Generate hyperparameter plot
trials = optimizer.get_trials_values()
plot_filenames, plots = optimizer.create_plots(
trials, run_id, n_params=len(run_optimizable_parameters)
)
plot_paths = []
for filename, plot in zip(plot_filenames, plots):
plot_path = os.path.join(config["RUNS_PATH"], filename)
with open(plot_path, "wb") as file:
pickle.dump(plot, file)
plot_paths.append(plot_path)
except Exception as e:
log.exception(e)
raise JobError(
"Model training failed",
) from e
if run_optimizable_parameters != {}:
if len(run_optimizable_parameters) >= 2:
try:
run.plot_history_path = plot_paths[0]
run.plot_slice_path = plot_paths[1]
run.plot_contour_path = plot_paths[2]
run.plot_importance_path = plot_paths[3]
db.commit()
except Exception as e:
log.exception(e)
raise JobError(
"Hyperparameter plot path saving failed",
) from e
else:
try:
run.plot_history_path = plot_paths[0]
run.plot_slice_path = plot_paths[1]
db.commit()
except Exception as e:
log.exception(e)
raise JobError(
"Hyperparameter plot path saving failed",
) from e
try:
run.set_status_as_finished()
db.commit()
except exc.SQLAlchemyError as e:
log.exception(e)
raise JobError(
"Connection with the database failed",
) from e
try:
model_metrics = factory.evaluate(x, y, metrics)
except Exception as e:
log.exception(e)
raise JobError(
"Metrics calculation failed",
) from e
run.train_metrics = model_metrics["train"]
run.validation_metrics = model_metrics["validation"]
run.test_metrics = model_metrics["test"]
try:
run_path = os.path.join(config["RUNS_PATH"], str(run.id))
model.save(run_path)
except Exception as e:
log.exception(e)
raise JobError(
"Model saving failed",
) from e
try:
run.run_path = run_path
db.commit()
except exc.SQLAlchemyError as e:
log.exception(e)
run.set_status_as_error()
db.commit()
raise JobError(
"Connection with the database failed",
) from e
except Exception as e:
run.set_status_as_error()
db.commit()
raise e