Source code for DashAI.back.optimizers.base_optimizer

"""Base Optimizer abstract class."""

import logging
from abc import ABCMeta, abstractmethod
from typing import Final

import numpy as np
import optuna
import plotly
import plotly.graph_objects as go
from optuna.importance import FanovaImportanceEvaluator

from DashAI.back.config_object import ConfigObject

log = logging.getLogger(__name__)


[docs]class BaseOptimizer(ConfigObject, metaclass=ABCMeta): """ Abstract class of all hyperparameter's Optimizers. All models must extend this class and implement optimize method. """ TYPE: Final[str] = "Optimizer" @abstractmethod def optimize(self, model, input, output, parameters, task): """ Optimization process Args: model (class): class for the model from the current experiment dataset (dict): dict with the data to train and validation parameters (dict): dict with the information to create the search space Returns ------- None """ raise NotImplementedError( "Optimization modules must implement optimize method." ) @abstractmethod def get_model(self): """ Get the model with the best set of hyperparameters found Returns ------- best_model (object): Object from the class model with the best hyperparameters found. """ raise NotImplementedError( "Optimization modules must implement get_model method." ) @abstractmethod def get_trials_values(self): """ Get the trial values from the hyperparameter optimization process Returns ------- trial_values (list): List with the hyperparameters values and the goal metric per trial. """ raise NotImplementedError( "Optimization modules must implement get_trials_values method." ) def history_objective_plot(self, trials): """ Plot for the goal metric achieved per trial. Args: trial_values (list): List with the hyperparameters values and the goal metric per trial. Returns ------- fig (json): json with the plot data """ x = list(range(1, len(trials) + 1)) y = [trial["value"] for trial in trials] max_cumulative = np.maximum.accumulate(y) fig = go.Figure() fig.add_trace( go.Scatter( x=x, y=y, mode="markers", name="Optimization History", marker_color="blue", marker_size=8, ) ) fig.add_trace( go.Scatter( x=x, y=max_cumulative, mode="lines", name="Current Max Value", line_color="red", line_width=2, ) ) fig.update_layout( title="Optimization History with Current Max Value", xaxis_title="Trial", yaxis_title="Objective Value", ) return plotly.io.to_json(fig) def slice_plot(self, trials): """ Plot that compares the performance in the search space of one hyperparameter. Args: trial_values (list): List with the hyperparameters values and the goal metric per trial. Returns ------- fig (json): json with the plot data """ param_names = list(trials[0]["params"].keys()) traces = [] for param_name in param_names: x_values = [trial["params"][param_name] for trial in trials] y_values = [trial["value"] for trial in trials] trial_numbers = list(range(1, len(trials) + 1)) trace = go.Scatter( x=x_values, y=y_values, mode="markers", marker={ "size": 8, "color": trial_numbers, "colorscale": "Blues", "colorbar": {"title": "Trial Number"}, "showscale": True, "line": {"width": 0.2, "color": "black"}, }, name=param_name, visible=False, ) traces.append(trace) traces[0]["visible"] = True buttons = [] for i, param_name in enumerate(param_names): buttons.append( { "method": "update", "label": param_name, "args": [ {"visible": [j == i for j in range(len(param_names))]}, { "title": f"Slice plot for {param_name}", "xaxis": {"title": param_name}, }, ], } ) updatemenus = [{"buttons": buttons, "direction": "down", "showactive": True}] fig = go.Figure(data=traces) fig.update_layout( updatemenus=updatemenus, title=f"Slice plot for {param_names[0]}", xaxis_title=param_names[0], yaxis_title="Objective Value", ) return plotly.io.to_json(fig) def contour_plot(self, trials): """ Contour plot between two hyperparameters and the goal metric achieved in the search space. Args: trial_values (list): List with the hyperparameters values and the goal metric per trial. Returns ------- fig (json): json with the plot data """ param_names = list(trials[0]["params"].keys()) traces = [] scatter_traces = [] for param_x in param_names: for param_y in param_names: if param_x != param_y: x_values = [ trial["params"][param_x] for trial in trials if param_x in trial["params"] ] y_values = [ trial["params"][param_y] for trial in trials if param_y in trial["params"] ] z_values = [ trial["value"] for trial in trials if param_x in trial["params"] and param_y in trial["params"] ] contour_trace = go.Contour( x=x_values, y=y_values, z=z_values, colorscale="Blues", colorbar={"title": "Objective Value"}, showscale=True, name=f"{param_x} vs {param_y}", visible=False, ) traces.append(contour_trace) scatter_trace = go.Scatter( x=x_values, y=y_values, mode="markers", marker={ "size": 8, "color": z_values, "colorscale": "Blues", "colorbar": {"title": "Objective Value"}, "showscale": False, "line": {"width": 0.2, "color": "black"}, }, name=f"{param_x} vs {param_y} points", visible=False, ) scatter_traces.append(scatter_trace) traces[0]["visible"] = True scatter_traces[0]["visible"] = True buttons = [] for i in range(len(traces)): buttons.append( { "method": "update", "label": traces[i]["name"], "args": [ { "visible": [j == i for j in range(len(traces))] + [j == i for j in range(len(scatter_traces))] }, {"title": f'Contour plot for {traces[i]["name"]}'}, ], } ) updatemenus = [{"buttons": buttons, "direction": "down", "showactive": True}] fig = go.Figure(data=traces + scatter_traces) fig.update_layout( updatemenus=updatemenus, title=f'Contour plot for {traces[0]["name"]}', xaxis_title=param_names[0], yaxis_title=param_names[1], ) return plotly.io.to_json(fig) def importance_plot(self, trials): """ Plot to obtain the importance between all the hyperparameters involved in hyperparameter optimization. Args: trial_values (list): List with the hyperparameters values and the goal metric per trial. Returns ------- fig (json): json with the plot data """ distributions = {} for param, (low, high) in self.parameters.items(): if isinstance(low, int): distributions[param] = optuna.distributions.IntDistribution(low, high) elif isinstance(low, float): distributions[param] = optuna.distributions.FloatDistribution(low, high) study = optuna.create_study(direction="maximize") for trial in trials: study.add_trial( optuna.trial.create_trial( params=trial["params"], distributions=distributions, value=trial["value"], state=optuna.trial.TrialState.COMPLETE, ) ) evaluator = FanovaImportanceEvaluator() importances = evaluator.evaluate(study) try: evaluator = FanovaImportanceEvaluator() importances = evaluator.evaluate(study) except RuntimeError: importances = { param: 1.0 / len(self.parameters) for param in self.parameters } log.warning( "Could not calculate parameter importance using FANOVA. " "Using equal importances." ) sorted_items = sorted(importances.items(), key=lambda item: item[1]) param_names, importance_values = zip(*sorted_items) fig = go.Figure( data=[ go.Bar( x=importance_values, y=param_names, orientation="h", text=importance_values, textposition="outside", texttemplate="%{text:.2f}", ) ] ) fig.update_layout( title="Hyperparameter importance", xaxis_title="Importance", yaxis_title="Hyperparameter", yaxis={"tickangle": 0}, ) return plotly.io.to_json(fig) def create_plots(self, trials, run_id, n_params): """ List of available plots. Args: trials (list): List with the hyperparameters values and the goal metric per trial. run_id (int): Number with the id associated to the current run from the experiment. n_params (int): Number of the different hyperparameters involved in the process of hyperparameter optimization Returns ------- fig (json): json with the plot data """ if n_params >= 2: plots_filenames = [ f"history_objective_plot_{run_id}.pickle", f"slice_plot_{run_id}.pickle", f"contour_plot_{run_id}.pickle", f"importance_plot_{run_id}.pickle", ] plots_list = [ self.history_objective_plot(trials), self.slice_plot(trials), self.contour_plot(trials), self.importance_plot(trials), ] return plots_filenames, plots_list else: plots_filenames = [ f"history_objective_plot_{run_id}.pickle", f"slice_plot_{run_id}.pickle", ] plots_list = [self.history_objective_plot(trials), self.slice_plot(trials)] return plots_filenames, plots_list