Source code for DashAI.back.models.hugging_face.vit_transformer

"""DashAI implementation of DistilBERT model for image classification."""

import shutil
from typing import Optional

import numpy as np
from datasets import Dataset
from sklearn.exceptions import NotFittedError
from transformers import (
    Trainer,
    TrainingArguments,
    ViTFeatureExtractor,
    ViTForImageClassification,
)

from DashAI.back.core.schema_fields import (
    BaseSchema,
    enum_field,
    float_field,
    int_field,
    schema_field,
)
from DashAI.back.models.image_classification_model import ImageClassificationModel


class ViTTransformerSchema(BaseSchema):
    """ViT is a transformer that allows you to classify text in English."""

    num_train_epochs: schema_field(
        int_field(ge=1),
        placeholder=3,
        description="Total number of training epochs to perform.",
    )  # type: ignore
    batch_size: schema_field(
        int_field(ge=1),
        placeholder=8,
        description="The batch size per GPU/TPU core/CPU for training",
    )  # type: ignore
    learning_rate: schema_field(
        float_field(ge=0.0),
        placeholder=5e-5,
        description="The initial learning rate for AdamW optimizer",
    )  # type: ignore
    device: schema_field(
        enum_field(enum=["gpu", "cpu"]),
        placeholder="gpu",
        description="Hardware on which the training is run. If available, GPU is "
        "recommended for efficiency reasons. Otherwise, use CPU.",
    )  # type: ignore
    weight_decay: schema_field(
        float_field(ge=0.0),
        placeholder=0.0,
        description="Weight decay is a regularization technique used in training "
        "neural networks to prevent overfitting. In the context of the AdamW "
        "optimizer, the 'weight_decay' parameter is the rate at which the weights of "
        "all layers are reduced during training, provided that this rate is not zero.",
    )  # type: ignore


[docs]class ViTTransformer(ImageClassificationModel): """Pre-trained Vision Transformer (ViT) for image classification. Vision Transformer (ViT) is a transformer that is targeted at vision processing tasks such as image recognition.[1] References ---------- [1] https://en.wikipedia.org/wiki/Vision_transformer [2] https://huggingface.co/docs/transformers/model_doc/vit """ SCHEMA = ViTTransformerSchema
[docs] def __init__(self, model=None, **kwargs): """Initialize the transformer. This process includes the instantiation of the pre-trained model and the associated feature extractor. """ kwargs = self.validate_and_transform(kwargs) self.model_name = "google/vit-base-patch16-224" self.feature_extractor = ViTFeatureExtractor.from_pretrained(self.model_name) self.model = ( model if model is not None else ViTForImageClassification.from_pretrained(self.model_name) ) self.fitted = model is not None if model is None: self.training_args = kwargs self.batch_size = kwargs.pop("batch_size", 8) self.device = kwargs.pop("device", "gpu")
def preprocess_images(self, x: Dataset, y: Optional[Dataset] = None): """Preprocess images for model input. Parameters ---------- x: Dataset Dataset with the input data to preprocess. y: Optional Dataset Dataset with the output data to preprocess. Returns ------- Dataset Dataset with the processed data. """ # If the output datset is not given, create an empty dataset if not y: y = Dataset.from_list([{"foo": 0}] * len(x)) # Initialize useful variables dataset = [] input_column_name = x.column_names[0] output_column_name = y.column_names[0] # Preprocess both datasets for input_sample, output_sample in zip(x, y): # noqa preprocessed_input = self.feature_extractor( images=input_sample[input_column_name], return_tensors="pt", size=224 ) reshaped_image = preprocessed_input["pixel_values"].reshape( ( preprocessed_input["pixel_values"].shape[1], preprocessed_input["pixel_values"].shape[2], preprocessed_input["pixel_values"].shape[3], ) ) dataset.append( { "pixel_values": reshaped_image, "labels": output_sample[output_column_name], } ) return Dataset.from_list(dataset) def fit(self, x_train: Dataset, y_train: Dataset): """Fine-tune the pre-trained model. Parameters ---------- x_train : Dataset Dataset with input training data. y_train: Dataset Dataset with output training data. """ dataset = self.preprocess_images(x_train, y_train) # Arguments for fine-tuning training_args = TrainingArguments( output_dir="DashAI/back/user_models/temp_checkpoints_vit", save_steps=1, save_total_limit=1, per_device_train_batch_size=self.batch_size, per_device_eval_batch_size=self.batch_size, no_cuda=self.device != "gpu", **self.training_args, ) # The Trainer class is used for fine-tuning the model. trainer = Trainer( model=self.model, args=training_args, train_dataset=dataset, ) trainer.train() self.fitted = True shutil.rmtree( "DashAI/back/user_models/temp_checkpoints_vit", ignore_errors=True ) def predict(self, x_pred: Dataset) -> np.array: """Make a prediction with the fine-tuned model. Parameters ---------- x_pred : Dataset Dataset with image data. Returns ------- np.array Numpy array with the probabilities for each class. """ if not self.fitted: raise NotFittedError( f"This {self.__class__.__name__} instance is not fitted yet. Call 'fit'" " with appropriate arguments before using this estimator." ) dataset = self.preprocess_images(x_pred) dataset.set_format("torch", columns=["pixel_values", "labels"]) probabilities = [] # Iterate over each batch in the dataset for i in range(len(dataset)): # Prepare a batch of images for the model batch = dataset[i] # Make sure that the tensors are in the correct device. batch = {k: v.to(self.model.device) for k, v in batch.items()} if batch["pixel_values"].dim() == 3: batch["pixel_values"] = batch["pixel_values"].unsqueeze(0) outputs = self.model(**batch) # Takes the model probability using softmax probs = outputs.logits.softmax(dim=-1) probabilities.extend(probs.detach().cpu().numpy()) return np.array(probabilities) def save(self, filename=None): self.model.save_pretrained(filename) @classmethod def load(cls, filename): model = ViTForImageClassification.from_pretrained(filename) return cls(model=model)