Source code for DashAI.back.tasks.text_classification_task

from typing import List, Union

from datasets import ClassLabel, DatasetDict, Value

from DashAI.back.dataloaders.classes.dashai_dataset import (
    DashAIDataset,
    to_dashai_dataset,
)
from DashAI.back.tasks.classification_task import ClassificationTask


[docs] class TextClassificationTask(ClassificationTask): """Base class for Text Classification Task.""" metadata: dict = { "inputs_types": [Value], "outputs_types": [ClassLabel], "inputs_cardinality": 1, "outputs_cardinality": 1, } DESCRIPTION: str = """ Text classification is an essential Natural Language Processing (NLP) task that involves automatically assigning pre-defined categories or labels to text documents based on their content. It serves as the foundation for applications like sentiment analysis, spam filtering, topic classification, and document categorization. """ DISPLAY_NAME: str = "Text Classification" def prepare_for_task( self, datasetdict: Union[DatasetDict, DashAIDataset], outputs_columns: List[str] ) -> DashAIDataset: """Change the column types to suit the text classification task. A copy of the dataset is created. Parameters ---------- datasetdict : DatasetDict Dataset to be changed Returns ------- DatasetDict Dataset with the new types """ types = dict.fromkeys(outputs_columns, "Categorical") datasetdict = to_dashai_dataset(datasetdict) dataset = datasetdict.change_columns_type(types) return dataset