import spaces import torch from transformers import GPT2LMHeadModel, GPT2Tokenizer import gradio as gr import os import spacy from spacy import displacy title = """ # 🙋🏻‍♂️Welcome to 🌟Tonic's 🎅🏻⌚OCRonos Vintage Text Gen This app generates historical-style text using the OCRonos-Vintage model. You can customize the generation parameters using the sliders and visualize the tokenized output and dependency parse. You can see a tokenized visualisation of the output and your input, and learn english using the visualization for the output text! ### Join us : 🌟TeamTonic🌟 is always making cool demos! Join our active builder's 🛠️community 👻 [![Join us on Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/qdfnvSPcqP) On 🤗Huggingface:[MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to🌟 [Build Tonic](https://git.tonic-ai.com/contribute)🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗 """ model_name = "PleIAs/OCRonos-Vintage" model = GPT2LMHeadModel.from_pretrained(model_name) tokenizer = GPT2Tokenizer.from_pretrained(model_name) tokenizer.pad_token = tokenizer.eos_token device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) os.system('python -m spacy download en_core_web_sm') nlp = spacy.load("en_core_web_sm") @spaces.GPU def historical_generation(prompt, max_new_tokens=600, top_k=50, temperature=0.7, top_p=0.95, repetition_penalty=1.0): # with torch.no_grad(): prompt = f"### Text ###\n{prompt}" inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024) input_ids = inputs["input_ids"].to(device) attention_mask = inputs["attention_mask"].to(device) output = model.generate( input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, pad_token_id=tokenizer.eos_token_id, top_k=top_k, temperature=temperature, top_p=top_p, do_sample=True, repetition_penalty=repetition_penalty, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id ) generated_text = tokenizer.decode(output[0], skip_special_tokens=True) if "### Correction ###" in generated_text: generated_text = generated_text.split("### Correction ###")[1].strip() tokens = tokenizer.tokenize(generated_text) highlighted_text = [] for token in tokens: clean_token = token.replace("Ġ", "") token_type = tokenizer.convert_ids_to_tokens([tokenizer.convert_tokens_to_ids(token)])[0].replace("Ġ", "") highlighted_text.append((clean_token, token_type)) del inputs, input_ids, attention_mask, output, tokens torch.cuda.empty_cache() return highlighted_text, generated_text @spaces.GPU def text_analysis(text): doc = nlp(text) html = displacy.render(doc, style="dep", page=True) html = ( "
" + html + "
" ) pos_count = { "char_count": len(text), "token_count": len(list(doc)), } pos_tokens = [(token.text, token.pos_) for token in doc] return pos_tokens, pos_count, html def generate_dependency_parse(generated_text): tokens_generated, pos_count_generated, html_generated = text_analysis(generated_text) return html_generated def display_dependency_parse(generated_text): return generate_dependency_parse(generated_text) def full_interface(prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty): # Generate historical-style text and tokenized output generated_highlight, generated_text = historical_generation( prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty ) # Analyze input text (dependency parse visualization) tokens_input, pos_count_input, html_input = text_analysis(prompt) # Generate dependency parse for the generated text dependency_parse_generated_html = generate_dependency_parse(generated_text) # Set the visibility of the generated text and highlight components return (generated_text, generated_highlight, pos_count_input, html_input, gr.update(visible=True), dependency_parse_generated_html, gr.update(visible=True), gr.update(visible=False)) def reset_interface(): return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True) with gr.Blocks(theme=gr.themes.Base()) as iface: gr.Markdown(title) prompt = gr.Textbox(label="Add a passage in the style of historical texts", placeholder="Hi there my name is Tonic and I ride my bicycle along the river Seine' he said", lines=2) max_new_tokens = gr.Slider(label="📏Length", minimum=50, maximum=1000, step=5, value=320) top_k = gr.Slider(label="🧪Sampling", minimum=1, maximum=100, step=1, value=50) temperature = gr.Slider(label="🎨Creativity", minimum=0.1, maximum=1, step=0.05, value=0.3) top_p = gr.Slider(label="👌🏻Quality", minimum=0.1, maximum=0.99, step=0.01, value=0.97) repetition_penalty = gr.Slider(label="🔴Repetition Penalty", minimum=0.5, maximum=2.0, step=0.05, value=1.3) generated_text_output = gr.Textbox(label="🎅🏻⌚OCRonos-Vintage") highlighted_text = gr.HighlightedText(label="🎅🏻⌚Tokenized", combine_adjacent=True, show_legend=True) tokenizer_info = gr.JSON(label="📉Tokenizer Info (Input Text)") dependency_parse_input = gr.HTML(label="👁️Visualization") dependency_parse_generated = gr.HTML(label="🎅🏻⌚Dependency Parse Visualization (Generated Text)") send_button = gr.Button(value="🎅🏻⌚OCRonos-Vintage 👁️Visualization", visible=False) reset_button = gr.Button(value="♻️Start Again", visible=False) generate_button = gr.Button(value="🎅🏻⌚Generate Historical Text") generate_button.click( full_interface, inputs=[prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty], outputs=[generated_text_output, highlighted_text, tokenizer_info, dependency_parse_input, send_button, dependency_parse_generated, generate_button, reset_button] ) send_button.click( display_dependency_parse, inputs=[generated_text_output], outputs=[dependency_parse_generated] ) reset_button.click( reset_interface, inputs=None, outputs=[generate_button, send_button, reset_button, generated_text_output, highlighted_text, tokenizer_info, dependency_parse_input, dependency_parse_generated] ) iface.launch()