jgyasu's picture
Upload folder using huggingface_hub
5d9cd0b verified
raw
history blame contribute delete
No virus
5.65 kB
import nltk
nltk.download('stopwords')
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM
import plotly.graph_objs as go
from transformers import pipeline
from matplotlib.colors import ListedColormap, rgb2hex
import random
import gradio as gr
from tree import generate_subplot1, generate_subplot2
from paraphraser import generate_paraphrase
from lcs import find_common_subsequences
from highlighter import highlight_common_words, highlight_common_words_dict
from entailment import analyze_entailment
from masking_methods import mask_non_stopword, mask_non_stopword_pseudorandom, high_entropy_words
from sampling_methods import sample_word
# Function for the Gradio interface
def model(prompt):
user_prompt = prompt
paraphrased_sentences = generate_paraphrase(user_prompt)
analyzed_paraphrased_sentences, selected_sentences, discarded_sentences = analyze_entailment(user_prompt, paraphrased_sentences, 0.7)
length_accepted_sentences = len(selected_sentences)
common_grams = find_common_subsequences(user_prompt, selected_sentences)
masked_sentences = []
masked_words = []
masked_logits = []
for sentence in paraphrased_sentences:
masked_sent, logits, words = mask_non_stopword(sentence)
masked_sentences.append(masked_sent)
masked_words.append(words)
masked_logits.append(logits)
masked_sent, logits, words = mask_non_stopword_pseudorandom(sentence)
masked_sentences.append(masked_sent)
masked_words.append(words)
masked_logits.append(logits)
masked_sent, logits, words = high_entropy_words(sentence, common_grams)
masked_sentences.append(masked_sent)
masked_words.append(words)
masked_logits.append(logits)
sampled_sentences = []
for masked_sent, words, logits in zip(masked_sentences, masked_words, masked_logits):
sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='inverse_transform', temperature=1.0))
sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='exponential_minimum', temperature=1.0))
sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='temperature', temperature=1.0))
sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='greedy', temperature=1.0))
print(len(sampled_sentences))
colors = ["red", "blue", "brown", "green"]
def select_color():
return random.choice(colors)
highlight_info = [(word, select_color()) for _, word in common_grams]
highlighted_user_prompt = highlight_common_words(common_grams, [user_prompt], "Non-melting Points in the User Prompt")
highlighted_accepted_sentences = highlight_common_words_dict(common_grams, selected_sentences, "Paraphrased Sentences")
highlighted_discarded_sentences = highlight_common_words_dict(common_grams, discarded_sentences, "Discarded Sentences")
trees1 = []
trees2 = []
masked_index = 0
sampled_index = 0
for i, sentence in enumerate(paraphrased_sentences):
next_masked_sentences = masked_sentences[masked_index:masked_index + 3]
next_sampled_sentences = sampled_sentences[sampled_index:sampled_index + 12]
tree1 = generate_subplot1(sentence, next_masked_sentences, highlight_info, common_grams)
trees1.append(tree1)
tree2 = generate_subplot2(next_masked_sentences, next_sampled_sentences, highlight_info, common_grams)
trees2.append(tree2)
masked_index += 3
sampled_index += 12
return [highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + trees1 + trees2
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
gr.Markdown("# **AIISC Watermarking Model**")
with gr.Row():
user_input = gr.Textbox(label="User Prompt")
with gr.Row():
submit_button = gr.Button("Submit")
clear_button = gr.Button("Clear")
with gr.Row():
highlighted_user_prompt = gr.HTML()
with gr.Row():
with gr.Tabs():
with gr.TabItem("Paraphrased Sentences"):
highlighted_accepted_sentences = gr.HTML()
with gr.TabItem("Discarded Sentences"):
highlighted_discarded_sentences = gr.HTML()
# Adding labels before the tree plots
with gr.Row():
gr.Markdown("### Where to Watermark?") # Label for masked sentences trees
with gr.Row():
with gr.Tabs():
tree1_tabs = []
for i in range(10): # Adjust this range according to the number of trees
with gr.TabItem(f"Sentence {i+1}"):
tree1 = gr.Plot()
tree1_tabs.append(tree1)
with gr.Row():
gr.Markdown("### How to Watermark?") # Label for sampled sentences trees
with gr.Row():
with gr.Tabs():
tree2_tabs = []
for i in range(10): # Adjust this range according to the number of trees
with gr.TabItem(f"Sentence {i+1}"):
tree2 = gr.Plot()
tree2_tabs.append(tree2)
submit_button.click(model, inputs=user_input, outputs=[highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + tree1_tabs + tree2_tabs)
clear_button.click(lambda: "", inputs=None, outputs=user_input)
clear_button.click(lambda: "", inputs=None, outputs=[highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + tree1_tabs + tree2_tabs)
demo.launch(share=True)