import nltk nltk.download('stopwords') from transformers import AutoTokenizer from transformers import AutoModelForSeq2SeqLM import plotly.graph_objs as go from transformers import pipeline from matplotlib.colors import ListedColormap, rgb2hex import random import gradio as gr from tree import generate_subplot1, generate_subplot2 from paraphraser import generate_paraphrase from lcs import find_common_subsequences from highlighter import highlight_common_words, highlight_common_words_dict from entailment import analyze_entailment from masking_methods import mask_non_stopword, mask_non_stopword_pseudorandom, high_entropy_words from sampling_methods import sample_word # Function for the Gradio interface def model(prompt): user_prompt = prompt paraphrased_sentences = generate_paraphrase(user_prompt) analyzed_paraphrased_sentences, selected_sentences, discarded_sentences = analyze_entailment(user_prompt, paraphrased_sentences, 0.7) length_accepted_sentences = len(selected_sentences) common_grams = find_common_subsequences(user_prompt, selected_sentences) masked_sentences = [] masked_words = [] masked_logits = [] for sentence in paraphrased_sentences: masked_sent, logits, words = mask_non_stopword(sentence) masked_sentences.append(masked_sent) masked_words.append(words) masked_logits.append(logits) masked_sent, logits, words = mask_non_stopword_pseudorandom(sentence) masked_sentences.append(masked_sent) masked_words.append(words) masked_logits.append(logits) masked_sent, logits, words = high_entropy_words(sentence, common_grams) masked_sentences.append(masked_sent) masked_words.append(words) masked_logits.append(logits) sampled_sentences = [] for masked_sent, words, logits in zip(masked_sentences, masked_words, masked_logits): sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='inverse_transform', temperature=1.0)) sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='exponential_minimum', temperature=1.0)) sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='temperature', temperature=1.0)) sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='greedy', temperature=1.0)) print(len(sampled_sentences)) colors = ["red", "blue", "brown", "green"] def select_color(): return random.choice(colors) highlight_info = [(word, select_color()) for _, word in common_grams] highlighted_user_prompt = highlight_common_words(common_grams, [user_prompt], "Non-melting Points in the User Prompt") highlighted_accepted_sentences = highlight_common_words_dict(common_grams, selected_sentences, "Paraphrased Sentences") highlighted_discarded_sentences = highlight_common_words_dict(common_grams, discarded_sentences, "Discarded Sentences") trees1 = [] trees2 = [] masked_index = 0 sampled_index = 0 for i, sentence in enumerate(paraphrased_sentences): next_masked_sentences = masked_sentences[masked_index:masked_index + 3] next_sampled_sentences = sampled_sentences[sampled_index:sampled_index + 12] tree1 = generate_subplot1(sentence, next_masked_sentences, highlight_info, common_grams) trees1.append(tree1) tree2 = generate_subplot2(next_masked_sentences, next_sampled_sentences, highlight_info, common_grams) trees2.append(tree2) masked_index += 3 sampled_index += 12 return [highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + trees1 + trees2 with gr.Blocks(theme=gr.themes.Monochrome()) as demo: gr.Markdown("# **AIISC Watermarking Model**") with gr.Row(): user_input = gr.Textbox(label="User Prompt") with gr.Row(): submit_button = gr.Button("Submit") clear_button = gr.Button("Clear") with gr.Row(): highlighted_user_prompt = gr.HTML() with gr.Row(): with gr.Tabs(): with gr.TabItem("Paraphrased Sentences"): highlighted_accepted_sentences = gr.HTML() with gr.TabItem("Discarded Sentences"): highlighted_discarded_sentences = gr.HTML() # Adding labels before the tree plots with gr.Row(): gr.Markdown("### Where to Watermark?") # Label for masked sentences trees with gr.Row(): with gr.Tabs(): tree1_tabs = [] for i in range(10): # Adjust this range according to the number of trees with gr.TabItem(f"Sentence {i+1}"): tree1 = gr.Plot() tree1_tabs.append(tree1) with gr.Row(): gr.Markdown("### How to Watermark?") # Label for sampled sentences trees with gr.Row(): with gr.Tabs(): tree2_tabs = [] for i in range(10): # Adjust this range according to the number of trees with gr.TabItem(f"Sentence {i+1}"): tree2 = gr.Plot() tree2_tabs.append(tree2) submit_button.click(model, inputs=user_input, outputs=[highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + tree1_tabs + tree2_tabs) clear_button.click(lambda: "", inputs=None, outputs=user_input) clear_button.click(lambda: "", inputs=None, outputs=[highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + tree1_tabs + tree2_tabs) demo.launch(share=True)