TAPASxHF2 / embedding_generator.py
jskinner215's picture
Upload 17 files
25fc3a2
raw
history blame
682 Bytes
from transformers import AutoTokenizer, AutoModel
import torch
class EmbeddingGenerator:
def __init__(self):
self.model_name = "deepset/all-mpnet-base-v2-table"
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model = AutoModel.from_pretrained(self.model_name)
def generate_embeddings(self, dataframes):
embeddings = []
for df in dataframes:
inputs = self.tokenizer(df.to_string(index=False), return_tensors='pt', truncation=True, padding=True)
outputs = self.model(**inputs)
embeddings.append(outputs.last_hidden_state.mean(dim=1).detach().numpy())
return embeddings