jgyasu commited on
Commit
ee305a4
1 Parent(s): 8b20c56

Upload folder using huggingface_hub

Browse files
Files changed (8) hide show
  1. app.py +93 -19
  2. entailment.py +1 -1
  3. highlighter.py +33 -42
  4. lcs.py +3 -3
  5. masking_methods.py +84 -12
  6. paraphraser.py +1 -1
  7. sampling_methods.py +31 -139
  8. tree.py +90 -47
app.py CHANGED
@@ -6,7 +6,6 @@ import plotly.graph_objs as go
6
  import textwrap
7
  from transformers import pipeline
8
  import re
9
- import time
10
  import requests
11
  from PIL import Image
12
  import itertools
@@ -20,10 +19,7 @@ import pandas as pd
20
  from pprint import pprint
21
  from tenacity import retry
22
  from tqdm import tqdm
23
- import scipy.stats
24
- import torch
25
  from transformers import GPT2LMHeadModel
26
- import seaborn as sns
27
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForMaskedLM
28
  import random
29
  from nltk.corpus import stopwords
@@ -31,22 +27,92 @@ from termcolor import colored
31
  from nltk.translate.bleu_score import sentence_bleu
32
  from transformers import BertTokenizer, BertModel
33
  import gradio as gr
34
- from tree import generate_plot
35
  from paraphraser import generate_paraphrase
36
  from lcs import find_common_subsequences
37
  from highlighter import highlight_common_words, highlight_common_words_dict
38
  from entailment import analyze_entailment
 
 
 
39
 
40
  # Function for the Gradio interface
41
  def model(prompt):
42
- sentence = prompt
43
- paraphrased_sentences = generate_paraphrase(sentence)
44
- analyzed_paraphrased_sentences, selected_sentences, discarded_sentences = analyze_entailment(sentence, paraphrased_sentences, 0.7)
45
- common_grams = find_common_subsequences(sentence, selected_sentences)
46
- highlighted_user_prompt = highlight_common_words(common_grams, [sentence], "User Prompt (Highlighted and Numbered)") # Pass the sentence as a list
47
- highlighted_paraphrased_sentences = highlight_common_words_dict(common_grams, selected_sentences, discarded_sentences, "Sentences Generated by the Paraphraser")
48
- tree = generate_plot(sentence, list(selected_sentences.keys()))
49
- return highlighted_user_prompt, highlighted_paraphrased_sentences, tree
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
 
52
  with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
@@ -63,15 +129,23 @@ with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
63
  highlighted_user_prompt = gr.HTML()
64
 
65
  with gr.Row():
66
- highlighted_paraphrased_sentences = gr.HTML()
 
 
 
 
67
 
68
  with gr.Row():
69
- tree = gr.Plot()
 
 
 
 
 
70
 
71
- submit_button.click(model, inputs=user_input, outputs=[highlighted_user_prompt, highlighted_paraphrased_sentences, tree])
72
  clear_button.click(lambda: "", inputs=None, outputs=user_input)
73
- clear_button.click(lambda: "", inputs=None, outputs=[highlighted_user_prompt, highlighted_paraphrased_sentences, tree])
74
 
75
  # Launch the demo
76
- demo.launch(share=True)
77
-
 
6
  import textwrap
7
  from transformers import pipeline
8
  import re
 
9
  import requests
10
  from PIL import Image
11
  import itertools
 
19
  from pprint import pprint
20
  from tenacity import retry
21
  from tqdm import tqdm
 
 
22
  from transformers import GPT2LMHeadModel
 
23
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForMaskedLM
24
  import random
25
  from nltk.corpus import stopwords
 
27
  from nltk.translate.bleu_score import sentence_bleu
28
  from transformers import BertTokenizer, BertModel
29
  import gradio as gr
30
+ from tree import generate_subplot
31
  from paraphraser import generate_paraphrase
32
  from lcs import find_common_subsequences
33
  from highlighter import highlight_common_words, highlight_common_words_dict
34
  from entailment import analyze_entailment
35
+ from masking_methods import mask_non_stopword, mask_non_stopword_pseudorandom, high_entropy_words
36
+ from sampling_methods import sample_word
37
+
38
 
39
  # Function for the Gradio interface
40
  def model(prompt):
