Source code for DashAI.back.job.dataset_job

import gc
import json
import logging
import os
import shutil
import uuid

from kink import inject
from sqlalchemy import exc
from sqlalchemy.orm import sessionmaker

from DashAI.back.api.api_v1.schemas.datasets_params import DatasetParams
from DashAI.back.api.utils import parse_params
from DashAI.back.dataloaders.classes.dashai_dataset import load_dataset, save_dataset
from DashAI.back.dependencies.database.models import Dataset, Notebook
from DashAI.back.job.base_job import BaseJob, JobError

log = logging.getLogger(__name__)


[docs] class DatasetJob(BaseJob): """ Job for processing and uploading datasets using streaming data processing. Parameters ---------- kwargs : Dict[str, Any] A dictionary containing the parameters for the job, including: - name: Name of the dataset - datatype_name: Name of the datatype - params: Parameters for the datatype - file_path: Path to the temporarily saved file - temp_dir: Directory containing the temporary file - filename: Name of the uploaded file - db: Database session """ @inject def set_status_as_delivered( self, session_factory: sessionmaker = lambda di: di["session_factory"] ) -> None: """Set the status of the dataset as delivered.""" dataset_id: int = self.kwargs["dataset_id"] with session_factory() as db: dataset: Dataset = db.get(Dataset, dataset_id) if dataset is None: raise JobError(f"Dataset with id {dataset_id} not found.") try: dataset.set_status_as_delivered() db.commit() except exc.SQLAlchemyError as e: log.exception(e) raise JobError( "Error while setting the status of the dataset as delivered." ) from e @inject def set_status_as_error( self, session_factory: sessionmaker = lambda di: di["session_factory"] ) -> None: """Set the job status as error.""" dataset_id: int = self.kwargs["dataset_id"] with session_factory() as db: dataset: Dataset = db.get(Dataset, dataset_id) if dataset is None: raise JobError(f"Dataset with id {dataset_id} not found.") try: dataset.set_status_as_error() db.commit() except exc.SQLAlchemyError as e: log.exception(e) raise JobError( "Error while setting the status of the dataset as error." ) from e def get_job_name(self) -> str: """Get a descriptive name for the job.""" name = self.kwargs.get("name", "") if name: return f"Dataset: {name}" params = self.kwargs.get("params", {}) if params and isinstance(params, dict) and "name" in params: return f"Dataset: {params['name']}" return "Dataset load" @inject def run( self, ) -> None: from kink import di component_registry = di["component_registry"] session_factory = di["session_factory"] config = di["config"] dataset_id = self.kwargs.get("dataset_id") notebook_id = self.kwargs.get("notebook_id", None) params = self.kwargs.get("params", {}) file_path = self.kwargs.get("file_path") temp_dir = self.kwargs.get("temp_dir") url = self.kwargs.get("url", "") try: with session_factory() as db: dataset = db.get(Dataset, dataset_id) if not dataset: raise JobError(f"Dataset with ID {dataset_id} not found.") dataset.set_status_as_started() db.commit() db.refresh(dataset) random_name = str(uuid.uuid4()) folder_path = config["DATASETS_PATH"] / random_name try: log.debug("Trying to create a new dataset path: %s", folder_path) folder_path.mkdir(parents=True) except FileExistsError as e: log.exception(e) raise JobError( f"A dataset with the name {random_name} already exists." ) from e try: if notebook_id is not None: log.debug(f"Copying dataset from notebook id {notebook_id}.") with session_factory() as db: notebook_dataset = ( db.query(Notebook) .filter(Notebook.id == notebook_id) .first() ) if not notebook_dataset: msg = ( "Notebook with ID " f"{notebook_id}" " has no associated dataset." ) raise JobError(msg) new_dataset = load_dataset( os.path.join(notebook_dataset.file_path, "dataset") ) else: parsed_params = parse_params(DatasetParams, json.dumps(params)) dataloader = component_registry[parsed_params.dataloader]["class"]() log.debug("Storing dataset in %s", folder_path) new_dataset = dataloader.load_data( filepath_or_buffer=( str(file_path) if file_path is not None else url ), temp_path=str(temp_dir), params=parsed_params.model_dump(), ) # Calculate nan per column new_dataset.nan_per_column() gc.collect() dataset_save_path = folder_path / "dataset" log.debug("Saving dataset in %s", str(dataset_save_path)) save_dataset(new_dataset, dataset_save_path) except Exception as e: log.exception(e) shutil.rmtree(folder_path, ignore_errors=True) raise JobError(f"Error loading dataset: {str(e)}") from e # Add dataset to database with session_factory() as db: log.debug("Storing dataset metadata in database.") try: folder_path = os.path.realpath(folder_path) dataset = db.get(Dataset, dataset_id) dataset.file_path = folder_path dataset.set_status_as_finished() db.commit() db.refresh(dataset) except exc.SQLAlchemyError as e: log.exception(e) shutil.rmtree(folder_path, ignore_errors=True) raise JobError("Internal database error") from e log.debug("Dataset creation successfully finished.") except JobError as e: log.error(f"Dataset creation failed: {e}") with session_factory() as db: dataset = db.get(Dataset, dataset_id) if dataset: dataset.set_status_as_error() db.commit() db.refresh(dataset) raise e finally: gc.collect() if temp_dir and os.path.exists(temp_dir): try: shutil.rmtree(temp_dir, ignore_errors=True) except Exception as e: log.exception(f"Error cleaning up temporary directory: {e}")