Source code for DashAI.back.models.hugging_face.opus_mt_en_es_transformer
"""OpusMtEnESTransformer model for english-spanish translation DashAI implementation."""
import shutil
from pathlib import Path
from typing import List, Optional, Union
from datasets import Dataset
from sklearn.exceptions import NotFittedError
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
)
from DashAI.back.core.schema_fields import (
BaseSchema,
enum_field,
float_field,
int_field,
schema_field,
)
from DashAI.back.models.translation_model import TranslationModel
class OpusMtEnESTransformerSchema(BaseSchema):
"""opus-mt-en-es is a transformer pre-trained model that allows translation of
texts from English to Spanish.
"""
num_train_epochs: schema_field(
int_field(ge=1),
placeholder=1,
description="Total number of training epochs to perform.",
) # type: ignore
batch_size: schema_field(
int_field(ge=1),
placeholder=4,
description="The batch size per GPU/TPU core/CPU for training",
) # type: ignore
learning_rate: schema_field(
float_field(ge=0.0),
placeholder=2e-5,
description="The initial learning rate for AdamW optimizer",
) # type: ignore
device: schema_field(
enum_field(enum=["gpu", "cpu"]),
placeholder="gpu",
description="Hardware on which the training is run. If available, GPU is "
"recommended for efficiency reasons. Otherwise, use CPU.",
) # type: ignore
weight_decay: schema_field(
float_field(ge=0.0),
placeholder=0.01,
description="Weight decay is a regularization technique used in training "
"neural networks to prevent overfitting. In the context of the AdamW "
"optimizer, the 'weight_decay' parameter is the rate at which the weights of "
"all layers are reduced during training, provided that this rate is not zero.",
) # type: ignore
[docs]class OpusMtEnESTransformer(TranslationModel):
"""Pre-trained transformer for english-spanish translation.
This model fine-tunes the pre-trained model opus-mt-en-es.
"""
SCHEMA = OpusMtEnESTransformerSchema
[docs] def __init__(self, model=None, **kwargs):
"""Initialize the transformer.
This process includes the instantiation of the pre-trained model and the
associated tokenizer.
"""
kwargs = self.validate_and_transform(kwargs)
self.model_name = "Helsinki-NLP/opus-mt-en-es"
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
if model is None:
self.training_args = kwargs
self.batch_size = kwargs.pop("batch_size", 16)
self.device = kwargs.pop("device", "gpu")
self.model = (
model
if model is not None
else AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
)
self.fitted = model is not None
def tokenize_data(self, x: Dataset, y: Optional[Dataset] = None) -> Dataset:
"""Tokenize input and output.
Parameters
----------
x: Dataset
Dataset with the input data to preprocess.
y: Optional Dataset
Dataset with the output data to preprocess.
Returns
-------
Dataset
Dataset with the processed data.
"""
is_y = bool(y)
if not y:
y = Dataset.from_list([{"foo": 0}] * len(x))
dataset = []
input_column_name = x.column_names[0]
output_column_name = y.column_names[0]
for input_sample, output_sample in zip(x, y):
tokenized_input = self.tokenizer(
input_sample[input_column_name],
truncation=True,
padding="max_length",
max_length=512,
)
tokenized_output = (
self.tokenizer(
output_sample[output_column_name],
truncation=True,
padding="max_length",
max_length=512,
)
if is_y
else None
)
sample = {
"input_ids": tokenized_input["input_ids"],
"attention_mask": tokenized_input["attention_mask"],
"labels": (
tokenized_output["input_ids"] if is_y else y[output_column_name]
),
}
dataset.append(sample)
return Dataset.from_list(dataset)
def fit(self, x_train: Dataset, y_train: Dataset):
"""Fine-tune the pre-trained model.
Parameters
----------
x_train : Dataset
Dataset with input training data.
y_train : Dataset
Dataset with output training data.
"""
dataset = self.tokenize_data(x_train, y_train)
dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
training_args = Seq2SeqTrainingArguments(
output_dir="DashAI/back/user_models/temp_checkpoints_opus-mt-en-es",
save_steps=1,
save_total_limit=1,
per_device_train_batch_size=self.batch_size,
per_device_eval_batch_size=self.batch_size,
no_cuda=self.device != "gpu",
**self.training_args,
)
trainer = Seq2SeqTrainer(
model=self.model,
args=training_args,
train_dataset=dataset,
)
trainer.train()
self.fitted = True
shutil.rmtree(
"DashAI/back/user_models/temp_checkpoints_opus-mt-en-es", ignore_errors=True
)
return self
def predict(self, x_pred: Dataset) -> List:
"""Predict with the fine-tuned model.
Parameters
----------
x_pred : Dataset
Dataset with text data.
Returns
-------
List
list of translations made by the model.
"""
if not self.fitted:
raise NotFittedError(
f"This {self.__class__.__name__} instance is not fitted yet. Call 'fit'"
" with appropriate arguments before using this "
"estimator."
)
dataset = self.tokenize_data(x_pred)
dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])
translations = []
for example in dataset:
inputs = {
k: v.unsqueeze(0).to(self.model.device) for k, v in example.items()
}
outputs = self.model.generate(**inputs)
translated_text = self.tokenizer.decode(
outputs[0], skip_special_tokens=True
)
translations.append(translated_text)
return translations
def save(self, filename: Union[str, Path]) -> None:
self.model.save_pretrained(filename)
config = AutoConfig.from_pretrained(filename)
config.custom_params = {
"num_train_epochs": self.training_args.get("num_train_epochs"),
"batch_size": self.batch_size,
"learning_rate": self.training_args.get("learning_rate"),
"device": self.device,
"weight_decay": self.training_args.get("weight_decay"),
"fitted": self.fitted,
}
config.save_pretrained(filename)
@classmethod
def load(cls, filename: Union[str, Path]):
model = AutoModelForSeq2SeqLM.from_pretrained(filename)
config = AutoConfig.from_pretrained(filename)
custom_params = getattr(config, "custom_params", {})
loaded_model = cls(
model=model,
num_train_epochs=custom_params.get("num_train_epochs"),
batch_size=custom_params.get("batch_size"),
learning_rate=custom_params.get("learning_rate"),
device=custom_params.get("device"),
weight_decay=custom_params.get("weight_decay"),
)
loaded_model.fitted = custom_params.get("fitted", False)
return loaded_model