import logging
import uuid
from pathlib import Path
from typing import Any, List
import numpy as np
from fastapi import status
from fastapi.exceptions import HTTPException
from kink import inject
from sqlalchemy import exc
from sqlalchemy.orm import sessionmaker
from DashAI.back.dataloaders.classes.dashai_dataset import (
DashAIDataset,
load_dataset,
save_dataset,
to_dashai_dataset,
)
from DashAI.back.dependencies.database.models import Dataset, Experiment, Prediction
from DashAI.back.job.base_job import BaseJob, JobError
from DashAI.back.models.base_model import BaseModel
from DashAI.back.tasks import BaseTask
logging.basicConfig(level=logging.DEBUG)
log = logging.getLogger(__name__)
[docs]
class PredictJob(BaseJob):
"""PredictJob class to run the prediction."""
@inject
def set_status_as_delivered(
self, session_factory: sessionmaker = lambda di: di["session_factory"]
) -> None:
"""Set the status of the job as delivered."""
prediction_id = self.kwargs.get("prediction_id")
try:
with session_factory() as db:
prediction: Prediction = db.get(Prediction, prediction_id)
if prediction:
prediction.set_status_as_delivered()
db.commit()
else:
log.error(f"Prediction with id {prediction_id} not found.")
except exc.SQLAlchemyError as e:
log.exception(f"Database error while setting prediction status: {e}")
@inject
def set_status_as_error(
self, session_factory: sessionmaker = lambda di: di["session_factory"]
) -> None:
"""Set the status of the prediction job as error."""
prediction_id = self.kwargs.get("prediction_id")
try:
with session_factory() as db:
prediction: Prediction = db.get(Prediction, prediction_id)
if prediction:
prediction.set_status_as_error()
db.commit()
else:
log.error(f"Prediction with id {prediction_id} not found.")
except exc.SQLAlchemyError as e:
log.exception(f"Database error while setting prediction status: {e}")
@inject
def get_job_name(self) -> str:
"""Get a descriptive name for the job."""
prediction_id = self.kwargs.get("prediction_id")
dataset_id = self.kwargs.get("dataset_id")
if prediction_id:
from kink import di
session_factory = di["session_factory"]
try:
with session_factory() as db:
prediction = db.get(Prediction, prediction_id)
dataset = db.get(Dataset, dataset_id)
if prediction and dataset:
return f"Predict: {prediction.run.name} on {dataset.name}"
except Exception:
pass
return f"Prediction (Prediction:{prediction_id}, Dataset:{dataset_id})"
@inject
def run(
self,
) -> List[Any]:
from kink import di
component_registry = di["component_registry"]
session_factory = di["session_factory"]
config = di["config"]
prediction_id: int = self.kwargs["prediction_id"]
manual_input_data: List[dict] = self.kwargs.get("manual_input_data", [])
with session_factory() as db:
try:
# Retrieve Prediction
prediction: Prediction = db.get(Prediction, prediction_id)
if not prediction:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Prediction not found for id {prediction_id}",
)
# Set huey_id and update status to STARTED
prediction.huey_id = self.kwargs.get("huey_id", None)
prediction.set_status_as_started()
db.commit()
dataset_id = prediction.dataset_id
# Validate input data
if not manual_input_data and not dataset_id:
prediction.set_status_as_error()
db.commit()
raise JobError(
"Either dataset_id or manual_input_data must be provided."
)
# Retrieve Experiment
exp: Experiment = db.get(Experiment, prediction.run.experiment_id)
if not exp:
prediction.set_status_as_error()
db.commit()
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Experiment not found",
)
# Retrieve Dataset if dataset_id is provided
dataset: Dataset = None
dataset_trained: Dataset = db.get(Dataset, exp.dataset_id)
if not dataset_trained:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Training dataset not found",
)
if dataset_id:
dataset: Dataset = db.get(Dataset, dataset_id)
except exc.SQLAlchemyError as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal database error",
) from e
# Retrieve Task
try:
task: BaseTask = component_registry[exp.task_name]["class"]()
except Exception as e:
prediction.set_status_as_error()
db.commit()
log.exception(e)
raise JobError(
f"Task {exp.task_name} not found in the registry",
) from e
# Load Model
try:
model = component_registry[prediction.run.model_name]["class"]
trained_model: BaseModel = model.load(prediction.run.run_path)
except Exception as e:
prediction.set_status_as_error()
db.commit()
log.exception(e)
raise JobError(
f"Model {prediction.run.model_name} not found in the registry"
) from e
# Load Dataset and make Predictions
try:
# Load training dataset for type info and label processing
train_dataset: DashAIDataset = load_dataset(
str(Path(f"{dataset_trained.file_path}/dataset/"))
)
except Exception as e:
log.exception(e)
raise JobError(
f"Cannot load training dataset from "
f"{dataset_trained.file_path}/dataset/"
) from e
try:
# Load or create prediction dataset
if dataset_id:
loaded_dataset: DashAIDataset = load_dataset(
str(Path(f"{dataset.file_path}/dataset/"))
)
else:
dataset_trained_path = str(
Path(f"{dataset_trained.file_path}/dataset/")
)
loaded_dataset = task.process_manual_input(
manual_input_data, dataset_trained_path
)
# Select input columns and make prediction
prepared_dataset = loaded_dataset.select_columns(exp.input_columns)
y_pred_proba = np.array(trained_model.predict(prepared_dataset))
# Process predictions (convert to labels for classification)
y_pred = task.process_predictions(
train_dataset, y_pred_proba, exp.output_columns[0]
)
except ValueError as ve:
prediction.set_status_as_error()
db.commit()
log.error(f"Validation Error: {ve}")
raise HTTPException(
status_code=400,
detail=f"Invalid input data: {str(ve)}",
) from ve
except TypeError as te:
log.error(f"Type Error: {te}")
raise HTTPException(
status_code=400,
detail=f"Type validation failed: {str(te)}",
) from te
except Exception as e:
prediction.set_status_as_error()
db.commit()
log.error(e)
raise JobError(
"Model prediction failed",
) from e
# Save Predictions to Arrow file
try:
# Create unique folder for predictions
path = str(Path(f"{config['DATASETS_PATH']}/predictions/"))
folder_name = str(uuid.uuid4())
full_path = Path(path) / folder_name
full_path.mkdir(parents=True, exist_ok=True)
# Add predictions to loaded dataset
dataset_with_prediction = to_dashai_dataset(
prepared_dataset.add_column(exp.output_columns[0], y_pred)
)
# Filter schema from trained dataset
trained_schema = train_dataset.types
filtered_schema = {
key: value.to_string()
for key, value in trained_schema.items()
if key in exp.input_columns + exp.output_columns
}
# Store num of rows, columns, and column names
dataset_with_prediction.compute_base_metadata()
# Save dataset with predictions
save_dataset(
dataset_with_prediction,
str(full_path / "dataset"),
filtered_schema,
)
# Update Prediction record
prediction.results_path = str(full_path)
prediction.set_status_as_finished()
db.commit()
except Exception as e:
prediction.set_status_as_error()
db.commit()
log.exception(e)
raise JobError(
"Can not save prediction to json file",
) from e