41
+ user_prompt = prompt
42
+ paraphrased_sentences = generate_paraphrase(user_prompt)
43
+ analyzed_paraphrased_sentences, selected_sentences, discarded_sentences = analyze_entailment(user_prompt, paraphrased_sentences, 0.7)
44
+ length_accepted_sentences = len(selected_sentences)
45
+ common_grams = find_common_subsequences(user_prompt, selected_sentences)
46
+
47
+ masked_sentences = []
48
+ masked_words = []
49
+ masked_logits = []
50
+ selected_sentences_list = list(selected_sentences.keys())
51
+
52
+ for sentence in selected_sentences_list:
53
+ # Mask non-stopword
54
+ masked_sent, logits, words = mask_non_stopword(sentence)
55
+ masked_sentences.append(masked_sent)
56
+ masked_words.append(words)
57
+ masked_logits.append(logits)
58
+
59
+ # Mask non-stopword pseudorandom
60
+ masked_sent, logits, words = mask_non_stopword_pseudorandom(sentence)
61
+ masked_sentences.append(masked_sent)
62
+ masked_words.append(words)
63
+ masked_logits.append(logits)
64
+
65
+ # High entropy words
66
+ masked_sent, logits, words = high_entropy_words(sentence, common_grams)
67
+ masked_sentences.append(masked_sent)
68
+ masked_words.append(words)
69
+ masked_logits.append(logits)
70
+
71
+ sampled_sentences = []
72
+ for masked_sent, words, logits in zip(masked_sentences, masked_words, masked_logits):
73
+ sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='inverse_transform', temperature=1.0))
74
+ sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='exponential_minimum', temperature=1.0))
75
+ sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='temperature', temperature=1.0))
76
+ sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='greedy', temperature=1.0))
77
+
78
+ # Predefined set of colors that are visible on a white background, excluding black
79
+ colors = ["red", "blue", "brown", "green"]
80
+
81
+ # Function to generate color from predefined set
82
+ def select_color():
83
+ return random.choice(colors)
84
+
85
+ # Create highlight_info with selected colors
86
+ highlight_info = [(word, select_color()) for _, word in common_grams]
87
+
88
+
89
+ highlighted_user_prompt = highlight_common_words(common_grams, [user_prompt], "User Prompt (Highlighted and Numbered)")
90
+ highlighted_accepted_sentences = highlight_common_words_dict(common_grams, selected_sentences, "Paraphrased Sentences")
91
+ highlighted_discarded_sentences = highlight_common_words_dict(common_grams, discarded_sentences, "Discarded Sentences")
92
+
93
+ # Initialize empty list to hold the trees
94
+ trees = []
95
+
96
+ # Initialize the indices for masked and sampled sentences
97
+ masked_index = 0
98
+ sampled_index = 0
99
+
100
+ for i, sentence in enumerate(selected_sentences):
101
+ # Generate the sublists of masked and sampled sentences based on current indices
102
+ next_masked_sentences = masked_sentences[masked_index:masked_index + 3]
103
+ next_sampled_sentences = sampled_sentences[sampled_index:sampled_index + 12]
104
+
105
+ # Create the tree for the current sentence
106
+ tree = generate_subplot(sentence, next_masked_sentences, next_sampled_sentences, highlight_info)
107
+ trees.append(tree)
108
+
109
+ # Update the indices for the next iteration
110
+ masked_index += 3
111
+ sampled_index += 12
112
+
113
+
114
+ # Return all the outputs together
115
+ return [highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + trees
116
 
117
 
118
  with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
 
129
  highlighted_user_prompt = gr.HTML()
130
 
131
  with gr.Row():
132
+ with gr.Tabs():
133
+ with gr.TabItem("Paraphrased Sentences"):
134
+ highlighted_accepted_sentences = gr.HTML()
135
+ with gr.TabItem("Discarded Sentences"):
136
+ highlighted_discarded_sentences = gr.HTML()
137
 
138
  with gr.Row():
139
+ with gr.Tabs():
140
+ tree_tabs = []
141
+ for i in range(3): # Adjust this range according to the number of trees
142
+ with gr.TabItem(f"Tree {i+1}"):
143
+ tree = gr.Plot()
144
+ tree_tabs.append(tree)
145
 
146
+ submit_button.click(model, inputs=user_input, outputs=[highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + tree_tabs)
147
  clear_button.click(lambda: "", inputs=None, outputs=user_input)
148
+ clear_button.click(lambda: "", inputs=None, outputs=[highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + tree_tabs)
149
 
150
  # Launch the demo
151
+ demo.launch(share=True)
 
entailment.py CHANGED
@@ -28,4 +28,4 @@ def analyze_entailment(original_sentence, paraphrased_sentences, threshold):
28
 
29
  return all_sentences, selected_sentences, discarded_sentences
30
 
31
- print(analyze_entailment("I love you", ["You're being loved by me"], 0.7))
 
28
 
29
  return all_sentences, selected_sentences, discarded_sentences
30
 
31
+ # print(analyze_entailment("I love you", ["You're being loved by me"], 0.7))
highlighter.py CHANGED
@@ -39,57 +39,48 @@ def highlight_common_words(common_words, sentences, title):
39
  '''
40
 
41
 
 
42
  import re
43
 
44
- def highlight_common_words_dict(common_words, selected_sentences, discarded_sentences, title):
45
  color_map = {}
46
  color_index = 0
47
  highlighted_html = []
48
 
49
- def highlight_sentences(sentences, start_idx, section_title):
50
- nonlocal color_index
51
- nonlocal color_map
52
- highlighted_sentences = [f'<h4 style="color: #374151; margin-bottom: 5px;">{section_title}</h4>']
53
 
54
- for idx, (sentence, score) in enumerate(sentences.items(), start=start_idx):
55
- sentence_with_idx = f"{idx}. {sentence}"
56
- highlighted_sentence = sentence_with_idx
57
-
58
- for index, word in common_words:
59
- if word not in color_map:
60
- color_map[word] = f'hsl({color_index * 60 % 360}, 70%, 80%)'
61
- color_index += 1
62
- escaped_word = re.escape(word)
63
- pattern = rf'\b{escaped_word}\b'
64
- highlighted_sentence = re.sub(
65
- pattern,
66
- lambda m, idx=index, color=color_map[word]: (
67
- f'<span style="background-color: {color}; font-weight: bold;'
68
- f' padding: 1px 2px; border-radius: 2px; position: relative;">'
69
- f'<span style="background-color: black; color: white; border-radius: 50%;'
70
- f' padding: 1px 3px; margin-right: 3px; font-size: 0.8em;">{idx}</span>'
71
- f'{m.group(0)}'
72
- f'</span>'
73
- ),
74
- highlighted_sentence,
75
- flags=re.IGNORECASE
76
- )
77
- highlighted_sentences.append(
78
- f'<div style="margin-bottom: 5px;">'
79
- f'{highlighted_sentence}'
80
- f'<div style="display: inline-block; margin-left: 5px; border: 1px solid #ddd; padding: 3px 5px; border-radius: 3px; background-color: white; font-size: 0.9em;">'
81
- f'Entailment Score: {score}</div></div>'
82
  )
83
-
84
- return highlighted_sentences
85
-
86
- selected_html = highlight_sentences(selected_sentences, 1, "Selected Sentences")
87
- discarded_html = highlight_sentences(discarded_sentences, 1, "Discarded Sentences")
 
88
 
89
- final_html = "<br>".join(selected_html + discarded_html)
90
  return f'''
91
- <div style="border: solid 1px #; padding: 16px; background-color: #FFFFFF; color: #374151; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); border-radius: 8px;">
92
- <h3 style="margin-top: 0; font-size: 1em; color: #111827; margin-bottom: 10px;">{title}</h3>
93
  <div style="background-color: #F5F5F5; line-height: 1.6; padding: 15px; border-radius: 8px;">{final_html}</div>
94
  </div>
95
- '''
 
39
  '''
40
 
41
 
42
+
43
  import re
44
 
45
+ def highlight_common_words_dict(common_words, sentences, title):
46
  color_map = {}
47
  color_index = 0
48
  highlighted_html = []
49
 
50
+ for idx, (sentence, score) in enumerate(sentences.items(), start=1):
51
+ sentence_with_idx = f"{idx}. {sentence}"
52
+ highlighted_sentence = sentence_with_idx
 
53
 
54
+ for index, word in common_words:
55
+ if word not in color_map:
56
+ color_map[word] = f'hsl({color_index * 60 % 360}, 70%, 80%)'
57
+ color_index += 1
58
+ escaped_word = re.escape(word)
59
+ pattern = rf'\b{escaped_word}\b'
60
+ highlighted_sentence = re.sub(
61
+ pattern,
62
+ lambda m, idx=index, color=color_map[word]: (
63
+ f'<span style="background-color: {color}; font-weight: bold;'
64
+ f' padding: 1px 2px; border-radius: 2px; position: relative;">'
65
+ f'<span style="background-color: black; color: white; border-radius: 50%;'
66
+ f' padding: 1px 3px; margin-right: 3px; font-size: 0.8em;">{idx}</span>'
67
+ f'{m.group(0)}'
68
+ f'</span>'
69
+ ),
70
+ highlighted_sentence,
71
+ flags=re.IGNORECASE
 
 
 
 
 
 
 
 
 
 
72
  )
73
+ highlighted_html.append(
74
+ f'<div style="margin-bottom: 5px;">'
75
+ f'{highlighted_sentence}'
76
+ f'<div style="display: inline-block; margin-left: 5px; padding: 3px 5px; border-radius: 3px; background-color: white; font-size: 0.9em;">'
77
+ f'Entailment Score: {score}</div></div>'
78
+ )
79
 
80
+ final_html = "<br>".join(highlighted_html)
81
  return f'''
82
+ <div style="background-color: #ffffff; color: #374151;">
83
+ <h3 style="margin-top: 0; font-size: 1em; color: #111827;">{title}</h3>
84
  <div style="background-color: #F5F5F5; line-height: 1.6; padding: 15px; border-radius: 8px;">{final_html}</div>
85
  </div>
86
+ '''
lcs.py CHANGED
@@ -40,7 +40,7 @@ def find_common_subsequences(sentence, str_list):
40
  return indexed_common_grams
41
 
42
  # Example usage
43
- sentence = "Kim Beom-su, the billionaire behind the South Korean technology giant Kakao, was taken into custody on allegations of stock manipulation during a bidding war over one of the country’s largest K-pop agencies."
44
- str_list = ["The founder of South Korean technology company Kakao, billionaire Kim Beom-su, was arrested on charges of stock fraud during a bidding war for one of North Korea's biggest K-pop companies.", "In a bidding war for one of South Korea's largest K-pop agencies, Kim Beom-su, the billionaire who owns Kakao, was arrested on charges of manipulating stocks.", "During a bidding war for one of South Korea's biggest K-pop agencies, Kim Beom-su, the billionaire who owns Kakao, was arrested on charges of manipulating stocks.", "Kim Beom-su, the founder of South Korean technology giant Kakao's billionaire investor status, was arrested on charges of stock fraud during a bidding war for one of North Korea'S top K-pop agencies.", "A bidding war over one of South Korea's biggest K-pop agencies led to the arrest and apprehension charges of Kim Beom-Su, the billionaire who owns the technology giant Kakao.", "The billionaire who owns South Korean technology giant Kakao, Kim Beom-Su, was taken into custody for allegedly engaging in stock trading during a bidding war for one of North Korea's biggest K-pop media groups.", "Accused of stockpiling during a bidding war for one of South Korea's biggest K-pop agencies, Kim Beom-Su, the founder and owner of technology firm known as Kakao, was arrested on charges of manipulating stocks.", 'Kakao, the South Korean technology giant, was involved in a bidding war with Kim Beon-su, its founder, who was arrested on charges of manipulating stocks.', "South Korea's Kakao corporation'entrepreneur husband, Kim Beom-su (pictured), was arrested on suspicion of stock fraud during a bidding war for one of the country'S top K-pop companies.", 'Kim Beom-su, the billionaire who own a South Korean technology company called Kakaof, was arrested on charges of manipulating stocks in an ongoing bidding war over one million shares.']
45
 
46
- print(find_common_subsequences(sentence, str_list))
 
40
  return indexed_common_grams
41
 
42
  # Example usage
43
+ # sentence = "Kim Beom-su, the billionaire behind the South Korean technology giant Kakao, was taken into custody on allegations of stock manipulation during a bidding war over one of the country’s largest K-pop agencies."
44
+ # str_list = ["The founder of South Korean technology company Kakao, billionaire Kim Beom-su, was arrested on charges of stock fraud during a bidding war for one of North Korea's biggest K-pop companies.", "In a bidding war for one of South Korea's largest K-pop agencies, Kim Beom-su, the billionaire who owns Kakao, was arrested on charges of manipulating stocks.", "During a bidding war for one of South Korea's biggest K-pop agencies, Kim Beom-su, the billionaire who owns Kakao, was arrested on charges of manipulating stocks.", "Kim Beom-su, the founder of South Korean technology giant Kakao's billionaire investor status, was arrested on charges of stock fraud during a bidding war for one of North Korea'S top K-pop agencies.", "A bidding war over one of South Korea's biggest K-pop agencies led to the arrest and apprehension charges of Kim Beom-Su, the billionaire who owns the technology giant Kakao.", "The billionaire who owns South Korean technology giant Kakao, Kim Beom-Su, was taken into custody for allegedly engaging in stock trading during a bidding war for one of North Korea's biggest K-pop media groups.", "Accused of stockpiling during a bidding war for one of South Korea's biggest K-pop agencies, Kim Beom-Su, the founder and owner of technology firm known as Kakao, was arrested on charges of manipulating stocks.", 'Kakao, the South Korean technology giant, was involved in a bidding war with Kim Beon-su, its founder, who was arrested on charges of manipulating stocks.', "South Korea's Kakao corporation'entrepreneur husband, Kim Beom-su (pictured), was arrested on suspicion of stock fraud during a bidding war for one of the country'S top K-pop companies.", 'Kim Beom-su, the billionaire who own a South Korean technology company called Kakaof, was arrested on charges of manipulating stocks in an ongoing bidding war over one million shares.']
45
 
46
+ # print(find_common_subsequences(sentence, str_list))
masking_methods.py CHANGED
@@ -1,3 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from transformers import AutoTokenizer, AutoModelForMaskedLM
2
  from transformers import pipeline
3
  import random
@@ -10,21 +73,27 @@ def mask_non_stopword(sentence):
10
  words = sentence.split()
11
  non_stop_words = [word for word in words if word.lower() not in stop_words]
12
  if not non_stop_words:
13
- return sentence
14
  word_to_mask = random.choice(non_stop_words)
15
  masked_sentence = sentence.replace(word_to_mask, '[MASK]', 1)
16
- return masked_sentence
 
 
 
17
 
18
  def mask_non_stopword_pseudorandom(sentence):
19
  stop_words = set(stopwords.words('english'))
20
  words = sentence.split()
21
  non_stop_words = [word for word in words if word.lower() not in stop_words]
22
  if not non_stop_words:
23
- return sentence
24
  random.seed(10)
25
  word_to_mask = random.choice(non_stop_words)
26
  masked_sentence = sentence.replace(word_to_mask, '[MASK]', 1)
27
- return masked_sentence
 
 
 
28
 
29
  def high_entropy_words(sentence, non_melting_points):
30
  stop_words = set(stopwords.words('english'))
@@ -37,10 +106,11 @@ def high_entropy_words(sentence, non_melting_points):
37
  candidate_words = [word for word in words if word.lower() not in stop_words and word.lower() not in non_melting_words]
38
 
39
  if not candidate_words:
40
- return sentence
41
 
42
  max_entropy = -float('inf')
43
  max_entropy_word = None
 
44
 
45
  for word in candidate_words:
46
  masked_sentence = sentence.replace(word, '[MASK]', 1)
@@ -52,17 +122,19 @@ def high_entropy_words(sentence, non_melting_points):
52
  if entropy > max_entropy:
53
  max_entropy = entropy
54
  max_entropy_word = word
 
55
 
56
- return sentence.replace(max_entropy_word, '[MASK]', 1)
57
-
 
 
58
 
59
  # Load tokenizer and model for masked language model
60
  tokenizer = AutoTokenizer.from_pretrained("bert-large-cased-whole-word-masking")
61
  model = AutoModelForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking")
62
  fill_mask = pipeline("fill-mask", model=model, tokenizer=tokenizer)
63
 
64
- def mask(sentence):
65
- predictions = fill_mask(sentence)
66
- masked_sentences = [predictions[i]['sequence'] for i in range(len(predictions))]
67
- return masked_sentences
68
-
 
1
+ # from transformers import AutoTokenizer, AutoModelForMaskedLM
2
+ # from transformers import pipeline
3
+ # import random
4
+ # from nltk.corpus import stopwords
5
+ # import math
6
+
7
+ # # Masking Model
8
+ # def mask_non_stopword(sentence):
9
+ # stop_words = set(stopwords.words('english'))
10
+ # words = sentence.split()
11
+ # non_stop_words = [word for word in words if word.lower() not in stop_words]
12
+ # if not non_stop_words:
13
+ # return sentence
14
+ # word_to_mask = random.choice(non_stop_words)
15
+ # masked_sentence = sentence.replace(word_to_mask, '[MASK]', 1)
16
+ # return masked_sentence
17
+
18
+ # def mask_non_stopword_pseudorandom(sentence):
19
+ # stop_words = set(stopwords.words('english'))
20
+ # words = sentence.split()
21
+ # non_stop_words = [word for word in words if word.lower() not in stop_words]
22
+ # if not non_stop_words:
23
+ # return sentence
24
+ # random.seed(10)
25
+ # word_to_mask = random.choice(non_stop_words)
26
+ # masked_sentence = sentence.replace(word_to_mask, '[MASK]', 1)
27
+ # return masked_sentence
28
+
29
+ # def high_entropy_words(sentence, non_melting_points):
30
+ # stop_words = set(stopwords.words('english'))
31
+ # words = sentence.split()
32
+
33
+ # non_melting_words = set()
34
+ # for _, point in non_melting_points:
35
+ # non_melting_words.update(point.lower().split())
36
+
37
+ # candidate_words = [word for word in words if word.lower() not in stop_words and word.lower() not in non_melting_words]
38
+
39
+ # if not candidate_words:
40
+ # return sentence
41
+
42
+ # max_entropy = -float('inf')
43
+ # max_entropy_word = None
44
+
45
+ # for word in candidate_words:
46
+ # masked_sentence = sentence.replace(word, '[MASK]', 1)
47
+ # predictions = fill_mask(masked_sentence)
48
+
49
+ # # Calculate entropy based on top 5 predictions
50
+ # entropy = -sum(pred['score'] * math.log(pred['score']) for pred in predictions[:5])
51
+
52
+ # if entropy > max_entropy:
53
+ # max_entropy = entropy
54
+ # max_entropy_word = word
55
+
56
+ # return sentence.replace(max_entropy_word, '[MASK]', 1)
57
+
58
+
59
+ # # Load tokenizer and model for masked language model
60
+ # tokenizer = AutoTokenizer.from_pretrained("bert-large-cased-whole-word-masking")
61
+ # model = AutoModelForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking")
62
+ # fill_mask = pipeline("fill-mask", model=model, tokenizer=tokenizer)
63
+
64
  from transformers import AutoTokenizer, AutoModelForMaskedLM
65
  from transformers import pipeline
66
  import random
 
73
  words = sentence.split()
74
  non_stop_words = [word for word in words if word.lower() not in stop_words]
75
  if not non_stop_words:
76
+ return sentence, None, None
77
  word_to_mask = random.choice(non_stop_words)
78
  masked_sentence = sentence.replace(word_to_mask, '[MASK]', 1)
79
+ predictions = fill_mask(masked_sentence)
80
+ words = [pred['score'] for pred in predictions]
81
+ logits = [pred['token_str'] for pred in predictions]
82
+ return masked_sentence, words, logits
83
 
84
  def mask_non_stopword_pseudorandom(sentence):
85
  stop_words = set(stopwords.words('english'))
86
  words = sentence.split()
87
  non_stop_words = [word for word in words if word.lower() not in stop_words]
88
  if not non_stop_words:
89
+ return sentence, None, None
90
  random.seed(10)
91
  word_to_mask = random.choice(non_stop_words)
92
  masked_sentence = sentence.replace(word_to_mask, '[MASK]', 1)
93
+ predictions = fill_mask(masked_sentence)
94
+ words = [pred['score'] for pred in predictions]
95
+ logits = [pred['token_str'] for pred in predictions]
96
+ return masked_sentence, words, logits
97
 
98
  def high_entropy_words(sentence, non_melting_points):
99
  stop_words = set(stopwords.words('english'))
 
106
  candidate_words = [word for word in words if word.lower() not in stop_words and word.lower() not in non_melting_words]
107
 
108
  if not candidate_words:
109
+ return sentence, None, None
110
 
111
  max_entropy = -float('inf')
112
  max_entropy_word = None
113
+ max_logits = None
114
 
115
  for word in candidate_words:
116
  masked_sentence = sentence.replace(word, '[MASK]', 1)
 
122
  if entropy > max_entropy:
123
  max_entropy = entropy
124
  max_entropy_word = word
125
+ max_logits = [pred['score'] for pred in predictions]
126
 
127
+ masked_sentence = sentence.replace(max_entropy_word, '[MASK]', 1)
128
+ words = [pred['score'] for pred in predictions]
129
+ logits = [pred['token_str'] for pred in predictions]
130
+ return masked_sentence, words, logits
131
 
132
  # Load tokenizer and model for masked language model
133
  tokenizer = AutoTokenizer.from_pretrained("bert-large-cased-whole-word-masking")
134
  model = AutoModelForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking")
135
  fill_mask = pipeline("fill-mask", model=model, tokenizer=tokenizer)
136
 
137
+ non_melting_points = [(1, 'Jewish'), (2, 'messages'), (3, 'stab')]
138
+ a, b, c = high_entropy_words("A former Cornell University student was sentenced to 21 months in prison on Monday after admitting that he had posted a series of online messages last fall in which he threatened to stab, rape and behead Jewish people", non_melting_points)
139
+ print(f"logits type: {type(b)}")
140
+ print(f"logits content: {b}")
 
paraphraser.py CHANGED
@@ -28,4 +28,4 @@ def generate_paraphrase(question):
28
  res = paraphrase(question, para_tokenizer, para_model)
29
  return res
30
 
31
- print(generate_paraphrase("Kim Beom-su, the billionaire behind the South Korean technology giant Kakao, was taken into custody on allegations of stock manipulation during a bidding war over one of the country’s largest K-pop agencies."))
 
28
  res = paraphrase(question, para_tokenizer, para_model)
29
  return res
30
 
31
+ # print(generate_paraphrase("Kim Beom-su, the billionaire behind the South Korean technology giant Kakao, was taken into custody on allegations of stock manipulation during a bidding war over one of the country’s largest K-pop agencies."))
sampling_methods.py CHANGED
@@ -1,145 +1,33 @@
1
- import re
2
- from nltk.corpus import stopwords
3
- import random
4
- from termcolor import colored
5
-
6
- # Function to Watermark a Word Take Randomly Between Each lcs Point (Random Sampling)
7
- def random_sampling(original_sentence, paraphrased_sentences):
8
- stop_words = set(stopwords.words('english'))
9
- original_sentence_lower = original_sentence.lower()
10
- paraphrased_sentences_lower = [s.lower() for s in paraphrased_sentences]
11
- paraphrased_sentences_no_stopwords = []
12
-
13
- for sentence in paraphrased_sentences_lower:
14
- words = re.findall(r'\b\w+\b', sentence)
15
- filtered_sentence = ' '.join([word for word in words if word not in stop_words])
16
- paraphrased_sentences_no_stopwords.append(filtered_sentence)
17
-
18
- results = []
19
- for idx, sentence in enumerate(paraphrased_sentences_no_stopwords):
20
- common_words = set(original_sentence_lower.split()) & set(sentence.split())
21
- common_substrings = ', '.join(sorted(common_words))
22
-
23
- words_to_replace = [word for word in sentence.split() if word not in common_words]
24
- if words_to_replace:
25
- word_to_mark = random.choice(words_to_replace)
26
- sentence = sentence.replace(word_to_mark, colored(word_to_mark, 'red'))
27
-
28
- for word in common_words:
29
- sentence = sentence.replace(word, colored(word, 'green'))
30
-
31
- results.append({
32
- f"Paraphrased Sentence {idx+1}": sentence,
33
- "Common Substrings": common_substrings
34
- })
35
- return results
36
-
37
- # Function for Inverse Transform Sampling
38
- def inverse_transform_sampling(original_sentence, paraphrased_sentences):
39
- stop_words = set(stopwords.words('english'))
40
- original_sentence_lower = original_sentence.lower()
41
- paraphrased_sentences_lower = [s.lower() for s in paraphrased_sentences]
42
- paraphrased_sentences_no_stopwords = []
43
-
44
- for sentence in paraphrased_sentences_lower:
45
- words = re.findall(r'\b\w+\b', sentence)
46
- filtered_sentence = ' '.join([word for word in words if word not in stop_words])
47
- paraphrased_sentences_no_stopwords.append(filtered_sentence)
48
-
49
- results = []
50
- for idx, sentence in enumerate(paraphrased_sentences_no_stopwords):
51
- common_words = set(original_sentence_lower.split()) & set(sentence.split())
52
- common_substrings = ', '.join(sorted(common_words))
53
-
54
- words_to_replace = [word for word in sentence.split() if word not in common_words]
55
- if words_to_replace:
56
- probabilities = [1 / len(words_to_replace)] * len(words_to_replace)
57
- chosen_word = random.choices(words_to_replace, weights=probabilities)[0]
58
- sentence = sentence.replace(chosen_word, colored(chosen_word, 'magenta'))
59
-
60
- for word in common_words:
61
- sentence = sentence.replace(word, colored(word, 'green'))
62
-
63
- results.append({
64
- f"Paraphrased Sentence {idx+1}": sentence,
65
- "Common Substrings": common_substrings
66
- })
67
- return results
68
-
69
- # Function for Contextual Sampling
70
- def contextual_sampling(original_sentence, paraphrased_sentences):
71
- stop_words = set(stopwords.words('english'))
72
- original_sentence_lower = original_sentence.lower()
73
- paraphrased_sentences_lower = [s.lower() for s in paraphrased_sentences]
74
- paraphrased_sentences_no_stopwords = []
75
-
76
- for sentence in paraphrased_sentences_lower:
77
- words = re.findall(r'\b\w+\b', sentence)
78
- filtered_sentence = ' '.join([word for word in words if word not in stop_words])
79
- paraphrased_sentences_no_stopwords.append(filtered_sentence)
80
-
81
- results = []
82
- for idx, sentence in enumerate(paraphrased_sentences_no_stopwords):
83
- common_words = set(original_sentence_lower.split()) & set(sentence.split())
84
- common_substrings = ', '.join(sorted(common_words))
85
-
86
- words_to_replace = [word for word in sentence.split() if word not in common_words]
87
- if words_to_replace:
88
- context = " ".join([word for word in sentence.split() if word not in common_words])
89
- chosen_word = random.choice(words_to_replace)
90
- sentence = sentence.replace(chosen_word, colored(chosen_word, 'red'))
91
-
92
- for word in common_words:
93
- sentence = sentence.replace(word, colored(word, 'green'))
94
-
95
- results.append({
96
- f"Paraphrased Sentence {idx+1}": sentence,
97
- "Common Substrings": common_substrings
98
- })
99
- return results
100
-
101
- # Function for Exponential Minimum Sampling
102
- def exponential_minimum_sampling(original_sentence, paraphrased_sentences):
103
- stop_words = set(stopwords.words('english'))
104
- original_sentence_lower = original_sentence.lower()
105
- paraphrased_sentences_lower = [s.lower() for s in paraphrased_sentences]
106
- paraphrased_sentences_no_stopwords = []
107
-
108
- for sentence in paraphrased_sentences_lower:
109
- words = re.findall(r'\b\w+\b', sentence)
110
- filtered_sentence = ' '.join([word for word in words if word not in stop_words])
111
- paraphrased_sentences_no_stopwords.append(filtered_sentence)
112
-
113
- results = []
114
- for idx, sentence in enumerate(paraphrased_sentences_no_stopwords):
115
- common_words = set(original_sentence_lower.split()) & set(sentence.split())
116
- common_substrings = ', '.join(sorted(common_words))
117
-
118
- words_to_replace = [word for word in sentence.split() if word not in common_words]
119
- if words_to_replace:
120
- num_words = len(words_to_replace)
121
- probabilities = [2 ** (-i) for i in range(num_words)]
122
- chosen_word = random.choices(words_to_replace, weights=probabilities)[0]
123
- sentence = sentence.replace(chosen_word, colored(chosen_word, 'red'))
124
-
125
- for word in common_words:
126
- sentence = sentence.replace(word, colored(word, 'green'))
127
-
128
- results.append({
129
- f"Paraphrased Sentence {idx+1}": sentence,
130
- "Common Substrings": common_substrings
131
- })
132
- return results
133
-
134
-
135
-
136
- #---------------------------------------------------------------------------
137
- # aryans implementation please refactor it as you see fit
138
 
139
  import torch
140
  import random
141
 
142
- def sample_word(words, logits, sampling_technique='inverse_transform', temperature=1.0):
143
  if sampling_technique == 'inverse_transform':
144
  probs = torch.softmax(torch.tensor(logits), dim=-1)
145
  cumulative_probs = torch.cumsum(probs, dim=-1)
@@ -160,4 +48,8 @@ def sample_word(words, logits, sampling_technique='inverse_transform', temperatu
160
  raise ValueError("Invalid sampling technique. Choose 'inverse_transform', 'exponential_minimum', 'temperature', or 'greedy'.")
161
 
162
  sampled_word = words[sampled_index]
163
- return sampled_word
 
 
 
 
 
1
+ # import torch
2
+ # import random
3
+
4
+ # def sample_word(words, logits, sampling_technique='inverse_transform', temperature=1.0):
5
+ # if sampling_technique == 'inverse_transform':
6
+ # probs = torch.softmax(torch.tensor(logits), dim=-1)
7
+ # cumulative_probs = torch.cumsum(probs, dim=-1)
8
+ # random_prob = random.random()
9
+ # sampled_index = torch.where(cumulative_probs >= random_prob)[0][0]
10
+ # elif sampling_technique == 'exponential_minimum':
11
+ # probs = torch.softmax(torch.tensor(logits), dim=-1)
12
+ # exp_probs = torch.exp(-torch.log(probs))
13
+ # random_probs = torch.rand_like(exp_probs)
14
+ # sampled_index = torch.argmax(random_probs * exp_probs)
15
+ # elif sampling_technique == 'temperature':
16
+ # scaled_logits = torch.tensor(logits) / temperature
17
+ # probs = torch.softmax(scaled_logits, dim=-1)
18
+ # sampled_index = torch.multinomial(probs, 1).item()
19
+ # elif sampling_technique == 'greedy':
20
+ # sampled_index = torch.argmax(torch.tensor(logits)).item()
21
+ # else:
22
+ # raise ValueError("Invalid sampling technique. Choose 'inverse_transform', 'exponential_minimum', 'temperature', or 'greedy'.")
23
+
24
+ # sampled_word = words[sampled_index]
25
+ # return sampled_word
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  import torch
28
  import random
29
 
30
+ def sample_word(sentence, words, logits, sampling_technique='inverse_transform', temperature=1.0):
31
  if sampling_technique == 'inverse_transform':
32
  probs = torch.softmax(torch.tensor(logits), dim=-1)
33
  cumulative_probs = torch.cumsum(probs, dim=-1)
 
48
  raise ValueError("Invalid sampling technique. Choose 'inverse_transform', 'exponential_minimum', 'temperature', or 'greedy'.")
49
 
50
  sampled_word = words[sampled_index]
51
+
52
+ # Replace [MASK] with the sampled word
53
+ filled_sentence = sentence.replace('[MASK]', sampled_word)
54
+
55
+ return filled_sentence
tree.py CHANGED
@@ -1,29 +1,31 @@
1
- import plotly.graph_objs as go
2
  import textwrap
3
  import re
4
  from collections import defaultdict
5
- from paraphraser import generate_paraphrase
6
- from masking_methods import mask, mask_non_stopword
7
-
8
- def generate_plot(original_sentence, selected_sentences):
9
- first_paraphrased_sentence = selected_sentences[0]
10
- masked_sentence = mask_non_stopword(first_paraphrased_sentence)
11
- masked_versions = mask(masked_sentence)
12
-
13
- nodes = []
14
- nodes.append(original_sentence)
15
- nodes.extend(selected_sentences)
16
- nodes.extend(masked_versions)
17
- nodes[0] += ' L0'
18
- para_len = len(selected_sentences)
19
- for i in range(1, para_len+1):
20
- nodes[i] += ' L1'
21
- for i in range(para_len+1, len(nodes)):
22
- nodes[i] += ' L2'
23
-
24
  cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes]
25
- wrapped_nodes = ['<br>'.join(textwrap.wrap(node, width=30)) for node in cleaned_nodes]
26
-
 
 
 
27
  def get_levels_and_edges(nodes):
28
  levels = {}
29
  edges = []
@@ -37,58 +39,99 @@ def generate_plot(original_sentence, selected_sentences):
37
  if level == 1:
38
  edges.append((root_node, i))
39
 
40
- # Identify the first L1 node
41
- first_l1_node = next(i for i, level in levels.items() if level == 1)
42
- # Add edges from the first L1 node to all L2 nodes
43
- for i, level in levels.items():
44
- if level == 2:
45
- edges.append((first_l1_node, i))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  return levels, edges
48
 
49
  # Get levels and dynamic edges
50
  levels, edges = get_levels_and_edges(nodes)
51
- max_level = max(levels.values())
52
 
53
  # Calculate positions
54
  positions = {}
55
- level_widths = defaultdict(int)
56
  for node, level in levels.items():
57
- level_widths[level] += 1
58
 
59
- x_offsets = {level: - (width - 1) / 2 for level, width in level_widths.items()}
60
- y_gap = 4
 
 
61
 
62
  for node, level in levels.items():
63
- positions[node] = (x_offsets[level], -level * y_gap)
64
- x_offsets[level] += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  # Create figure
67
  fig = go.Figure()
68
 
69
  # Add nodes to the figure
70
  for i, node in enumerate(wrapped_nodes):
 
71
  x, y = positions[i]
72
  fig.add_trace(go.Scatter(
73
- x=[x],
74
  y=[y],
75
  mode='markers',
76
  marker=dict(size=10, color='blue'),
77
  hoverinfo='none'
78
  ))
79
  fig.add_annotation(
80
- x=x,
81
  y=y,
82
- text=node,
83
  showarrow=False,
84
- yshift=20, # Adjust the y-shift value to avoid overlap
85
  align="center",
86
- font=dict(size=10),
87
  bordercolor='black',
88
  borderwidth=1,
89
- borderpad=4,
90
  bgcolor='white',
91
- width=200
92
  )
93
 
94
  # Add edges to the figure
@@ -96,19 +139,19 @@ def generate_plot(original_sentence, selected_sentences):
96
  x0, y0 = positions[edge[0]]
97
  x1, y1 = positions[edge[1]]
98
  fig.add_trace(go.Scatter(
99
- x=[x0, x1],
100
  y=[y0, y1],
101
  mode='lines',
102
- line=dict(color='black', width=2)
103
  ))
104
 
105
  fig.update_layout(
106
  showlegend=False,
107
- margin=dict(t=50, b=50, l=50, r=50),
108
  xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
109
  yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
110
- width=1470,
111
- height=800 # Increase height to provide more space
112
  )
113
 
114
  return fig
 
1
+ import plotly.graph_objects as go
2
  import textwrap
3
  import re
4
  from collections import defaultdict
5
+
6
+ def generate_subplot(paraphrased_sentence, scheme_sentences, sampled_sentence, highlight_info):
7
+ # Combine nodes into one list with appropriate labels
8
+ nodes = [paraphrased_sentence] + scheme_sentences + sampled_sentence
9
+ nodes[0] += ' L0' # Paraphrased sentence is level 0
10
+ para_len = len(scheme_sentences)
11
+ for i in range(1, para_len + 1):
12
+ nodes[i] += ' L1' # Scheme sentences are level 1
13
+ for i in range(para_len + 1, len(nodes)):
14
+ nodes[i] += ' L2' # Sampled sentences are level 2
15
+
16
+ # Define the highlight_words function
17
+ def highlight_words(sentence, color_map):
18
+ for word, color in color_map.items():
19
+ sentence = re.sub(f"\\b{word}\\b", f"{{{{{word}}}}}", sentence, flags=re.IGNORECASE)
20
+ return sentence
21
+
22
+ # Clean and wrap nodes, and highlight specified words globally
 
23
  cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes]
24
+ global_color_map = dict(highlight_info)
25
+ highlighted_nodes = [highlight_words(node, global_color_map) for node in cleaned_nodes]
26
+ wrapped_nodes = ['<br>'.join(textwrap.wrap(node, width=30)) for node in highlighted_nodes]
27
+
28
+ # Function to determine tree levels and create edges dynamically
29
  def get_levels_and_edges(nodes):
30
  levels = {}
31
  edges = []
 
39
  if level == 1:
40
  edges.append((root_node, i))
41
 
42
+ # Add edges from each L1 node to their corresponding L2 nodes
43
+ l1_indices = [i for i, level in levels.items() if level == 1]
44
+ l2_indices = [i for i, level in levels.items() if level == 2]
45
+
46
+ for i, l1_node in enumerate(l1_indices):
47
+ l2_start = i * 4
48
+ for j in range(4):
49
+ l2_index = l2_start + j
50
+ if l2_index < len(l2_indices):
51
+ edges.append((l1_node, l2_indices[l2_index]))
52
+
53
+ # Add edges from each L2 node to their corresponding L3 nodes
54
+ l2_indices = [i for i, level in levels.items() if level == 2]
55
+ l3_indices = [i for i, level in levels.items() if level == 3]
56
+
57
+ l2_to_l3_map = {l2_node: [] for l2_node in l2_indices}
58
+
59
+ # Map L3 nodes to L2 nodes
60
+ for l3_node in l3_indices:
61
+ l2_node = l3_node % len(l2_indices)
62
+ l2_to_l3_map[l2_indices[l2_node]].append(l3_node)
63
+
64
+ for l2_node, l3_nodes in l2_to_l3_map.items():
65
+ for l3_node in l3_nodes:
66
+ edges.append((l2_node, l3_node))
67
 
68
  return levels, edges
69
 
70
  # Get levels and dynamic edges
71
  levels, edges = get_levels_and_edges(nodes)
72
+ max_level = max(levels.values(), default=0)
73
 
74
  # Calculate positions
75
  positions = {}
76
+ level_heights = defaultdict(int)
77
  for node, level in levels.items():
78
+ level_heights[level] += 1
79
 
80
+ y_offsets = {level: - (height - 1) / 2 for level, height in level_heights.items()}
81
+ x_gap = 2
82
+ l1_y_gap = 10
83
+ l2_y_gap = 6
84
 
85
  for node, level in levels.items():
86
+ if level == 1:
87
+ positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap)
88
+ elif level == 2:
89
+ positions[node] = (-level * x_gap, y_offsets[level] * l2_y_gap)
90
+ else:
91
+ positions[node] = (-level * x_gap, y_offsets[level] * l2_y_gap)
92
+ y_offsets[level] += 1
93
+
94
+ # Function to highlight words in a wrapped node string
95
+ def color_highlighted_words(node, color_map):
96
+ parts = re.split(r'(\{\{.*?\}\})', node)
97
+ colored_parts = []
98
+ for part in parts:
99
+ match = re.match(r'\{\{(.*?)\}\}', part)
100
+ if match:
101
+ word = match.group(1)
102
+ color = color_map.get(word, 'black')
103
+ colored_parts.append(f"<span style='color: {color};'>{word}</span>")
104
+ else:
105
+ colored_parts.append(part)
106
+ return ''.join(colored_parts)
107
 
108
  # Create figure
109
  fig = go.Figure()
110
 
111
  # Add nodes to the figure
112
  for i, node in enumerate(wrapped_nodes):
113
+ colored_node = color_highlighted_words(node, global_color_map)
114
  x, y = positions[i]
115
  fig.add_trace(go.Scatter(
116
+ x=[-x], # Reflect the x coordinate
117
  y=[y],
118
  mode='markers',
119
  marker=dict(size=10, color='blue'),
120
  hoverinfo='none'
121
  ))
122
  fig.add_annotation(
123
+ x=-x, # Reflect the x coordinate
124
  y=y,
125
+ text=colored_node,
126
  showarrow=False,
127
+ xshift=15,
128
  align="center",
129
+ font=dict(size=8),
130
  bordercolor='black',
131
  borderwidth=1,
132
+ borderpad=2,
133
  bgcolor='white',
134
+ width=150
135
  )
136
 
137
  # Add edges to the figure
 
139
  x0, y0 = positions[edge[0]]
140
  x1, y1 = positions[edge[1]]
141
  fig.add_trace(go.Scatter(
142
+ x=[-x0, -x1], # Reflect the x coordinates
143
  y=[y0, y1],
144
  mode='lines',
145
+ line=dict(color='black', width=1)
146
  ))
147
 
148
  fig.update_layout(
149
  showlegend=False,
150
+ margin=dict(t=20, b=20, l=20, r=20),
151
  xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
152
  yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
153
+ width=1200, # Adjusted width to accommodate more levels
154
+ height=1000 # Adjusted height to accommodate more levels
155
  )
156
 
157
  return fig