import os
import pathlib
import plotly.express as px
from beartype.typing import Any, Dict
from plotly.graph_objs import Figure
from plotly.io import read_json
from DashAI.back.core.schema_fields import bool_field, enum_field, schema_field
from DashAI.back.dataloaders.classes.dashai_dataset import (  # ClassLabel, Value,
    DashAIDataset,
)
from DashAI.back.dependencies.database.models import Explorer, Notebook
from DashAI.back.exploration.base_explorer import BaseExplorerSchema
from DashAI.back.exploration.distribution_explorer import DistributionExplorer
class BoxPlotSchema(BaseExplorerSchema):
    horizontal: schema_field(
        bool_field(),
        False,
        ("If True, the box plot will be horizontal, otherwise vertical."),
    )  # type: ignore
    points: schema_field(
        enum_field(["all", "outliers", "False"]),
        "outliers",
        ("One of 'all', 'outliers', or 'False'. Determines which points are shown."),
    )  # type: ignore
[docs]
class BoxPlotExplorer(DistributionExplorer):
    """
    BoxPlotExplorer is an explorer that returns a box plot
    of selected columns of a dataset.
    """
    DISPLAY_NAME = "Box Plot"
    DESCRIPTION = (
        "BoxPlotExplorer is an explorer that returns a box plot "
        "of selected columns of a dataset."
    )
    IMAGE_PREVIEW = "box_plot.png"
    SCHEMA = BoxPlotSchema
    metadata: Dict[str, Any] = {
        "allowed_dtypes": ["*"],
        "restricted_dtypes": [],
        "input_cardinality": {"min": 1, "max": 2},
    }
[docs]
    def __init__(self, **kwargs) -> None:
        self.horizontal = kwargs.get("horizontal", False)
        if kwargs.get("points") == "False":
            kwargs["points"] = False
        self.points = kwargs.get("points", "outliers")
        super().__init__(**kwargs) 
    def launch_exploration(self, dataset: DashAIDataset, explorer_info: Explorer):
        _df = dataset.to_pandas()
        cols = [col["columnName"] for col in explorer_info.columns]
        if len(cols) == 1:
            fig = px.box(
                _df,
                x=cols[0] if self.horizontal else None,
                y=None if self.horizontal else cols[0],
                title=f"Boxplot of {cols[0]}",
                points=self.points,
            )
        elif len(cols) == 2:
            fig = px.box(
                _df,
                x=cols[1] if self.horizontal else cols[0],
                y=cols[0] if self.horizontal else cols[1],
                title=f"Boxplot of {cols[0]} vs {cols[1]}",
                points=self.points,
            )
        else:
            raise ValueError("BoxPlotExplorer can only handle 1 or 2 columns")
        if explorer_info.name is not None and explorer_info.name != "":
            fig.update_layout(title=f"{explorer_info.name}")
        return fig
    def save_notebook(
        self,
        __notebook_info__: Notebook,
        explorer_info: Explorer,
        save_path: pathlib.Path,
        result: Figure,
    ) -> str:
        filename = f"{explorer_info.id}.json"
        path = pathlib.Path(os.path.join(save_path, filename))
        result.write_json(path.as_posix())
        return path.as_posix()
    def get_results(
        self, exploration_path: str, options: Dict[str, Any]
    ) -> Dict[str, Any]:
        resultType = "plotly_json"
        config = {}
        result = read_json(exploration_path)
        result = result.to_json()
        return {"data": result, "type": resultType, "config": config}