import os
import pathlib
import plotly.express as px
from beartype.typing import Any, Dict, List
from plotly.graph_objs import Figure
from plotly.io import read_json
from DashAI.back.core.schema_fields import (
int_field,
none_type,
schema_field,
string_field,
union_type,
)
from DashAI.back.dataloaders.classes.dashai_dataset import ( # ClassLabel, Value,
DashAIDataset,
)
from DashAI.back.dependencies.database.models import Exploration, Explorer
from DashAI.back.exploration.base_explorer import BaseExplorer, BaseExplorerSchema
class ScatterPlotSchema(BaseExplorerSchema):
color_group: schema_field(
none_type(union_type(string_field(), int_field(ge=0))),
None,
("The columnName or columnIndex to take for grouping colored points."),
) # type: ignore
simbol_group: schema_field(
none_type(union_type(string_field(), int_field(ge=0))),
None,
("The columnName or columnIndex to take for grouping simbol of the points."),
) # type: ignore
point_size: schema_field(
none_type(union_type(string_field(), int_field(ge=0))),
None,
("The columnName or columnIndex to take for size of each point."),
) # type: ignore
[docs]class ScatterPlotExplorer(BaseExplorer):
"""
ScatterPlotExplorer is an explorer that returns a scatter plot
of selected columns of a dataset.
"""
DISPLAY_NAME = "Scatter Plot"
DESCRIPTION = (
"ScatterPlotExplorer is an explorer that returns a scatter plot "
"of selected columns of a dataset."
)
SCHEMA = ScatterPlotSchema
metadata: Dict[str, Any] = {
"allowed_dtypes": ["*"],
"restricted_dtypes": [],
"input_cardinality": {"exact": 2},
}
[docs] def __init__(self, **kwargs) -> None:
self.color_column = kwargs.get("color_group")
self.simbol_column = kwargs.get("simbol_group")
self.size_column = kwargs.get("point_size")
super().__init__(**kwargs)
def prepare_dataset(
self, loaded_dataset: DashAIDataset, columns: List[Dict[str, Any]]
) -> DashAIDataset:
explorer_columns = [col["columnName"] for col in columns]
dataset_columns = loaded_dataset.column_names
if self.color_column is not None:
if isinstance(self.color_column, int):
idx = self.color_column
col = dataset_columns[idx]
if col not in explorer_columns:
columns.append({"id": idx, "columnName": col})
else:
col = self.color_column
if col not in explorer_columns:
columns.append({"columnName": col})
self.color_column = col
if self.simbol_column is not None:
if isinstance(self.simbol_column, int):
idx = self.simbol_column
col = dataset_columns[idx]
if col not in explorer_columns:
columns.append({"id": idx, "columnName": col})
else:
col = self.simbol_column
if col not in explorer_columns:
columns.append({"columnName": col})
self.simbol_column = col
if self.size_column is not None:
if isinstance(self.size_column, (int, float)):
idx = self.size_column
col = dataset_columns[idx]
if col not in explorer_columns:
columns.append({"id": idx, "columnName": col})
else:
col = self.size_column
if col not in explorer_columns:
columns.append({"columnName": col})
self.size_column = col
return super().prepare_dataset(loaded_dataset, columns)
def launch_exploration(self, dataset: DashAIDataset, explorer_info: Explorer):
_df = dataset.to_pandas()
cols = [col["columnName"] for col in explorer_info.columns]
colorColumn = self.color_column if self.color_column in _df.columns else None
simbolColumn = self.simbol_column if self.simbol_column in _df.columns else None
sizeColumn = self.size_column if self.size_column in _df.columns else None
fig = px.scatter(
_df,
x=cols[0],
y=cols[1],
color=colorColumn,
symbol=simbolColumn,
size=sizeColumn,
title=f"Scatter Plot of {cols[0]} vs {cols[1]}",
)
if explorer_info.name is not None and explorer_info.name != "":
fig.update_layout(title=f"{explorer_info.name}")
return fig
def save_exploration(
self,
__exploration_info__: Exploration,
explorer_info: Explorer,
save_path: pathlib.Path,
result: Figure,
) -> str:
filename = f"{explorer_info.id}.pickle"
path = pathlib.Path(os.path.join(save_path, filename))
result.write_json(path.as_posix())
return path.as_posix()
def get_results(
self, exploration_path: str, options: Dict[str, Any]
) -> Dict[str, Any]:
resultType = "plotly_json"
config = {}
result = read_json(exploration_path)
result = result.to_json()
return {"data": result, "type": resultType, "config": config}