import random import os import gradio as gr import torch from transformers import pipeline, set_seed from transformers import AutoTokenizer, AutoModelForCausalLM import logging logger = logging.getLogger() logger.addHandler(logging.StreamHandler()) HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN", None) DEVICE = os.environ.get("DEVICE", "cpu") # cuda:0 if DEVICE != "cpu" and not torch.cuda.is_available(): DEVICE = "cpu" logger.info(f"DEVICE {DEVICE}") DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16 MODEL_NAME = os.environ.get("MODEL_NAME", "bertin-project/bertin-gpt-j-6B") MAX_LENGTH = int(os.environ.get("MAX_LENGTH", 1024)) HEADER_INFO = """ # BERTIN GPT-J-6B Spanish BERTIN GPT-J-6B Model. """.strip() LOGO = "https://huggingface.co/bertin-project/bertin-roberta-base-spanish/resolve/main/images/bertin.png" HEADER = f"""
# BERTIN GPT-J-6B BERTIN proporciona una serie de modelos de lenguaje en Español entrenados en abierto. Este modelo ha sido entrenado con [Mesh Transformer JAX](https://github.com/kingoflolz/mesh-transformer-jax) en TPUs proporcionadas por Google a través del programa Tensor Research Cloud, a partir del modelo [GPT-J de EleutherAI](https://huggingface.co/EleutherAI/gpt-j-6B) con el corpus [mC4-es-sampled (gaussian)](https://huggingface.co/datasets/bertin-project/mc4-es-sampled). Esta demo funciona sobre una GPU proporcionada por HuggingFace.
""" FOOTER = """ Para más información, visite el [repositorio del modelo](https://huggingface.co/bertin-project/bertin-gpt-j-6B). """.strip() class Normalizer: def remove_repetitions(self, text): """Remove repetitions""" first_ocurrences = [] for sentence in text.split("."): if sentence not in first_ocurrences: first_ocurrences.append(sentence) return '.'.join(first_ocurrences) def trim_last_sentence(self, text): """Trim last sentence if incomplete""" return text[:text.rfind(".") + 1] def clean_txt(self, text): return self.trim_last_sentence(self.remove_repetitions(text)) class TextGeneration: def __init__(self): self.tokenizer = None self.generator = None self.task = "text-generation" self.model_name_or_path = MODEL_NAME set_seed(42) def load(self): logger.info("Loading model...") self.tokenizer = AutoTokenizer.from_pretrained( self.model_name_or_path, use_auth_token=HF_AUTH_TOKEN if HF_AUTH_TOKEN else None, ) self.model = AutoModelForCausalLM.from_pretrained( self.model_name_or_path, use_auth_token=HF_AUTH_TOKEN if HF_AUTH_TOKEN else None, pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id, torch_dtype=DTYPE, low_cpu_mem_usage=False if DEVICE == "cpu" else True ).to(device=DEVICE, non_blocking=False) _ = self.model.eval() device_number = -1 if DEVICE == "cpu" else int(DEVICE.split(":")[-1]) self.generator = pipeline(self.task, model=self.model, tokenizer=self.tokenizer, device=device_number) logger.info("Loading model done.") # with torch.no_grad(): # tokens = tokenizer.encode(prompt, return_tensors='pt').to(device=device, non_blocking=True) # gen_tokens = self.model.generate(tokens, do_sample=True, temperature=0.8, max_length=128) # generated = tokenizer.batch_decode(gen_tokens)[0] # return generated def generate(self, text, generation_kwargs): max_length = len(self.tokenizer(text)["input_ids"]) + generation_kwargs["max_length"] generation_kwargs["max_length"] = min(max_length, self.model.config.n_positions) # generation_kwargs["num_return_sequences"] = 1 # generation_kwargs["return_full_text"] = False generated_text = None if text: for _ in range(10): generated_text = self.generator( text, **generation_kwargs, )[0]["generated_text"] if generation_kwargs["do_clean"]: generated_text = cleaner.clean_txt(generated_text) if generated_text.strip().startswith(text): generated_text = generated_text.replace(text, "", 1).strip() if generated_text: return ( text + " " + generated_text, [(text, None), (generated_text, "BERTIN")] ) if not generated_text: return ( "", [("Tras 10 intentos BERTIN no generó nada. Pruebe cambiando las opciones", "ERROR")] ) # return (text + " " + generated_text, # f'

' # f'{text} ' # f'{generated_text}' # f'

' # ) #@st.cache(hash_funcs={torch.nn.parameter.Parameter: lambda _: None}) #@st.cache(allow_output_mutation=True) #@st.cache(allow_output_mutation=True, hash_funcs={TextGeneration: lambda _: None}) def load_text_generator(): text_generator = TextGeneration() text_generator.load() return text_generator cleaner = Normalizer() generator = load_text_generator() def complete_with_gpt(text, max_length, top_k, top_p, temperature, do_sample, do_clean): generation_kwargs = { "max_length": max_length, "top_k": top_k, "top_p": top_p, "temperature": temperature, "do_sample": do_sample, "do_clean": do_clean, } return generator.generate(text, generation_kwargs) with gr.Blocks() as demo: gr.Markdown(HEADER) with gr.Row(): with gr.Group(): with gr.Box(): gr.Markdown("Opciones") max_length = gr.Slider( label='Longitud máxima', # help="Número máximo (aproximado) de palabras a generar.", minimum=1, maximum=MAX_LENGTH, value=50, step=1 ) top_k = gr.Slider( label='Top-k', # help="Número de palabras con alta probabilidad a mantener para el filtrado `top-k`", minimum=40, maximum=80, value=50, step=1 ) top_p = gr.Slider( label='Top-p', # help="Solo las palabras más probables con probabilidades que sumen `top_p` o más se mantienen para la generación.", minimum=0.0, maximum=1.0, value=0.95, step=0.01 ) temperature = gr.Slider( label='Temperatura', # help="Valor utilizado para modular las probabilidades de las siguientes palabras generadas.", minimum=0.1, maximum=10.0, value=0.8, step=0.05 ) do_sample = gr.Checkbox( label='¿Muestrear?', value = True, # options=(True, False), # help="Si no se muestrea se usará una decodificación voraz (_greedy_).", ) do_clean = gr.Checkbox( label='¿Limpiar texto?', value = True, # options=(True, False), # help="Si eliminar o no las palabras repetidas y recortar las últimas frases sin terminar.", ) with gr.Column(): textbox = gr.Textbox(label="Texto",placeholder="Escriba algo y pulse 'Generar'...", lines=8) hidden = gr.Textbox(visible=False, show_label=False) with gr.Box(): # output = gr.Markdown() output = gr.HighlightedText(label="Resultado", combine_adjacent=True, color_map={"BERTIN": "green", "ERROR": "red"}) with gr.Row(): btn = gr.Button("Generar") btn.click(complete_with_gpt, inputs=[textbox, max_length, top_k, top_p, temperature, do_sample, do_clean], outputs=[hidden, output]) edit_btn = gr.Button("Editar", variant="secondary") edit_btn.click(lambda x: (x, "", []), inputs=[hidden], outputs=[textbox, hidden, output]) clean_btn = gr.Button("Limpiar", variant="secondary") clean_btn.click(lambda: ("", "", []), inputs=[], outputs=[textbox, hidden, output]) gr.Markdown(FOOTER) demo.launch() # gr.Interface(complete_with_gpt, inputs=[textbox, max_length, top_k, top_p, temperature, do_sample, do_clean], outputs=[hidden, output]).launch()