Source code for DashAI.back.models.hugging_face.opus_mt_en_es_transformer
"""OpusMtEnESTransformer model for english-spanish translation DashAI implementation."""
import shutil
from pathlib import Path
from typing import List, Optional, Union
import torch
from datasets import Dataset
from sklearn.exceptions import NotFittedError
from transformers import (
    AutoConfig,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
)
from DashAI.back.core.schema_fields import (
    BaseSchema,
    enum_field,
    float_field,
    int_field,
    schema_field,
)
from DashAI.back.models.translation_model import TranslationModel
if torch.cuda.is_available():
    DEVICE_ENUM = [f"cuda:{i}" for i in range(torch.cuda.device_count())] + ["cpu"]
    DEVICE_PLACEHOLDER = "cuda:0"
else:
    DEVICE_ENUM = ["cpu"]
    DEVICE_PLACEHOLDER = "cpu"
class OpusMtEnESTransformerSchema(BaseSchema):
    """opus-mt-en-es is a transformer pre-trained model that allows translation of
    texts from English to Spanish.
    """
    num_train_epochs: schema_field(
        int_field(ge=1),
        placeholder=1,
        description="Total number of training epochs to perform.",
    )  # type: ignore
    batch_size: schema_field(
        int_field(ge=1),
        placeholder=4,
        description="The batch size per GPU/TPU core/CPU for training",
    )  # type: ignore
    learning_rate: schema_field(
        float_field(ge=0.0),
        placeholder=2e-5,
        description="The initial learning rate for AdamW optimizer",
    )  # type: ignore
    device: schema_field(
        enum_field(enum=DEVICE_ENUM),
        placeholder=DEVICE_PLACEHOLDER,
        description="Hardware on which the training is run. If available, GPU is "
        "recommended for efficiency reasons. Otherwise, use CPU.",
    )  # type: ignore
    weight_decay: schema_field(
        float_field(ge=0.0),
        placeholder=0.01,
        description="Weight decay is a regularization technique used in training "
        "neural networks to prevent overfitting. In the context of the AdamW "
        "optimizer, the 'weight_decay' parameter is the rate at which the weights of "
        "all layers are reduced during training, provided that this rate is not zero.",
    )  # type: ignore
[docs]
class OpusMtEnESTransformer(TranslationModel):
    """Pre-trained transformer for english-spanish translation.
    This model fine-tunes the pre-trained model opus-mt-en-es.
    """
    SCHEMA = OpusMtEnESTransformerSchema
    DISPLAY_NAME: str = "Opus MT En-Es Transformer"
    COLOR: str = "#4CAF50"
[docs]
    def __init__(self, model=None, **kwargs):
        """Initialize the transformer.
        This process includes the instantiation of the pre-trained model and the
        associated tokenizer.
        """
        kwargs = self.validate_and_transform(kwargs)
        self.model_name = "Helsinki-NLP/opus-mt-en-es"
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        if model is None:
            self.training_args = kwargs
            self.batch_size = kwargs.pop("batch_size", 16)
            self.device = kwargs.pop("device", "gpu")
        self.model = (
            model
            if model is not None
            else AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
        )
        self.fitted = model is not None
    def tokenize_data(self, x: Dataset, y: Optional[Dataset] = None) -> Dataset:
        """Tokenize input and output.
        Parameters
        ----------
        x: Dataset
            Dataset with the input data to preprocess.
        y: Optional Dataset
            Dataset with the output data to preprocess.
        Returns
        -------
        Dataset
            Dataset with the processed data.
        """
        is_y = bool(y)
        if not y:
            y = Dataset.from_list([{"foo": 0}] * len(x))
        dataset = []
        input_column_name = x.column_names[0]
        output_column_name = y.column_names[0]
        for input_sample, output_sample in zip(x, y):
            tokenized_input = self.tokenizer(
                input_sample[input_column_name],
                truncation=True,
                padding="max_length",
                max_length=512,
            )
            tokenized_output = (
                self.tokenizer(
                    output_sample[output_column_name],
                    truncation=True,
                    padding="max_length",
                    max_length=512,
                )
                if is_y
                else None
            )
            sample = {
                "input_ids": tokenized_input["input_ids"],
                "attention_mask": tokenized_input["attention_mask"],
                "labels": (
                    tokenized_output["input_ids"] if is_y else y[output_column_name]
                ),
            }
            dataset.append(sample)
        return Dataset.from_list(dataset)
    def fit(self, x_train: Dataset, y_train: Dataset):
        """Fine-tune the pre-trained model.
        Parameters
        ----------
        x_train : Dataset
            Dataset with input training data.
        y_train : Dataset
            Dataset with output training data.
        """
        dataset = self.tokenize_data(x_train, y_train)
        dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
        training_args = Seq2SeqTrainingArguments(
            output_dir="DashAI/back/user_models/temp_checkpoints_opus-mt-en-es",
            save_steps=1,
            save_total_limit=1,
            per_device_train_batch_size=self.batch_size,
            per_device_eval_batch_size=self.batch_size,
            use_cpu=self.device != "gpu",
            **self.training_args,
        )
        trainer = Seq2SeqTrainer(
            model=self.model,
            args=training_args,
            train_dataset=dataset,
        )
        trainer.train()
        self.fitted = True
        shutil.rmtree(
            "DashAI/back/user_models/temp_checkpoints_opus-mt-en-es", ignore_errors=True
        )
        return self
    def predict(self, x_pred: Dataset) -> List:
        """Predict with the fine-tuned model.
        Parameters
        ----------
        x_pred : Dataset
            Dataset with text data.
        Returns
        -------
        List
            list of translations made by the model.
        """
        if not self.fitted:
            raise NotFittedError(
                f"This {self.__class__.__name__} instance is not fitted yet. Call 'fit'"
                " with appropriate arguments before using this "
                "estimator."
            )
        dataset = self.tokenize_data(x_pred)
        dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])
        translations = []
        for example in dataset:
            inputs = {
                k: v.unsqueeze(0).to(self.model.device) for k, v in example.items()
            }
            outputs = self.model.generate(**inputs)
            translated_text = self.tokenizer.decode(
                outputs[0], skip_special_tokens=True
            )
            translations.append(translated_text)
        return translations
    def save(self, filename: Union[str, Path]) -> None:
        self.model.save_pretrained(filename)
        config = AutoConfig.from_pretrained(filename)
        config.custom_params = {
            "num_train_epochs": self.training_args.get("num_train_epochs"),
            "batch_size": self.batch_size,
            "learning_rate": self.training_args.get("learning_rate"),
            "device": self.device,
            "weight_decay": self.training_args.get("weight_decay"),
            "fitted": self.fitted,
        }
        config.save_pretrained(filename)
    @classmethod
    def load(cls, filename: Union[str, Path]):
        model = AutoModelForSeq2SeqLM.from_pretrained(filename)
        config = AutoConfig.from_pretrained(filename)
        custom_params = getattr(config, "custom_params", {})
        loaded_model = cls(
            model=model,
            num_train_epochs=custom_params.get("num_train_epochs"),
            batch_size=custom_params.get("batch_size"),
            learning_rate=custom_params.get("learning_rate"),
            device=custom_params.get("device"),
            weight_decay=custom_params.get("weight_decay"),
        )
        loaded_model.fitted = custom_params.get("fitted", False)
        return loaded_model