Source code for DashAI.back.models.base_model

"""Base Model abstract class."""

from abc import ABCMeta, abstractmethod
from typing import Any, Dict, Final, final

from kink import di

from DashAI.back.config_object import ConfigObject
from DashAI.back.core.enums.metrics import LevelEnum, SplitEnum
from DashAI.back.dataloaders.classes.dashai_dataset import DashAIDataset
from DashAI.back.dependencies.database.models import Metric


[docs] class BaseModel(ConfigObject, metaclass=ABCMeta): """Abstract class of all machine learning models. All models must extend this class and implement save, load, train and predict methods. """ TYPE: Final[str] = "Model" DISPLAY_NAME: str = "" DESCRIPTION: str = "" COLOR: str = "#795548" ICON: str = "Science" @classmethod def get_metadata(cls) -> Dict[str, Any]: """Get metadata values for the current model. Returns ------- Dict[str, Any] Dictionary with the metadata including icon. """ metadata: Dict[str, Any] = {} metadata["icon"] = cls.ICON if cls.ICON else "Science" return metadata @abstractmethod def save(self, filename: str) -> None: """Store an instance of a model. filename (Str): Indicates where to store the model, if filename is None, this method returns a bytes array with the model. """ raise NotImplementedError @abstractmethod def load(self, filename: str) -> Any: """Restores an instance of a model. filename (Str): Indicates where the model was stored. """ raise NotImplementedError @abstractmethod def train( self, x_train: DashAIDataset, y_train: DashAIDataset, x_validation: DashAIDataset = None, y_validation: DashAIDataset = None, ) -> "BaseModel": """Train the model with the provided data. Parameters ---------- x_train : DashAIDataset The input features for training. y_train : DashAIDataset The target labels for training. x_validation : DashAIDataset, optional The input features for validation. y_validation : DashAIDataset, optional The target labels for validation. Returns ------- BaseModel The trained model instance. """ raise NotImplementedError @final def _save_metrics( self, split: SplitEnum, level: LevelEnum, results: Dict[str, float], log_index: int = None, ): with di["session_factory"]() as db: # Initialize tracking dict if not exists if not hasattr(self, "_metric_step_counters"): self._metric_step_counters = {} # Create a unique key for this run/split/level combination counter_key = (self.run_id, split, level) # 1. Determine log_index if counter_key not in self._metric_step_counters: steps = ( db.query(Metric.step) .filter_by(run_id=self.run_id, split=split, level=level) .order_by(Metric.step.desc()) .limit(2) .all() ) if not steps: current, previous = 0, 0 elif len(steps) == 1: current, previous = steps[0][0], 0 else: current, previous = steps[0][0], steps[1][0] self._metric_step_counters[counter_key] = { "current": current, "previous": previous, } counter = self._metric_step_counters[counter_key] current_max = counter["current"] previous_max = counter["previous"] # Compute delta (preserve spacing) delta = current_max - previous_max if delta <= 0: delta = 1 # Case 1: no log_index -> advance naturally if log_index is None or log_index <= current_max: log_index = current_max + delta # Update the in-memory tracker counter["previous"] = current_max counter["current"] = log_index # 2. Handle 'LAST' level replacement logic if level == LevelEnum.LAST: for name, value in results.items(): existing = ( db.query(Metric) .filter_by( run_id=self.run_id, split=split, level=level, name=name ) .first() ) if existing: existing.value = value existing.step = log_index else: db.add( Metric( run_id=self.run_id, split=split, level=level, name=name, value=value, step=log_index, ) ) # 3. Standard logging (STEP, BATCH, TRIAL) - just insert else: metric_entries = [ Metric( run_id=self.run_id, split=split, level=level, name=name, value=score, step=log_index, ) for name, score in results.items() ] db.add_all(metric_entries) db.commit() @final def calculate_metrics( self, split: SplitEnum = SplitEnum.VALIDATION, level: LevelEnum = LevelEnum.LAST, log_index: int = None, x_data: DashAIDataset = None, y_data: DashAIDataset = None, ): """ Calculate and save metrics for a given data split and level. Parameters ---------- split : SplitEnum, default=SplitEnum.VALIDATION The data split (TRAIN, VALIDATION, TEST). level : LevelEnum, default=LevelEnum.LAST The metric level (LAST, TRIAL, STEP, BATCH). log_index : int, optional The index for logging purposes. If None, it will save the metric as last index + 1. x_data : DashAIDataset, optional The input features for the split. If None, the stored dataset associated with the split is used. y_data : DashAIDataset, optional The target labels for the split. If None, the stored labels associated with the split are used. """ # Get the appropriate metrics based on split metrics_attr = f"{split.value}_metrics" metrics = getattr(self, metrics_attr, None) # If no metrics or run_id, skip calculation if not metrics or not self.run_id: return # Load data if not provided if x_data is None or y_data is None: if self.x_data is None or self.y_data is None: return x_data = self.x_data[split.value] y_data = self.y_data[split.value] # If data is empty after retrieval, skip calculation if x_data is None or y_data is None: return # Make predictions and transform outputs y_pred = self.predict(x_data) y_transformed = self.prepare_output(y_data, is_fit=False) # Calculate metric scores results = {} for metric in metrics: score = metric.score(y_transformed, y_pred) results[metric.__name__] = score # Save to database self._save_metrics( split=split, level=level, results=results, log_index=log_index ) def prepare_dataset( self, dataset: DashAIDataset, is_fit: bool = False ) -> DashAIDataset: """Hook for model-specific preprocessing of input features. Override in subclasses needing custom tokenization/encoding. Must not mutate input in-place. Parameters ---------- dataset : DashAIDataset The dataset to be transformed. is_fit : bool Whether the dataset is for fitting or not. Returns ------- DashAIDataset The prepared dataset ready to be converted to an accepted format in the model. """ return dataset def prepare_output( self, dataset: DashAIDataset, is_fit: bool = False ) -> DashAIDataset: """Hook for model-specific preprocessing of output targets. Parameters ---------- dataset : DashAIDataset The output dataset to be transformed. is_fit : bool Whether the dataset is for fitting or not. Returns ------- DashAIDataset The prepared output dataset. """ return self.prepare_dataset(dataset, is_fit)