Source code for DashAI.back.job.predict_job

import json
import logging
import os
from pathlib import Path
from typing import Any, List

import numpy as np
from fastapi import Depends, status
from fastapi.exceptions import HTTPException
from kink import di, inject
from sqlalchemy import exc
from sqlalchemy.orm import Session, sessionmaker

from DashAI.back.dataloaders.classes.dashai_dataset import DashAIDataset, load_dataset
from DashAI.back.dependencies.database.models import Dataset, Experiment, Run
from DashAI.back.dependencies.registry import ComponentRegistry
from DashAI.back.job.base_job import BaseJob, JobError
from DashAI.back.models.base_model import BaseModel

logging.basicConfig(level=logging.DEBUG)
log = logging.getLogger(__name__)


[docs]class PredictJob(BaseJob): """PredictJob class to run the prediction.""" def set_status_as_delivered(self) -> None: """Set the status of the job as delivered.""" run_id: int = self.kwargs["run_id"] db: Session = self.kwargs["db"] run: Run = db.get(Run, run_id) if not run: raise JobError(f"Run {run_id} does not exist in DB.") try: run.set_status_as_delivered() db.commit() except exc.SQLAlchemyError as e: log.exception(e) raise JobError( "Internal database error", ) from e @inject def run( self, component_registry: ComponentRegistry = lambda di: di["component_registry"], session_factory: sessionmaker = Depends(lambda: di["session_factory"]), config=lambda di: di["config"], ) -> List[Any]: run_id: int = self.kwargs["run_id"] id: int = self.kwargs["id"] db: Session = self.kwargs["db"] json_filename: str = self.kwargs["json_filename"] try: run: Run = db.get(Run, run_id) if not run: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Run not found" ) exp: Experiment = db.get(Experiment, run.experiment_id) if not exp: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Experiment not found", ) dataset: Dataset = db.get(Dataset, id) if not dataset: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Dataset not found", ) except exc.SQLAlchemyError as e: log.exception(e) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal database error", ) from e try: loaded_dataset: DashAIDataset = load_dataset( str(Path(f"{dataset.file_path}/dataset/")) ) except Exception as e: log.exception(e) raise JobError( "Can not load dataset from path {dataset.file_path}/dataset/" ) from e try: model = component_registry[run.model_name]["class"] trained_model: BaseModel = model.load(run.run_path) except Exception as e: log.exception(e) raise JobError(f"Model {run.model_name} not found in the registry") from e try: prepared_dataset = loaded_dataset.select_columns(exp.input_columns) y_pred_proba = np.array(trained_model.predict(prepared_dataset)) if isinstance(y_pred_proba[0], str): y_pred = y_pred_proba else: y_pred = np.argmax(y_pred_proba, axis=1) except ValueError as ve: log.error(f"Validation Error: {ve}") raise HTTPException( status_code=400, detail=f"Invalid columns selected: {str(ve)}", ) from ve except Exception as e: log.error(e) raise JobError( "Model prediction failed", ) from e try: path = str(Path(f"{config['DATASETS_PATH']}/predictions/")) os.makedirs(path, exist_ok=True) existing_files = os.listdir(path) existing_ids = [] for f in existing_files: if f.endswith(".json"): file_path = os.path.join(path, f) with open(file_path, "r") as json_file: data = json.load(json_file) existing_ids.append(data["metadata"]["id"]) next_id = max(existing_ids, default=0) + 1 json_name = f"{json_filename}.json" json_data = { "metadata": { "id": next_id, "pred_name": json_name, "run_name": run.model_name, "model_name": run.name, "dataset_name": dataset.name, "task_name": exp.task_name, }, "prediction": y_pred.tolist(), } with open(os.path.join(path, json_name), "w") as json_file: json.dump(json_data, json_file, indent=4) except Exception as e: log.exception(e) raise JobError( "Can not save prediction to json file", ) from e