Source code for DashAI.back.job.explainer_job

import json
import logging
import os
import pickle
from typing import Any, Dict, Tuple

from datasets import DatasetDict
from kink import inject
from sqlalchemy import exc
from sqlalchemy.orm import sessionmaker

from DashAI.back.dataloaders.classes.dashai_dataset import (
    load_dataset,
    prepare_for_experiment,
    select_columns,
    split_dataset,
)
from DashAI.back.dependencies.database.models import (
    Dataset,
    Experiment,
    GlobalExplainer,
    LocalExplainer,
    Run,
)
from DashAI.back.explainability.global_explainer import BaseGlobalExplainer
from DashAI.back.explainability.local_explainer import BaseLocalExplainer
from DashAI.back.job.base_job import BaseJob, JobError
from DashAI.back.models import BaseModel
from DashAI.back.tasks import BaseTask

logging.basicConfig(level=logging.DEBUG)
log = logging.getLogger(__name__)


[docs] class ExplainerJob(BaseJob): """ExplainerJob class to calculate explanations.""" @inject def set_status_as_delivered( self, session_factory: sessionmaker = lambda di: di["session_factory"] ) -> None: """Set the status of the job as delivered.""" explainer_id: int = self.kwargs["explainer_id"] explainer_scope: str = self.kwargs["explainer_scope"] with session_factory() as db: if explainer_scope == "global": explainer: GlobalExplainer = db.get(GlobalExplainer, explainer_id) elif explainer_scope == "local": explainer: LocalExplainer = db.get(LocalExplainer, explainer_id) else: raise JobError(f"{explainer_scope} is an invalid explainer type") if not explainer: raise JobError( f"Explainer with id {explainer_id} does not exist in DB." ) try: explainer.set_status_as_delivered() db.commit() except exc.SQLAlchemyError as e: log.exception(e) raise JobError( "Internal database error", ) from e @inject def set_status_as_error( self, session_factory: sessionmaker = lambda di: di["session_factory"] ) -> None: """Set the status of the explainer as error.""" explainer_id: int = self.kwargs.get("explainer_id") explainer_scope: str = self.kwargs.get("explainer_scope", "") if explainer_id is None: return with session_factory() as db: try: if explainer_scope == "global": explainer = db.get(GlobalExplainer, explainer_id) elif explainer_scope == "local": explainer = db.get(LocalExplainer, explainer_id) else: return if explainer: explainer.set_status_as_error() db.commit() except exc.SQLAlchemyError as e: log.exception(e) @inject def get_job_name(self) -> str: """Get a descriptive name for the job.""" explainer_id = self.kwargs.get("explainer_id") explainer_scope = self.kwargs.get("explainer_scope", "") if not explainer_id: return f"{explainer_scope.capitalize()} Explanation" from kink import di session_factory = di["session_factory"] try: with session_factory() as db: if explainer_scope == "global": explainer = db.get(GlobalExplainer, explainer_id) elif explainer_scope == "local": explainer = db.get(LocalExplainer, explainer_id) else: return ( f"{explainer_scope.capitalize()} Explanation ({explainer_id})" ) if explainer and explainer.name: return f"Explain: {explainer.name}" if explainer and explainer.explainer_name: return f"Explain: {explainer.explainer_name.split('.')[-1]}" except Exception: pass return f"{explainer_scope.capitalize()} Explanation ({explainer_id})" @inject def _generate_global_explanation( self, explainer: BaseGlobalExplainer, dataset=Tuple[DatasetDict, DatasetDict], ) -> None: from kink import di explainer_id: int = self.kwargs["explainer_id"] session_factory = di["session_factory"] config = di["config"] with session_factory() as db: try: explanation = explainer.explain(dataset) plot = explainer.plot(explanation) except Exception as e: log.exception(e) raise JobError( "Failed to generate the explanation", ) from e try: explanation_filename = f"global_explanation_{explainer_id}.pickle" explanation_path = os.path.join( config["EXPLANATIONS_PATH"], explanation_filename ) with open(explanation_path, "wb") as file: pickle.dump(explanation, file) plot_filename = f"global_explanation_plot_{explainer_id}.pickle" plot_path = os.path.join(config["EXPLANATIONS_PATH"], plot_filename) with open(plot_path, "wb") as file: pickle.dump(plot, file) except Exception as e: log.exception(e) raise JobError( "Explanation file saving failed", ) from e try: self.explainer_db.explanation_path = explanation_path self.explainer_db.plot_path = plot_path db.commit() except Exception as e: log.exception(e) raise JobError( "Explanation path saving failed", ) from e @inject def _generate_local_explanation( self, explainer: BaseLocalExplainer, dataset: Tuple[DatasetDict, DatasetDict], splits: Dict[str, Any], task: BaseTask, same_dataset: bool, ) -> None: from kink import di explainer_id: int = self.kwargs["explainer_id"] session_factory = di["session_factory"] config = di["config"] explainer.fit(dataset, **self.explainer_db.fit_parameters) instance_id = self.explainer_db.dataset_id with session_factory() as db: instance: Dataset = db.get(Dataset, instance_id) if not instance: raise JobError( f"Dataset {instance_id} to be explained does not exist in DB." ) try: loaded_instance = load_dataset(f"{instance.file_path}/dataset") except Exception as e: log.exception(e) raise JobError( f"Can not load instance from path {instance.file_path}", ) from e try: prepared_instance = task.prepare_for_task( loaded_instance, outputs_columns=self.output_columns ) split = self.explainer_db.scope.get("split") if split not in ["train", "test", "val", "all"]: raise JobError(f"{split} is not a valid split") if split != "all": if not same_dataset: prepared_instance, splits = prepare_for_experiment( dataset=prepared_instance, splits=splits, output_columns=self.output_columns, ) prepared_instance = split_dataset( prepared_instance, train_indexes=splits["train_indexes"], test_indexes=splits["test_indexes"], val_indexes=splits["val_indexes"], )[split] prepared_instance = prepared_instance.select( range( max( 1, int( prepared_instance.num_rows * self.explainer_db.scope.get("percentage") / 100 ), ), ) ) prepared_instance = DatasetDict({"train": prepared_instance}) X, _ = select_columns( prepared_instance, self.input_columns, self.output_columns, ) except Exception as e: log.exception(e) raise JobError( f"""Can not prepare Dataset with {instance_id} to generate the local explanation.""", ) from e try: explanation = explainer.explain_instance(X) plots = explainer.plot(explanation) except Exception as e: log.exception(e) raise JobError( "Failed to generate the explanation", ) from e try: explanation_filename = f"local_explanation_{explainer_id}.json" explanation_path = os.path.join( config["EXPLANATIONS_PATH"], explanation_filename ) with open(explanation_path, "wb") as file: pickle.dump(explanation, file) plots_filename = f"local_explanation_plots_{explainer_id}.pickle" plots_path = os.path.join(config["EXPLANATIONS_PATH"], plots_filename) with open(plots_path, "wb") as file: pickle.dump(plots, file) except Exception as e: log.exception(e) raise JobError( "Explanation file saving failed", ) from e try: self.explainer_db.explanation_path = explanation_path self.explainer_db.plots_path = plots_path db.commit() except Exception as e: log.exception(e) raise JobError( "Explanation path saving failed", ) from e @inject def run( self, ) -> None: from kink import di component_registry = di["component_registry"] session_factory = di["session_factory"] explainer_id: int = self.kwargs["explainer_id"] explainer_scope: str = self.kwargs["explainer_scope"] with session_factory() as db: if explainer_scope == "global": self.explainer_db: GlobalExplainer = db.get( GlobalExplainer, explainer_id ) elif explainer_scope == "local": self.explainer_db: LocalExplainer = db.get(LocalExplainer, explainer_id) else: raise JobError(f"{explainer_scope} is an invalid explainer type") try: run: Run = db.get(Run, self.explainer_db.run_id) if not run: raise JobError( f"Run {self.explainer_db.run_id} does not exist in DB." ) experiment: Experiment = db.get(Experiment, run.experiment_id) if not experiment: raise JobError( f"Experiment {run.experiment_id} does not exist in DB." ) dataset: Dataset = db.get(Dataset, experiment.dataset_id) if not dataset: raise JobError( f"Dataset {self.explainer_db.dataset_id} does not exist in DB." ) self.input_columns = experiment.input_columns self.output_columns = experiment.output_columns try: run_model_class = component_registry[run.model_name]["class"] except Exception as e: log.exception(e) raise JobError( f"Unable to find Model with name {run.model_name} in registry.", ) from e try: model: BaseModel = run_model_class(**run.parameters) except Exception as e: log.exception(e) raise JobError("Unable to instantiate model") from e try: trained_model = model.load(run.run_path) except Exception as e: log.exception(e) raise JobError( f"Can not load model from path {run.run_path}" ) from e try: explainer_class = component_registry[ self.explainer_db.explainer_name ]["class"] except Exception as e: log.exception(e) raise JobError( f"""Unable to find the {explainer_scope} explainer with name {self.explainer_db.explainer_name} in registry.""", ) from e try: explainer = explainer_class( model=trained_model, **self.explainer_db.parameters ) except Exception as e: log.exception(e) raise JobError( f"Unable to instantiate {explainer_scope} explainer.", ) from e try: loaded_dataset: DatasetDict = load_dataset( f"{dataset.file_path}/dataset" ) except Exception as e: log.exception(e) raise JobError( f"Can not load dataset from path {dataset.file_path}", ) from e try: task: BaseTask = component_registry[experiment.task_name]["class"]() except Exception as e: log.exception(e) raise JobError( ( f"Unable to find Task with name {experiment.task_name} " "in registry" ), ) from e try: splits = json.loads(run.split_indexes) loaded_dataset = split_dataset( loaded_dataset, train_indexes=splits["train_indexes"], test_indexes=splits["test_indexes"], val_indexes=splits["val_indexes"], ) prepared_dataset: DatasetDict = task.prepare_for_task( datasetdict=loaded_dataset, outputs_columns=self.output_columns, ) data = select_columns( prepared_dataset, self.input_columns, self.output_columns, ) data_x = split_dataset( data[0], train_indexes=splits["train_indexes"], test_indexes=splits["test_indexes"], val_indexes=splits["val_indexes"], ) data_y = split_dataset( data[1], train_indexes=splits["train_indexes"], test_indexes=splits["test_indexes"], val_indexes=splits["val_indexes"], ) except Exception as e: log.exception(e) raise JobError( f"""Can not prepare dataset {dataset.id} for the explanation""", ) from e try: self.explainer_db.set_status_as_started() db.commit() except exc.SQLAlchemyError as e: log.exception(e) raise JobError( "Connection with the database failed", ) from e if explainer_scope == "global": self._generate_global_explanation( explainer=explainer, dataset=(data_x, data_y) ) elif explainer_scope == "local": same_dataset = experiment.dataset_id == self.explainer_db.dataset_id if not same_dataset: splits = experiment.splits self._generate_local_explanation( explainer=explainer, dataset=(data_x, data_y), splits=splits, task=task, same_dataset=same_dataset, ) else: raise JobError(f"{explainer_scope} is an invalid explainer type") self.explainer_db.set_status_as_finished() db.commit() except Exception as e: self.explainer_db.set_status_as_error() db.commit() raise e