|
import os |
|
import streamlit as st |
|
import torch |
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoModelForSeq2SeqLM, |
|
AutoTokenizer, |
|
) |
|
|
|
device = torch.cuda.device_count() - 1 |
|
|
|
TRANSLATION_NL_TO_EN = "translation_en_to_nl" |
|
|
|
|
|
@st.cache(suppress_st_warning=True, allow_output_mutation=True) |
|
def load_model(model_name, task): |
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
try: |
|
if not os.path.exists(".streamlit/secrets.toml"): |
|
raise FileNotFoundError |
|
access_token = st.secrets.get("babel") |
|
except FileNotFoundError: |
|
access_token = os.environ.get("HF_ACCESS_TOKEN", None) |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_name, from_flax=True, use_auth_token=access_token |
|
) |
|
if tokenizer.pad_token is None: |
|
print("Adding pad_token to the tokenizer") |
|
tokenizer.pad_token = tokenizer.eos_token |
|
auto_model_class = ( |
|
AutoModelForSeq2SeqLM if "translation" in task else AutoModelForCausalLM |
|
) |
|
model = auto_model_class.from_pretrained( |
|
model_name, from_flax=True, use_auth_token=access_token |
|
) |
|
if device != -1: |
|
model.to(f"cuda:{device}") |
|
return tokenizer, model |
|
|
|
|
|
class Generator: |
|
def __init__(self, model_name, task, desc): |
|
self.model_name = model_name |
|
self.task = task |
|
self.desc = desc |
|
self.tokenizer = None |
|
self.model = None |
|
self.prefix = "" |
|
self.load() |
|
|
|
def load(self): |
|
if not self.model: |
|
print(f"Loading model {self.model_name}") |
|
self.tokenizer, self.model = load_model(self.model_name, self.task) |
|
|
|
try: |
|
if self.task in self.model.config.task_specific_params: |
|
task_specific_params = self.model.config.task_specific_params[ |
|
self.task |
|
] |
|
if "prefix" in task_specific_params: |
|
self.prefix = task_specific_params["prefix"] |
|
except TypeError: |
|
pass |
|
|
|
def generate(self, text: str, **generate_kwargs) -> str: |
|
|
|
|
|
|
|
|
|
batch_encoded = self.tokenizer( |
|
self.prefix + text, |
|
max_length=generate_kwargs["max_length"], |
|
padding=False, |
|
truncation=False, |
|
return_tensors="pt", |
|
) |
|
if device != -1: |
|
batch_encoded.to(f"cuda:{device}") |
|
logits = self.model.generate( |
|
batch_encoded["input_ids"], |
|
attention_mask=batch_encoded["attention_mask"], |
|
**generate_kwargs, |
|
) |
|
decoded_preds = self.tokenizer.batch_decode( |
|
logits.cpu().numpy(), skip_special_tokens=False |
|
) |
|
decoded_preds = [ |
|
pred.replace("<pad> ", "").replace("<pad>", "").replace("</s>", "") |
|
for pred in decoded_preds |
|
] |
|
return decoded_preds |
|
|
|
|
|
|
|
def __str__(self): |
|
return self.desc |
|
|
|
|
|
class GeneratorFactory: |
|
def __init__(self, generator_list): |
|
self.generators = [] |
|
for g in generator_list: |
|
with st.spinner(text=f"Loading the model {g['desc']} ..."): |
|
self.add_generator(**g) |
|
|
|
def add_generator(self, model_name, task, desc): |
|
|
|
if not self.get_generator(model_name=model_name, task=task, desc=desc): |
|
g = Generator(model_name, task, desc) |
|
g.load() |
|
self.generators.append(g) |
|
|
|
def get_generator(self, **kwargs): |
|
for g in self.generators: |
|
if all([g.__dict__.get(k) == v for k, v in kwargs.items()]): |
|
return g |
|
return None |
|
|
|
def __iter__(self): |
|
return iter(self.generators) |
|
|
|
def gpt_descs(self): |
|
return [g.desc for g in self.generators if g.task == TRANSLATION_NL_TO_EN] |
|
|