|
import os |
|
import re |
|
|
|
import streamlit as st |
|
import torch |
|
from transformers import ( |
|
AutoModelForSeq2SeqLM, |
|
AutoTokenizer, |
|
) |
|
|
|
device = torch.cuda.device_count() - 1 |
|
|
|
|
|
def get_access_token(): |
|
try: |
|
if not os.path.exists(".streamlit/secrets.toml"): |
|
raise FileNotFoundError |
|
access_token = st.secrets.get("babel") |
|
except FileNotFoundError: |
|
access_token = os.environ.get("HF_ACCESS_TOKEN", None) |
|
return access_token |
|
|
|
|
|
@st.cache(suppress_st_warning=True, allow_output_mutation=True) |
|
def load_model(model_name): |
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_name, from_flax=True, use_auth_token=get_access_token() |
|
) |
|
if tokenizer.pad_token is None: |
|
print("Adding pad_token to the tokenizer") |
|
tokenizer.pad_token = tokenizer.eos_token |
|
model = AutoModelForSeq2SeqLM.from_pretrained( |
|
model_name, from_flax=True, use_auth_token=get_access_token() |
|
) |
|
if device != -1: |
|
model.to(f"cuda:{device}") |
|
return tokenizer, model |
|
|
|
|
|
class Generator: |
|
def __init__(self, model_name, task, desc, split_sentences): |
|
self.model_name = model_name |
|
self.task = task |
|
self.desc = desc |
|
self.tokenizer = None |
|
self.model = None |
|
self.prefix = "" |
|
self.gen_kwargs = { |
|
"max_length": 128, |
|
"num_beams": 6, |
|
"num_beam_groups": 3, |
|
"no_repeat_ngram_size": 0, |
|
"early_stopping": True, |
|
"num_return_sequences": 1, |
|
"length_penalty": 1.0, |
|
} |
|
self.load() |
|
|
|
def load(self): |
|
if not self.model: |
|
print(f"Loading model {self.model_name}") |
|
self.tokenizer, self.model = load_model(self.model_name) |
|
|
|
for key in self.gen_kwargs: |
|
if key in self.model.config.__dict__: |
|
self.gen_kwargs[key] = self.model.config.__dict__[key] |
|
print( |
|
"Setting", |
|
key, |
|
"to", |
|
self.gen_kwargs[key], |
|
"for model", |
|
self.model_name, |
|
) |
|
try: |
|
if self.task in self.model.config.task_specific_params: |
|
task_specific_params = self.model.config.task_specific_params[ |
|
self.task |
|
] |
|
self.prefix = ( |
|
task_specific_params["prefix"] |
|
if "prefix" in task_specific_params |
|
else "" |
|
) |
|
for key in self.gen_kwargs: |
|
if key in task_specific_params: |
|
self.gen_kwargs[key] = task_specific_params[key] |
|
except TypeError: |
|
pass |
|
|
|
def generate(self, text: str, **generate_kwargs) -> (str, dict): |
|
|
|
text = re.sub(r"\n{2,}", "\n", text) |
|
|
|
generate_kwargs = {**self.gen_kwargs, **generate_kwargs} |
|
batch_encoded = self.tokenizer( |
|
self.prefix + text, |
|
max_length=generate_kwargs["max_length"], |
|
padding=False, |
|
truncation=False, |
|
return_tensors="pt", |
|
) |
|
if device != -1: |
|
batch_encoded.to(f"cuda:{device}") |
|
logits = self.model.generate( |
|
batch_encoded["input_ids"], |
|
attention_mask=batch_encoded["attention_mask"], |
|
**generate_kwargs, |
|
) |
|
decoded_preds = self.tokenizer.batch_decode( |
|
logits.cpu().numpy(), skip_special_tokens=False |
|
) |
|
decoded_preds = [ |
|
pred.replace("<pad> ", "").replace("<pad>", "").replace("</s>", "") |
|
for pred in decoded_preds |
|
] |
|
return decoded_preds, generate_kwargs |
|
|
|
def __str__(self): |
|
return self.desc |
|
|
|
|
|
class GeneratorFactory: |
|
def __init__(self, generator_list): |
|
self.generators = [] |
|
for g in generator_list: |
|
with st.spinner(text=f"Loading the model {g['desc']} ..."): |
|
self.add_generator(**g) |
|
|
|
def add_generator(self, model_name, task, desc, split_sentences): |
|
|
|
if not self.get_generator(model_name=model_name, task=task, desc=desc): |
|
g = Generator(model_name, task, desc, split_sentences) |
|
g.load() |
|
self.generators.append(g) |
|
|
|
def get_generator(self, **kwargs): |
|
for g in self.generators: |
|
if all([g.__dict__.get(k) == v for k, v in kwargs.items()]): |
|
return g |
|
return None |
|
|
|
def __iter__(self): |
|
return iter(self.generators) |
|
|
|
def get_descriptions(self, task=None): |
|
return [g.desc for g in self.generators if task is None or task == g.task] |
|
|