rosetta / generator.py
yhavinga's picture
Rename app to rosetta. Make two-column. Add some texts.
cdb537e
raw
history blame
No virus
5.44 kB
import os
import re
import streamlit as st
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
device = torch.cuda.device_count() - 1
def get_access_token():
try:
if not os.path.exists(".streamlit/secrets.toml"):
raise FileNotFoundError
access_token = st.secrets.get("babel")
except FileNotFoundError:
access_token = os.environ.get("HF_ACCESS_TOKEN", None)
return access_token
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def load_model(model_name):
os.environ["TOKENIZERS_PARALLELISM"] = "false"
tokenizer = AutoTokenizer.from_pretrained(
model_name, from_flax=True, use_auth_token=get_access_token()
)
if tokenizer.pad_token is None:
print("Adding pad_token to the tokenizer")
tokenizer.pad_token = tokenizer.eos_token
try:
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name, use_auth_token=get_access_token()
)
except EnvironmentError:
try:
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name, from_flax=True, use_auth_token=get_access_token()
)
except EnvironmentError:
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name, from_tf=True, use_auth_token=get_access_token()
)
if device != -1:
model.to(f"cuda:{device}")
return tokenizer, model
class Generator:
def __init__(self, model_name, task, desc, split_sentences):
self.model_name = model_name
self.task = task
self.desc = desc
self.split_sentences = split_sentences
self.tokenizer = None
self.model = None
self.prefix = ""
self.gen_kwargs = {
"max_length": 128,
"num_beams": 6,
"num_beam_groups": 3,
"no_repeat_ngram_size": 0,
"early_stopping": True,
"num_return_sequences": 1,
"length_penalty": 1.0,
}
self.load()
def load(self):
if not self.model:
print(f"Loading model {self.model_name}")
self.tokenizer, self.model = load_model(self.model_name)
for key in self.gen_kwargs:
if key in self.model.config.__dict__:
self.gen_kwargs[key] = self.model.config.__dict__[key]
try:
if self.task in self.model.config.task_specific_params:
task_specific_params = self.model.config.task_specific_params[
self.task
]
if "prefix" in task_specific_params:
self.prefix = task_specific_params["prefix"]
for key in self.gen_kwargs:
if key in task_specific_params:
self.gen_kwargs[key] = task_specific_params[key]
except TypeError:
pass
def generate(self, text: str, **generate_kwargs) -> (str, dict):
# Replace two or more newlines with a single newline in text
text = re.sub(r"\n{2,}", "\n", text)
generate_kwargs = {**self.gen_kwargs, **generate_kwargs}
# if there are newlines in the text, and the model needs line-splitting, split the text and recurse
if re.search(r"\n", text) and self.split_sentences:
lines = text.splitlines()
translated = [self.generate(line, **generate_kwargs)[0] for line in lines]
return "\n".join(translated), generate_kwargs
batch_encoded = self.tokenizer(
self.prefix + text,
max_length=generate_kwargs["max_length"],
padding=False,
truncation=False,
return_tensors="pt",
)
if device != -1:
batch_encoded.to(f"cuda:{device}")
logits = self.model.generate(
batch_encoded["input_ids"],
attention_mask=batch_encoded["attention_mask"],
**generate_kwargs,
)
decoded_preds = self.tokenizer.batch_decode(
logits.cpu().numpy(), skip_special_tokens=False
)
decoded_preds = [
pred.replace("<pad> ", "").replace("<pad>", "").replace("</s>", "")
for pred in decoded_preds
]
return decoded_preds[0], generate_kwargs
def __str__(self):
return self.model_name
class GeneratorFactory:
def __init__(self, generator_list):
self.generators = []
for g in generator_list:
with st.spinner(text=f"Loading the model {g['desc']} ..."):
self.add_generator(**g)
def add_generator(self, model_name, task, desc, split_sentences):
# If the generator is not yet present, add it
if not self.get_generator(model_name=model_name, task=task, desc=desc):
g = Generator(model_name, task, desc, split_sentences)
g.load()
self.generators.append(g)
def get_generator(self, **kwargs):
for g in self.generators:
if all([g.__dict__.get(k) == v for k, v in kwargs.items()]):
return g
return None
def __iter__(self):
return iter(self.generators)
def filter(self, **kwargs):
return [
g
for g in self.generators
if all([g.__dict__.get(k) == v for k, v in kwargs.items()])
]