Source code for DashAI.back.converters.hugging_face.embedding

import torch
from datasets import Dataset, concatenate_datasets
from transformers import AutoModel, AutoTokenizer

from DashAI.back.converters.hugging_face_wrapper import HuggingFaceWrapper
from DashAI.back.core.schema_fields import (
    enum_field,
    int_field,
    schema_field,
)
from DashAI.back.core.schema_fields.base_schema import BaseSchema
from DashAI.back.dataloaders.classes.dashai_dataset import DashAIDataset


class EmbeddingSchema(BaseSchema):
    model_name: schema_field(
        enum_field(
            [
                # Sentence Transformers Models
                "sentence-transformers/all-MiniLM-L6-v2",
                "sentence-transformers/all-mpnet-base-v2",
                "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
                "sentence-transformers/all-distilroberta-v1",
                # BERT Models
                "bert-base-uncased",
                "bert-large-uncased",
                "bert-base-multilingual-cased",
                "distilbert-base-uncased",
                # RoBERTa Models
                "roberta-base",
                "roberta-large",
                "distilroberta-base",
            ]
        ),
        "sentence-transformers/all-MiniLM-L6-v2",
        "Name of the pre-trained model to use",
    )  # type: ignore

    max_length: schema_field(
        int_field(ge=1), 512, "Maximum sequence length for tokenization"
    )  # type: ignore

    batch_size: schema_field(
        int_field(ge=1), 32, "Number of samples to process at once"
    )  # type: ignore

    device: schema_field(
        enum_field(["cuda", "cpu"]),
        "cpu",
        "Device to use for computation",
    )  # type: ignore

    pooling_strategy: schema_field(
        enum_field(["mean", "cls", "max"]),
        "mean",
        "Strategy to pool token embeddings into sentence embedding",
    )  # type: ignore


[docs] class Embedding(HuggingFaceWrapper): """HuggingFace embedding converter.""" SCHEMA = EmbeddingSchema DESCRIPTION = "Convert text to embeddings using HuggingFace transformer models."
[docs] def __init__(self, **kwargs): super().__init__(**kwargs) self.pooling_strategy = kwargs.get("pooling_strategy", "mean") self.model_name = kwargs.get( "model_name", "sentence-transformers/all-MiniLM-L6-v2" ) self.device = kwargs.get("device", "cpu") self.max_length = kwargs.get("max_length", 512) self.batch_size = kwargs.get("batch_size", 32) self.model = None self.tokenizer = None
def _load_model(self): """Load the embedding model and tokenizer.""" self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) self.model = AutoModel.from_pretrained(self.model_name).to(self.device) self.model.eval() def _process_batch(self, batch: DashAIDataset) -> DashAIDataset: """Process a batch of text into embeddings.""" all_column_embeddings = [] for column in batch.column_names: # Get text data from dataset texts = [row[column] for row in batch] # Tokenize encoded = self.tokenizer( texts, padding=True, truncation=True, max_length=self.max_length, return_tensors="pt", ) # Move to device encoded = {k: v.to(self.device) for k, v in encoded.items()} # Get embeddings with torch.no_grad(): outputs = self.model(**encoded) hidden_states = outputs.last_hidden_state # Apply pooling strategy if self.pooling_strategy == "mean": embeddings = torch.mean(hidden_states, dim=1) elif self.pooling_strategy == "cls": embeddings = hidden_states[:, 0] else: # max pooling embeddings = torch.max(hidden_states, dim=1)[0] embeddings_np = embeddings.cpu().numpy() # Create a dictionary with embedding columns embedding_dict = { f"{column}_embedding_{i}": embeddings_np[:, i].tolist() for i in range(embeddings_np.shape[1]) } # Create a HuggingFace Dataset and convert it to a PyArrow table hf_dataset = Dataset.from_dict(embedding_dict) arrow_table = hf_dataset.data.table # Create a new dataset for this column's embeddings column_dataset = DashAIDataset(arrow_table) all_column_embeddings.append(column_dataset) # Concatenate all column embeddings concatenated_dataset = concatenate_datasets(all_column_embeddings) return DashAIDataset(concatenated_dataset.data.table)