Source code for DashAI.back.models.hugging_face.distilbert_transformer
"""DashAI implementation of DistilBERT model for english classification."""
import shutil
from pathlib import Path
from typing import Any, Union
import torch
from datasets import Dataset
from sklearn.exceptions import NotFittedError
from torch.utils.data import DataLoader
from transformers import (
AutoConfig,
AutoModelForSequenceClassification,
AutoTokenizer,
DataCollatorWithPadding,
Trainer,
TrainingArguments,
)
from DashAI.back.core.schema_fields import (
BaseSchema,
enum_field,
float_field,
int_field,
schema_field,
)
from DashAI.back.models.text_classification_model import TextClassificationModel
class DistilBertTransformerSchema(BaseSchema):
"""Distilbert is a transformer that allows you to classify text in English.
The implementation is based on huggingface distilbert-base in the case of
the uncased model, i.e. distilbert-base-uncased.
"""
num_train_epochs: schema_field(
int_field(ge=1),
placeholder=2,
description="Total number of training epochs to perform.",
) # type: ignore
batch_size: schema_field(
int_field(ge=1),
placeholder=16,
description="The batch size per GPU/TPU core/CPU for training",
) # type: ignore
learning_rate: schema_field(
float_field(ge=0.0),
placeholder=5e-5,
description="The initial learning rate for AdamW optimizer",
) # type: ignore
device: schema_field(
enum_field(enum=["gpu", "cpu"]),
placeholder="gpu",
description="Hardware on which the training is run. If available, GPU is "
"recommended for efficiency reasons. Otherwise, use CPU.",
) # type: ignore
weight_decay: schema_field(
float_field(ge=0.0),
placeholder=0.01,
description="Weight decay is a regularization technique used in training "
"neural networks to prevent overfitting. In the context of the AdamW "
"optimizer, the 'weight_decay' parameter is the rate at which the weights of "
"all layers are reduced during training, provided that this rate is not zero.",
) # type: ignore
[docs]class DistilBertTransformer(TextClassificationModel):
"""Pre-trained transformer DistilBERT allowing English text classification.
DistilBERT is a small, fast, cheap and light Transformer model trained by
distilling BERT base.
It has 40% less parameters than bert-base-uncased, runs 60% faster while preserving
over 95% of BERT's performances as measured on the GLUE language understanding
benchmark [1].
References
----------
[1] https://huggingface.co/docs/transformers/model_doc/distilbert
"""
SCHEMA = DistilBertTransformerSchema
[docs] def __init__(self, model=None, **kwargs):
"""Initialize the transformer model.
The process includes the instantiation of the pre-trained model and the
associated tokenizer.
"""
self.num_labels = kwargs.get("num_labels")
kwargs = self.validate_and_transform(kwargs)
self.model_name = "distilbert-base-uncased"
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.training_args = {
"num_train_epochs": kwargs.get("num_train_epochs", 2),
"learning_rate": kwargs.get("learning_rate", 5e-5),
"weight_decay": kwargs.get("weight_decay", 0.01),
}
self.batch_size = kwargs.get("batch_size", 16)
self.device = kwargs.get("device", "gpu")
self.model = (
model
if model is not None
else AutoModelForSequenceClassification.from_pretrained(self.model_name)
)
self.fitted = False
def tokenize_data(self, dataset: Dataset) -> Dataset:
"""Tokenize the input data.
Parameters
----------
dataset : Dataset
Dataset with the input data to preprocess.
Returns
-------
Dataset
Dataset with the tokenized input data.
"""
return dataset.map(
lambda examples: self.tokenizer(
examples["text"], truncation=True, padding=True, max_length=512
),
batched=True,
)
def fit(self, x_train: Dataset, y_train: Dataset):
"""Fine-tune the pre-trained model.
Parameters
----------
x_train : Dataset
Dataset with input training data.
y_train : Dataset
Dataset with output training data.
"""
output_column_name = y_train.column_names[0]
if self.num_labels is None:
self.num_labels = len(set(y_train[output_column_name]))
self.model.config.num_labels = self.num_labels
train_dataset = self.tokenize_data(x_train)
train_dataset = train_dataset.add_column("label", y_train[output_column_name])
can_use_fp16 = torch.cuda.is_available() and self.device == "gpu"
training_args = TrainingArguments(
output_dir="DashAI/back/user_models/temp_checkpoints_distilbert",
logging_strategy="steps",
logging_steps=50,
save_strategy="epoch", # Guarda checkpoints al final de cada época
per_device_train_batch_size=self.batch_size,
no_cuda=self.device != "gpu",
fp16=can_use_fp16,
**self.training_args,
)
data_collator = DataCollatorWithPadding(tokenizer=self.tokenizer)
trainer = Trainer(
model=self.model,
args=training_args,
train_dataset=train_dataset,
data_collator=data_collator,
tokenizer=self.tokenizer,
)
trainer.train()
self.fitted = True
shutil.rmtree(
"DashAI/back/user_models/temp_checkpoints_distilbert", ignore_errors=True
)
return self
def predict(self, x_pred: Dataset):
"""Predict with the fine-tuned model.
Parameters
----------
x_pred : Dataset
Dataset with text data.
Returns
-------
List
List of predicted probabilities for each class.
"""
if not self.fitted:
raise NotFittedError(
f"This {self.__class__.__name__} instance is not fitted yet. Call 'fit'"
" with appropriate arguments before using this estimator."
)
pred_dataset = self.tokenize_data(x_pred)
data_collator = DataCollatorWithPadding(tokenizer=self.tokenizer)
pred_loader = DataLoader(
pred_dataset.remove_columns(["text"]),
batch_size=self.batch_size,
collate_fn=data_collator,
)
probabilities = []
for batch in pred_loader:
inputs = {
k: v.to(self.model.device) for k, v in batch.items() if k != "labels"
}
outputs = self.model(**inputs)
probs = outputs.logits.softmax(dim=-1)
probabilities.extend(probs.detach().cpu().numpy())
return probabilities
def save(self, filename: Union[str, Path]) -> None:
self.model.save_pretrained(filename)
config = AutoConfig.from_pretrained(filename)
config.custom_params = {
"num_train_epochs": self.training_args.get("num_train_epochs"),
"batch_size": self.batch_size,
"learning_rate": self.training_args.get("learning_rate"),
"device": self.device,
"weight_decay": self.training_args.get("weight_decay"),
"num_labels": self.num_labels,
"fitted": self.fitted,
}
config.save_pretrained(filename)
@classmethod
def load(cls, filename: Union[str, Path]) -> Any:
config = AutoConfig.from_pretrained(filename)
custom_params = getattr(config, "custom_params", {})
model = AutoModelForSequenceClassification.from_pretrained(
filename, num_labels=custom_params.get("num_labels")
)
loaded_model = cls(
model=model,
model_name=config.model_type,
num_labels=custom_params.get("num_labels"),
num_train_epochs=custom_params.get("num_train_epochs"),
batch_size=custom_params.get("batch_size"),
learning_rate=custom_params.get("learning_rate"),
device=custom_params.get("device"),
weight_decay=custom_params.get("weight_decay"),
)
loaded_model.fitted = custom_params.get("fitted")
return loaded_model