jgyasu commited on
Commit
63b3783
1 Parent(s): 2493822

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. app.py +38 -56
  2. entailment.py +1 -1
  3. lcs.py +2 -2
  4. paraphraser.py +1 -1
  5. tree.py +430 -87
app.py CHANGED
@@ -3,31 +3,11 @@ nltk.download('stopwords')
3
  from transformers import AutoTokenizer
4
  from transformers import AutoModelForSeq2SeqLM
5
  import plotly.graph_objs as go
6
- import textwrap
7
  from transformers import pipeline
8
- import re
9
- import requests
10
- from PIL import Image
11
- import itertools
12
- import numpy as np
13
- import matplotlib.pyplot as plt
14
- import matplotlib
15
  from matplotlib.colors import ListedColormap, rgb2hex
16
- import ipywidgets as widgets
17
- from IPython.display import display, HTML
18
- import pandas as pd
19
- from pprint import pprint
20
- from tenacity import retry
21
- from tqdm import tqdm
22
- from transformers import GPT2LMHeadModel
23
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForMaskedLM
24
  import random
25
- from nltk.corpus import stopwords
26
- from termcolor import colored
27
- from nltk.translate.bleu_score import sentence_bleu
28
- from transformers import BertTokenizer, BertModel
29
  import gradio as gr
30
- from tree import generate_subplot
31
  from paraphraser import generate_paraphrase
32
  from lcs import find_common_subsequences
33
  from highlighter import highlight_common_words, highlight_common_words_dict
@@ -47,22 +27,18 @@ def model(prompt):
47
  masked_sentences = []
48
  masked_words = []
49
  masked_logits = []
50
- selected_sentences_list = list(selected_sentences.keys())
51
 
52
- for sentence in selected_sentences_list:
53
- # Mask non-stopword
54
  masked_sent, logits, words = mask_non_stopword(sentence)
55
  masked_sentences.append(masked_sent)
56
  masked_words.append(words)
57
  masked_logits.append(logits)
58
 
59
- # Mask non-stopword pseudorandom
60
  masked_sent, logits, words = mask_non_stopword_pseudorandom(sentence)
61
  masked_sentences.append(masked_sent)
62
  masked_words.append(words)
63
  masked_logits.append(logits)
64
 
65
- # High entropy words
66
  masked_sent, logits, words = high_entropy_words(sentence, common_grams)
67
  masked_sentences.append(masked_sent)
68
  masked_words.append(words)
@@ -75,45 +51,39 @@ def model(prompt):
75
  sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='temperature', temperature=1.0))
76
  sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='greedy', temperature=1.0))
77
 
78
- # Predefined set of colors that are visible on a white background, excluding black
 
79
  colors = ["red", "blue", "brown", "green"]
80
 
81
- # Function to generate color from predefined set
82
  def select_color():
83
  return random.choice(colors)
84
 
85
- # Create highlight_info with selected colors
86
  highlight_info = [(word, select_color()) for _, word in common_grams]
87
 
88
-
89
- highlighted_user_prompt = highlight_common_words(common_grams, [user_prompt], "User Prompt (Highlighted and Numbered)")
90
  highlighted_accepted_sentences = highlight_common_words_dict(common_grams, selected_sentences, "Paraphrased Sentences")
91
  highlighted_discarded_sentences = highlight_common_words_dict(common_grams, discarded_sentences, "Discarded Sentences")
92
 
 
 
93
 
94
- # Initialize empty list to hold the trees
95
- trees = []
96
-
97
- # Initialize the indices for masked and sampled sentences
98
  masked_index = 0
99
  sampled_index = 0
100
 
101
- for i, sentence in enumerate(selected_sentences):
102
- # Generate the sublists of masked and sampled sentences based on current indices
103
  next_masked_sentences = masked_sentences[masked_index:masked_index + 3]
104
  next_sampled_sentences = sampled_sentences[sampled_index:sampled_index + 12]
105
-
106
- # Create the tree for the current sentence
107
- tree = generate_subplot(sentence, next_masked_sentences, next_sampled_sentences, highlight_info)
108
- trees.append(tree)
109
-
110
- # Update the indices for the next iteration
111
- masked_index += 3
112
- sampled_index += 12
113
 
 
 
 
 
 
114
 
115
- # Return all the outputs together
116
- return [highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + trees
 
 
117
 
118
 
119
  with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
@@ -136,17 +106,29 @@ with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
136
  with gr.TabItem("Discarded Sentences"):
137
  highlighted_discarded_sentences = gr.HTML()
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  with gr.Row():
140
  with gr.Tabs():
141
- tree_tabs = []
142
- for i in range(3): # Adjust this range according to the number of trees
143
- with gr.TabItem(f"Tree {i+1}"):
144
- tree = gr.Plot()
145
- tree_tabs.append(tree)
146
 
147
- submit_button.click(model, inputs=user_input, outputs=[highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + tree_tabs)
148
  clear_button.click(lambda: "", inputs=None, outputs=user_input)
149
- clear_button.click(lambda: "", inputs=None, outputs=[highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + tree_tabs)
150
 
151
- # Launch the demo
152
- demo.launch(share=True)
 
3
  from transformers import AutoTokenizer
4
  from transformers import AutoModelForSeq2SeqLM
5
  import plotly.graph_objs as go
 
6
  from transformers import pipeline
 
 
 
 
 
 
 
7
  from matplotlib.colors import ListedColormap, rgb2hex
 
 
 
 
 
 
 
 
8
  import random
 
 
 
 
9
  import gradio as gr
10
+ from tree import generate_subplot1, generate_subplot2
11
  from paraphraser import generate_paraphrase
12
  from lcs import find_common_subsequences
13
  from highlighter import highlight_common_words, highlight_common_words_dict
 
27
  masked_sentences = []
28
  masked_words = []
29
  masked_logits = []
 
30
 
31
+ for sentence in paraphrased_sentences:
 
32
  masked_sent, logits, words = mask_non_stopword(sentence)
33
  masked_sentences.append(masked_sent)
34
  masked_words.append(words)
35
  masked_logits.append(logits)
36
 
 
37
  masked_sent, logits, words = mask_non_stopword_pseudorandom(sentence)
38
  masked_sentences.append(masked_sent)
39
  masked_words.append(words)
40
  masked_logits.append(logits)
41
 
 
42
  masked_sent, logits, words = high_entropy_words(sentence, common_grams)
43
  masked_sentences.append(masked_sent)
44
  masked_words.append(words)
 
51
  sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='temperature', temperature=1.0))
52
  sampled_sentences.append(sample_word(masked_sent, words, logits, sampling_technique='greedy', temperature=1.0))
53
 
54
+ print(len(sampled_sentences))
55
+
56
  colors = ["red", "blue", "brown", "green"]
57
 
 
58
  def select_color():
59
  return random.choice(colors)
60
 
 
61
  highlight_info = [(word, select_color()) for _, word in common_grams]
62
 
63
+ highlighted_user_prompt = highlight_common_words(common_grams, [user_prompt], "Non-melting Points in the User Prompt")
 
64
  highlighted_accepted_sentences = highlight_common_words_dict(common_grams, selected_sentences, "Paraphrased Sentences")
65
  highlighted_discarded_sentences = highlight_common_words_dict(common_grams, discarded_sentences, "Discarded Sentences")
66
 
67
+ trees1 = []
68
+ trees2 = []
69
 
 
 
 
 
70
  masked_index = 0
71
  sampled_index = 0
72
 
73
+ for i, sentence in enumerate(paraphrased_sentences):
 
74
  next_masked_sentences = masked_sentences[masked_index:masked_index + 3]
75
  next_sampled_sentences = sampled_sentences[sampled_index:sampled_index + 12]
 
 
 
 
 
 
 
 
76
 
77
+ tree1 = generate_subplot1(sentence, next_masked_sentences, highlight_info, common_grams)
78
+ trees1.append(tree1)
79
+
80
+ tree2 = generate_subplot2(next_masked_sentences, next_sampled_sentences, highlight_info, common_grams)
81
+ trees2.append(tree2)
82
 
83
+ masked_index += 3
84
+ sampled_index += 12
85
+
86
+ return [highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + trees1 + trees2
87
 
88
 
89
  with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
 
106
  with gr.TabItem("Discarded Sentences"):
107
  highlighted_discarded_sentences = gr.HTML()
108
 
109
+ # Adding labels before the tree plots
110
+ with gr.Row():
111
+ gr.Markdown("### Where to Mask?") # Label for masked sentences trees
112
+ with gr.Row():
113
+ with gr.Tabs():
114
+ tree1_tabs = []
115
+ for i in range(10): # Adjust this range according to the number of trees
116
+ with gr.TabItem(f"Sentence {i+1}"):
117
+ tree1 = gr.Plot()
118
+ tree1_tabs.append(tree1)
119
+
120
+ with gr.Row():
121
+ gr.Markdown("### How to Mask?") # Label for sampled sentences trees
122
  with gr.Row():
123
  with gr.Tabs():
124
+ tree2_tabs = []
125
+ for i in range(10): # Adjust this range according to the number of trees
126
+ with gr.TabItem(f"Sentence {i+1}"):
127
+ tree2 = gr.Plot()
128
+ tree2_tabs.append(tree2)
129
 
130
+ submit_button.click(model, inputs=user_input, outputs=[highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + tree1_tabs + tree2_tabs)
131
  clear_button.click(lambda: "", inputs=None, outputs=user_input)
132
+ clear_button.click(lambda: "", inputs=None, outputs=[highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences] + tree1_tabs + tree2_tabs)
133
 
134
+ demo.launch(share=True)
 
entailment.py CHANGED
@@ -28,4 +28,4 @@ def analyze_entailment(original_sentence, paraphrased_sentences, threshold):
28
 
29
  return all_sentences, selected_sentences, discarded_sentences
30
 
31
- # print(analyze_entailment("I love you", ["You're being loved by me"], 0.7))
 
28
 
29
  return all_sentences, selected_sentences, discarded_sentences
30
 
31
+ # print(analyze_entailment("I love you", [""], 0.7))
lcs.py CHANGED
@@ -40,7 +40,7 @@ def find_common_subsequences(sentence, str_list):
40
  return indexed_common_grams
41
 
42
  # Example usage
43
- # sentence = "Kim Beom-su, the billionaire behind the South Korean technology giant Kakao, was taken into custody on allegations of stock manipulation during a bidding war over one of the country’s largest K-pop agencies."
44
- # str_list = ["The founder of South Korean technology company Kakao, billionaire Kim Beom-su, was arrested on charges of stock fraud during a bidding war for one of North Korea's biggest K-pop companies.", "In a bidding war for one of South Korea's largest K-pop agencies, Kim Beom-su, the billionaire who owns Kakao, was arrested on charges of manipulating stocks.", "During a bidding war for one of South Korea's biggest K-pop agencies, Kim Beom-su, the billionaire who owns Kakao, was arrested on charges of manipulating stocks.", "Kim Beom-su, the founder of South Korean technology giant Kakao's billionaire investor status, was arrested on charges of stock fraud during a bidding war for one of North Korea'S top K-pop agencies.", "A bidding war over one of South Korea's biggest K-pop agencies led to the arrest and apprehension charges of Kim Beom-Su, the billionaire who owns the technology giant Kakao.", "The billionaire who owns South Korean technology giant Kakao, Kim Beom-Su, was taken into custody for allegedly engaging in stock trading during a bidding war for one of North Korea's biggest K-pop media groups.", "Accused of stockpiling during a bidding war for one of South Korea's biggest K-pop agencies, Kim Beom-Su, the founder and owner of technology firm known as Kakao, was arrested on charges of manipulating stocks.", 'Kakao, the South Korean technology giant, was involved in a bidding war with Kim Beon-su, its founder, who was arrested on charges of manipulating stocks.', "South Korea's Kakao corporation'entrepreneur husband, Kim Beom-su (pictured), was arrested on suspicion of stock fraud during a bidding war for one of the country'S top K-pop companies.", 'Kim Beom-su, the billionaire who own a South Korean technology company called Kakaof, was arrested on charges of manipulating stocks in an ongoing bidding war over one million shares.']
45
 
46
  # print(find_common_subsequences(sentence, str_list))
 
40
  return indexed_common_grams
41
 
42
  # Example usage
43
+ # sentence = "Donald Trump said at a campaign rally event in Wilkes-Barre, Pennsylvania, that there has “never been a more dangerous time 5since the Holocaust” to be Jewish in the United States."
44
+ # str_list = ['']
45
 
46
  # print(find_common_subsequences(sentence, str_list))
paraphraser.py CHANGED
@@ -28,4 +28,4 @@ def generate_paraphrase(question):
28
  res = paraphrase(question, para_tokenizer, para_model)
29
  return res
30
 
31
- # print(generate_paraphrase("Kim Beom-su, the billionaire behind the South Korean technology giant Kakao, was taken into custody on allegations of stock manipulation during a bidding war over one of the country’s largest K-pop agencies."))
 
28
  res = paraphrase(question, para_tokenizer, para_model)
29
  return res
30
 
31
+ # print(generate_paraphrase("Donald Trump said at a campaign rally event in Wilkes-Barre, Pennsylvania, that there has “never been a more dangerous time 5since the Holocaust” to be Jewish in the United States."))
tree.py CHANGED
@@ -3,15 +3,12 @@
3
  # import re
4
  # from collections import defaultdict
5
 
6
- # def generate_subplot(paraphrased_sentence, scheme_sentences, sampled_sentence, highlight_info):
7
  # # Combine nodes into one list with appropriate labels
8
- # nodes = [paraphrased_sentence] + scheme_sentences + sampled_sentence
9
  # nodes[0] += ' L0' # Paraphrased sentence is level 0
10
- # para_len = len(scheme_sentences)
11
- # for i in range(1, para_len + 1):
12
  # nodes[i] += ' L1' # Scheme sentences are level 1
13
- # for i in range(para_len + 1, len(nodes)):
14
- # nodes[i] += ' L2' # Sampled sentences are level 2
15
 
16
  # # Define the highlight_words function
17
  # def highlight_words(sentence, color_map):
@@ -23,7 +20,7 @@
23
  # cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes]
24
  # global_color_map = dict(highlight_info)
25
  # highlighted_nodes = [highlight_words(node, global_color_map) for node in cleaned_nodes]
26
- # wrapped_nodes = ['<br>'.join(textwrap.wrap(node, width=30)) for node in highlighted_nodes]
27
 
28
  # # Function to determine tree levels and create edges dynamically
29
  # def get_levels_and_edges(nodes):
@@ -39,37 +36,185 @@
39
  # if level == 1:
40
  # edges.append((root_node, i))
41
 
42
- # # Add edges from each L1 node to their corresponding L2 nodes
43
- # l1_indices = [i for i, level in levels.items() if level == 1]
44
- # l2_indices = [i for i, level in levels.items() if level == 2]
45
 
46
- # for i, l1_node in enumerate(l1_indices):
47
- # l2_start = i * 4
48
- # for j in range(4):
49
- # l2_index = l2_start + j
50
- # if l2_index < len(l2_indices):
51
- # edges.append((l1_node, l2_indices[l2_index]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- # # Add edges from each L2 node to their corresponding L3 nodes
54
- # l2_indices = [i for i, level in levels.items() if level == 2]
55
- # l3_indices = [i for i, level in levels.items() if level == 3]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- # l2_to_l3_map = {l2_node: [] for l2_node in l2_indices}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- # # Map L3 nodes to L2 nodes
60
- # for l3_node in l3_indices:
61
- # l2_node = l3_node % len(l2_indices)
62
- # l2_to_l3_map[l2_indices[l2_node]].append(l3_node)
 
 
 
 
 
 
63
 
64
- # for l2_node, l3_nodes in l2_to_l3_map.items():
65
- # for l3_node in l3_nodes:
66
- # edges.append((l2_node, l3_node))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  # return levels, edges
69
 
70
  # # Get levels and dynamic edges
71
  # levels, edges = get_levels_and_edges(nodes)
72
- # max_level = max(levels.values(), default=0)
73
 
74
  # # Calculate positions
75
  # positions = {}
@@ -80,15 +225,12 @@
80
  # y_offsets = {level: - (height - 1) / 2 for level, height in level_heights.items()}
81
  # x_gap = 2
82
  # l1_y_gap = 10
83
- # l2_y_gap = 6
84
 
85
  # for node, level in levels.items():
86
  # if level == 1:
87
  # positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap)
88
- # elif level == 2:
89
- # positions[node] = (-level * x_gap, y_offsets[level] * l2_y_gap)
90
  # else:
91
- # positions[node] = (-level * x_gap, y_offsets[level] * l2_y_gap)
92
  # y_offsets[level] += 1
93
 
94
  # # Function to highlight words in a wrapped node string
@@ -105,71 +247,116 @@
105
  # colored_parts.append(part)
106
  # return ''.join(colored_parts)
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  # # Create figure
109
- # fig = go.Figure()
110
 
111
  # # Add nodes to the figure
112
  # for i, node in enumerate(wrapped_nodes):
113
  # colored_node = color_highlighted_words(node, global_color_map)
114
  # x, y = positions[i]
115
- # fig.add_trace(go.Scatter(
116
  # x=[-x], # Reflect the x coordinate
117
  # y=[y],
118
  # mode='markers',
119
  # marker=dict(size=10, color='blue'),
120
  # hoverinfo='none'
121
  # ))
122
- # fig.add_annotation(
123
  # x=-x, # Reflect the x coordinate
124
  # y=y,
125
  # text=colored_node,
126
  # showarrow=False,
127
  # xshift=15,
128
  # align="center",
129
- # font=dict(size=8),
130
  # bordercolor='black',
131
  # borderwidth=1,
132
  # borderpad=2,
133
  # bgcolor='white',
134
- # width=150
 
135
  # )
136
 
137
- # # Add edges to the figure
138
- # for edge in edges:
139
  # x0, y0 = positions[edge[0]]
140
  # x1, y1 = positions[edge[1]]
141
- # fig.add_trace(go.Scatter(
142
  # x=[-x0, -x1], # Reflect the x coordinates
143
  # y=[y0, y1],
144
  # mode='lines',
145
  # line=dict(color='black', width=1)
146
  # ))
147
 
148
- # fig.update_layout(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  # showlegend=False,
150
  # margin=dict(t=20, b=20, l=20, r=20),
151
  # xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
152
  # yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
153
- # width=1200, # Adjusted width to accommodate more levels
154
  # height=1000 # Adjusted height to accommodate more levels
155
  # )
156
 
157
- # return fig
 
158
 
159
  import plotly.graph_objects as go
160
  import textwrap
161
  import re
162
  from collections import defaultdict
163
 
164
- def generate_subplot(paraphrased_sentence, scheme_sentences, sampled_sentence, highlight_info):
165
  # Combine nodes into one list with appropriate labels
166
- nodes = [paraphrased_sentence] + scheme_sentences + sampled_sentence
167
  nodes[0] += ' L0' # Paraphrased sentence is level 0
168
- para_len = len(scheme_sentences)
169
- for i in range(1, para_len + 1):
170
  nodes[i] += ' L1' # Scheme sentences are level 1
171
- for i in range(para_len + 1, len(nodes)):
172
- nodes[i] += ' L2' # Sampled sentences are level 2
 
 
 
 
 
 
 
 
173
 
174
  # Define the highlight_words function
175
  def highlight_words(sentence, color_map):
@@ -181,7 +368,7 @@ def generate_subplot(paraphrased_sentence, scheme_sentences, sampled_sentence, h
181
  cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes]
182
  global_color_map = dict(highlight_info)
183
  highlighted_nodes = [highlight_words(node, global_color_map) for node in cleaned_nodes]
184
- wrapped_nodes = ['<br>'.join(textwrap.wrap(node, width=30)) for node in highlighted_nodes]
185
 
186
  # Function to determine tree levels and create edges dynamically
187
  def get_levels_and_edges(nodes):
@@ -197,31 +384,188 @@ def generate_subplot(paraphrased_sentence, scheme_sentences, sampled_sentence, h
197
  if level == 1:
198
  edges.append((root_node, i))
199
 
200
- # Add edges from each L1 node to their corresponding L2 nodes
201
- l1_indices = [i for i, level in levels.items() if level == 1]
202
- l2_indices = [i for i, level in levels.items() if level == 2]
203
 
204
- for i, l1_node in enumerate(l1_indices):
205
- l2_start = i * 4
206
- for j in range(4):
207
- l2_index = l2_start + j
208
- if l2_index < len(l2_indices):
209
- edges.append((l1_node, l2_indices[l2_index]))
210
 
211
- # Add edges from each L2 node to their corresponding L3 nodes
212
- l2_indices = [i for i, level in levels.items() if level == 2]
213
- l3_indices = [i for i, level in levels.items() if level == 3]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
- l2_to_l3_map = {l2_node: [] for l2_node in l2_indices}
 
 
 
 
 
 
 
 
216
 
217
- # Map L3 nodes to L2 nodes
218
- for l3_node in l3_indices:
219
- l2_node = l3_node % len(l2_indices)
220
- l2_to_l3_map[l2_indices[l2_node]].append(l3_node)
 
 
 
 
221
 
222
- for l2_node, l3_nodes in l2_to_l3_map.items():
223
- for l3_node in l3_nodes:
224
- edges.append((l2_node, l3_node))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
  return levels, edges
227
 
@@ -238,15 +582,12 @@ def generate_subplot(paraphrased_sentence, scheme_sentences, sampled_sentence, h
238
  y_offsets = {level: - (height - 1) / 2 for level, height in level_heights.items()}
239
  x_gap = 2
240
  l1_y_gap = 10
241
- l2_y_gap = 6
242
 
243
  for node, level in levels.items():
244
  if level == 1:
245
  positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap)
246
- elif level == 2:
247
- positions[node] = (-level * x_gap, y_offsets[level] * l2_y_gap)
248
  else:
249
- positions[node] = (-level * x_gap, y_offsets[level] * l2_y_gap)
250
  y_offsets[level] += 1
251
 
252
  # Function to highlight words in a wrapped node string
@@ -283,39 +624,40 @@ def generate_subplot(paraphrased_sentence, scheme_sentences, sampled_sentence, h
283
  ]
284
 
285
  # Create figure
286
- fig = go.Figure()
287
 
288
  # Add nodes to the figure
289
  for i, node in enumerate(wrapped_nodes):
290
  colored_node = color_highlighted_words(node, global_color_map)
291
  x, y = positions[i]
292
- fig.add_trace(go.Scatter(
293
  x=[-x], # Reflect the x coordinate
294
  y=[y],
295
  mode='markers',
296
  marker=dict(size=10, color='blue'),
297
  hoverinfo='none'
298
  ))
299
- fig.add_annotation(
300
  x=-x, # Reflect the x coordinate
301
  y=y,
302
  text=colored_node,
303
  showarrow=False,
304
  xshift=15,
305
  align="center",
306
- font=dict(size=8),
307
  bordercolor='black',
308
  borderwidth=1,
309
  borderpad=2,
310
  bgcolor='white',
311
- width=150
 
312
  )
313
 
314
  # Add edges and text above each edge
315
  for i, edge in enumerate(edges):
316
  x0, y0 = positions[edge[0]]
317
  x1, y1 = positions[edge[1]]
318
- fig.add_trace(go.Scatter(
319
  x=[-x0, -x1], # Reflect the x coordinates
320
  y=[y0, y1],
321
  mode='lines',
@@ -330,23 +672,24 @@ def generate_subplot(paraphrased_sentence, scheme_sentences, sampled_sentence, h
330
  text_y_position = mid_y + 0.8 # Increase this value to shift the text further upwards
331
 
332
  # Add text annotation above the edge
333
- fig.add_annotation(
 
 
334
  x=mid_x,
335
  y=text_y_position,
336
- text=edge_texts[i], # Use the text specific to this edge
337
  showarrow=False,
338
- font=dict(size=10),
339
  align="center"
340
  )
341
 
342
- fig.update_layout(
343
  showlegend=False,
344
  margin=dict(t=20, b=20, l=20, r=20),
345
  xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
346
  yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
347
- width=1200, # Adjusted width to accommodate more levels
348
  height=1000 # Adjusted height to accommodate more levels
349
  )
350
 
351
- return fig
352
-
 
3
  # import re
4
  # from collections import defaultdict
5
 
6
+ # def generate_subplot1(paraphrased_sentence, scheme_sentences, highlight_info):
7
  # # Combine nodes into one list with appropriate labels
8
+ # nodes = [paraphrased_sentence] + scheme_sentences
9
  # nodes[0] += ' L0' # Paraphrased sentence is level 0
10
+ # for i in range(1, len(nodes)):
 
11
  # nodes[i] += ' L1' # Scheme sentences are level 1
 
 
12
 
13
  # # Define the highlight_words function
14
  # def highlight_words(sentence, color_map):
 
20
  # cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes]
21
  # global_color_map = dict(highlight_info)
22
  # highlighted_nodes = [highlight_words(node, global_color_map) for node in cleaned_nodes]
23
+ # wrapped_nodes = ['<br>'.join(textwrap.wrap(node, width=50)) for node in highlighted_nodes]
24
 
25
  # # Function to determine tree levels and create edges dynamically
26
  # def get_levels_and_edges(nodes):
 
36
  # if level == 1:
37
  # edges.append((root_node, i))
38
 
39
+ # return levels, edges
 
 
40
 
41
+ # # Get levels and dynamic edges
42
+ # levels, edges = get_levels_and_edges(nodes)
43
+ # max_level = max(levels.values(), default=0)
44
+
45
+ # # Calculate positions
46
+ # positions = {}
47
+ # level_heights = defaultdict(int)
48
+ # for node, level in levels.items():
49
+ # level_heights[level] += 1
50
+
51
+ # y_offsets = {level: - (height - 1) / 2 for level, height in level_heights.items()}
52
+ # x_gap = 2
53
+ # l1_y_gap = 10
54
+
55
+ # for node, level in levels.items():
56
+ # if level == 1:
57
+ # positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap)
58
+ # else:
59
+ # positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap)
60
+ # y_offsets[level] += 1
61
+
62
+ # # Function to highlight words in a wrapped node string
63
+ # def color_highlighted_words(node, color_map):
64
+ # parts = re.split(r'(\{\{.*?\}\})', node)
65
+ # colored_parts = []
66
+ # for part in parts:
67
+ # match = re.match(r'\{\{(.*?)\}\}', part)
68
+ # if match:
69
+ # word = match.group(1)
70
+ # color = color_map.get(word, 'black')
71
+ # colored_parts.append(f"<span style='color: {color};'>{word}</span>")
72
+ # else:
73
+ # colored_parts.append(part)
74
+ # return ''.join(colored_parts)
75
 
76
+ # # Define the text for each edge
77
+ # edge_texts = [
78
+ # "Highest Entropy Masking",
79
+ # "Pseudo-random Masking",
80
+ # "Random Masking",
81
+ # "Greedy Sampling",
82
+ # "Temperature Sampling",
83
+ # "Exponential Minimum Sampling",
84
+ # "Inverse Transform Sampling",
85
+ # "Greedy Sampling",
86
+ # "Temperature Sampling",
87
+ # "Exponential Minimum Sampling",
88
+ # "Inverse Transform Sampling",
89
+ # "Greedy Sampling",
90
+ # "Temperature Sampling",
91
+ # "Exponential Minimum Sampling",
92
+ # "Inverse Transform Sampling"
93
+ # ]
94
 
95
+ # # Create figure
96
+ # fig1 = go.Figure()
97
+
98
+ # # Add nodes to the figure
99
+ # for i, node in enumerate(wrapped_nodes):
100
+ # colored_node = color_highlighted_words(node, global_color_map)
101
+ # x, y = positions[i]
102
+ # fig1.add_trace(go.Scatter(
103
+ # x=[-x], # Reflect the x coordinate
104
+ # y=[y],
105
+ # mode='markers',
106
+ # marker=dict(size=10, color='blue'),
107
+ # hoverinfo='none'
108
+ # ))
109
+ # fig1.add_annotation(
110
+ # x=-x, # Reflect the x coordinate
111
+ # y=y,
112
+ # text=colored_node,
113
+ # showarrow=False,
114
+ # xshift=15,
115
+ # align="center",
116
+ # font=dict(size=12),
117
+ # bordercolor='black',
118
+ # borderwidth=1,
119
+ # borderpad=2,
120
+ # bgcolor='white',
121
+ # width=300,
122
+ # height=120
123
+ # )
124
 
125
+ # # Add edges and text above each edge
126
+ # for i, edge in enumerate(edges):
127
+ # x0, y0 = positions[edge[0]]
128
+ # x1, y1 = positions[edge[1]]
129
+ # fig1.add_trace(go.Scatter(
130
+ # x=[-x0, -x1], # Reflect the x coordinates
131
+ # y=[y0, y1],
132
+ # mode='lines',
133
+ # line=dict(color='black', width=1)
134
+ # ))
135
 
136
+ # # Calculate the midpoint of the edge
137
+ # mid_x = (-x0 + -x1) / 2
138
+ # mid_y = (y0 + y1) / 2
139
+
140
+ # # Adjust y position to shift text upwards
141
+ # text_y_position = mid_y + 0.8 # Increase this value to shift the text further upwards
142
+
143
+ # # Add text annotation above the edge
144
+ # fig1.add_annotation(
145
+ # x=mid_x,
146
+ # y=text_y_position,
147
+ # text=edge_texts[i], # Use the text specific to this edge
148
+ # showarrow=False,
149
+ # font=dict(size=12),
150
+ # align="center"
151
+ # )
152
+
153
+ # fig1.update_layout(
154
+ # showlegend=False,
155
+ # margin=dict(t=20, b=20, l=20, r=20),
156
+ # xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
157
+ # yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
158
+ # width=1435, # Adjusted width to accommodate more levels
159
+ # height=1000 # Adjusted height to accommodate more levels
160
+ # )
161
+
162
+ # return fig1
163
+
164
+
165
+
166
+ # def generate_subplot2(scheme_sentences, sampled_sentence, highlight_info):
167
+ # # Combine nodes into one list with appropriate labels
168
+ # nodes = scheme_sentences + sampled_sentence
169
+ # para_len = len(scheme_sentences)
170
+
171
+ # # Reassign levels: L1 -> L0, L2 -> L1
172
+ # for i in range(para_len):
173
+ # nodes[i] += ' L0' # Scheme sentences are now level 0
174
+ # for i in range(para_len, len(nodes)):
175
+ # nodes[i] += ' L1' # Sampled sentences are now level 1
176
+
177
+ # # Define the highlight_words function
178
+ # def highlight_words(sentence, color_map):
179
+ # for word, color in color_map.items():
180
+ # sentence = re.sub(f"\\b{word}\\b", f"{{{{{word}}}}}", sentence, flags=re.IGNORECASE)
181
+ # return sentence
182
+
183
+ # # Clean and wrap nodes, and highlight specified words globally
184
+ # cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes]
185
+ # global_color_map = dict(highlight_info)
186
+ # highlighted_nodes = [highlight_words(node, global_color_map) for node in cleaned_nodes]
187
+ # wrapped_nodes = ['<br>'.join(textwrap.wrap(node, width=80)) for node in highlighted_nodes]
188
+
189
+ # # Function to determine tree levels and create edges dynamically
190
+ # def get_levels_and_edges(nodes):
191
+ # levels = {}
192
+ # edges = []
193
+ # for i, node in enumerate(nodes):
194
+ # level = int(node.split()[-1][1])
195
+ # levels[i] = level
196
+
197
+ # # Add edges from L0 to all L1 nodes
198
+ # l0_indices = [i for i, level in levels.items() if level == 0]
199
+ # l1_indices = [i for i, level in levels.items() if level == 1]
200
+
201
+ # # Ensure there are exactly 3 L0 nodes
202
+ # if len(l0_indices) < 3:
203
+ # raise ValueError("There should be exactly 3 L0 nodes to attach edges correctly.")
204
+
205
+ # # Split L1 nodes into 3 groups of 4 for attaching to L0 nodes
206
+ # for i, l1_node in enumerate(l1_indices):
207
+ # if i < 4:
208
+ # edges.append((l0_indices[0], l1_node)) # Connect to the first L0 node
209
+ # elif i < 8:
210
+ # edges.append((l0_indices[1], l1_node)) # Connect to the second L0 node
211
+ # else:
212
+ # edges.append((l0_indices[2], l1_node)) # Connect to the third L0 node
213
 
214
  # return levels, edges
215
 
216
  # # Get levels and dynamic edges
217
  # levels, edges = get_levels_and_edges(nodes)
 
218
 
219
  # # Calculate positions
220
  # positions = {}
 
225
  # y_offsets = {level: - (height - 1) / 2 for level, height in level_heights.items()}
226
  # x_gap = 2
227
  # l1_y_gap = 10
 
228
 
229
  # for node, level in levels.items():
230
  # if level == 1:
231
  # positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap)
 
 
232
  # else:
233
+ # positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap)
234
  # y_offsets[level] += 1
235
 
236
  # # Function to highlight words in a wrapped node string
 
247
  # colored_parts.append(part)
248
  # return ''.join(colored_parts)
249
 
250
+ # # Define the text for each edge
251
+ # edge_texts = [
252
+ # "Highest Entropy Masking",
253
+ # "Pseudo-random Masking",
254
+ # "Random Masking",
255
+ # "Greedy Sampling",
256
+ # "Temperature Sampling",
257
+ # "Exponential Minimum Sampling",
258
+ # "Inverse Transform Sampling",
259
+ # "Greedy Sampling",
260
+ # "Temperature Sampling",
261
+ # "Exponential Minimum Sampling",
262
+ # "Inverse Transform Sampling",
263
+ # "Greedy Sampling",
264
+ # "Temperature Sampling",
265
+ # "Exponential Minimum Sampling",
266
+ # "Inverse Transform Sampling"
267
+ # ]
268
+
269
  # # Create figure
270
+ # fig2 = go.Figure()
271
 
272
  # # Add nodes to the figure
273
  # for i, node in enumerate(wrapped_nodes):
274
  # colored_node = color_highlighted_words(node, global_color_map)
275
  # x, y = positions[i]
276
+ # fig2.add_trace(go.Scatter(
277
  # x=[-x], # Reflect the x coordinate
278
  # y=[y],
279
  # mode='markers',
280
  # marker=dict(size=10, color='blue'),
281
  # hoverinfo='none'
282
  # ))
283
+ # fig2.add_annotation(
284
  # x=-x, # Reflect the x coordinate
285
  # y=y,
286
  # text=colored_node,
287
  # showarrow=False,
288
  # xshift=15,
289
  # align="center",
290
+ # font=dict(size=12),
291
  # bordercolor='black',
292
  # borderwidth=1,
293
  # borderpad=2,
294
  # bgcolor='white',
295
+ # width=450,
296
+ # height=65
297
  # )
298
 
299
+ # # Add edges and text above each edge
300
+ # for i, edge in enumerate(edges):
301
  # x0, y0 = positions[edge[0]]
302
  # x1, y1 = positions[edge[1]]
303
+ # fig2.add_trace(go.Scatter(
304
  # x=[-x0, -x1], # Reflect the x coordinates
305
  # y=[y0, y1],
306
  # mode='lines',
307
  # line=dict(color='black', width=1)
308
  # ))
309
 
310
+ # # Calculate the midpoint of the edge
311
+ # mid_x = (-x0 + -x1) / 2
312
+ # mid_y = (y0 + y1) / 2
313
+
314
+ # # Adjust y position to shift text upwards
315
+ # text_y_position = mid_y + 0.8 # Increase this value to shift the text further upwards
316
+
317
+ # # Add text annotation above the edge
318
+ # fig2.add_annotation(A surprising aspect of tests, specifically self-testing soon after exposure to new material, is that they can significantly improve your ability to learn, apply, and maintain new knowledge.
319
+ # x=mid_x,
320
+ # y=text_y_position,
321
+ # text=edge_texts[i], # Use the text specific to this edge
322
+ # showarrow=False,
323
+ # font=dict(size=12),
324
+ # align="center"
325
+ # )
326
+
327
+ # fig2.update_layout(
328
  # showlegend=False,
329
  # margin=dict(t=20, b=20, l=20, r=20),
330
  # xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
331
  # yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
332
+ # width=1435, # Adjusted width to accommodate more levels
333
  # height=1000 # Adjusted height to accommodate more levels
334
  # )
335
 
336
+ # return fig2
337
+
338
 
339
  import plotly.graph_objects as go
340
  import textwrap
341
  import re
342
  from collections import defaultdict
343
 
344
+ def generate_subplot1(paraphrased_sentence, scheme_sentences, highlight_info, common_grams):
345
  # Combine nodes into one list with appropriate labels
346
+ nodes = [paraphrased_sentence] + scheme_sentences
347
  nodes[0] += ' L0' # Paraphrased sentence is level 0
348
+ for i in range(1, len(nodes)):
 
349
  nodes[i] += ' L1' # Scheme sentences are level 1
350
+
351
+ # Function to apply LCS numbering based on common_grams
352
+ def apply_lcs_numbering(sentence, common_grams):
353
+ for idx, lcs in common_grams:
354
+ # Only replace if the LCS is a whole word (not part of another word)
355
+ sentence = re.sub(rf"\b{lcs}\b", f"({idx}){lcs}", sentence)
356
+ return sentence
357
+
358
+ # Apply LCS numbering
359
+ nodes = [apply_lcs_numbering(node, common_grams) for node in nodes]
360
 
361
  # Define the highlight_words function
362
  def highlight_words(sentence, color_map):
 
368
  cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes]
369
  global_color_map = dict(highlight_info)
370
  highlighted_nodes = [highlight_words(node, global_color_map) for node in cleaned_nodes]
371
+ wrapped_nodes = ['<br>'.join(textwrap.wrap(node, width=55)) for node in highlighted_nodes]
372
 
373
  # Function to determine tree levels and create edges dynamically
374
  def get_levels_and_edges(nodes):
 
384
  if level == 1:
385
  edges.append((root_node, i))
386
 
387
+ return levels, edges
 
 
388
 
389
+ # Get levels and dynamic edges
390
+ levels, edges = get_levels_and_edges(nodes)
391
+ max_level = max(levels.values(), default=0)
 
 
 
392
 
393
+ # Calculate positions
394
+ positions = {}
395
+ level_heights = defaultdict(int)
396
+ for node, level in levels.items():
397
+ level_heights[level] += 1
398
+
399
+ y_offsets = {level: - (height - 1) / 2 for level, height in level_heights.items()}
400
+ x_gap = 2
401
+ l1_y_gap = 10
402
+
403
+ for node, level in levels.items():
404
+ if level == 1:
405
+ positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap)
406
+ else:
407
+ positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap)
408
+ y_offsets[level] += 1
409
+
410
+ # Function to highlight words in a wrapped node string
411
+ def color_highlighted_words(node, color_map):
412
+ parts = re.split(r'(\{\{.*?\}\})', node)
413
+ colored_parts = []
414
+ for part in parts:
415
+ match = re.match(r'\{\{(.*?)\}\}', part)
416
+ if match:
417
+ word = match.group(1)
418
+ color = color_map.get(word, 'black')
419
+ colored_parts.append(f"<span style='color: {color};'>{word}</span>")
420
+ else:
421
+ colored_parts.append(part)
422
+ return ''.join(colored_parts)
423
+
424
+ # Define the text for each edge
425
+ edge_texts = [
426
+ "Highest Entropy Masking",
427
+ "Pseudo-random Masking",
428
+ "Random Masking",
429
+ "Greedy Sampling",
430
+ "Temperature Sampling",
431
+ "Exponential Minimum Sampling",
432
+ "Inverse Transform Sampling",
433
+ "Greedy Sampling",
434
+ "Temperature Sampling",
435
+ "Exponential Minimum Sampling",
436
+ "Inverse Transform Sampling",
437
+ "Greedy Sampling",
438
+ "Temperature Sampling",
439
+ "Exponential Minimum Sampling",
440
+ "Inverse Transform Sampling"
441
+ ]
442
+
443
+ # Create figure
444
+ fig1 = go.Figure()
445
+
446
+ # Add nodes to the figure
447
+ for i, node in enumerate(wrapped_nodes):
448
+ colored_node = color_highlighted_words(node, global_color_map)
449
+ x, y = positions[i]
450
+ fig1.add_trace(go.Scatter(
451
+ x=[-x], # Reflect the x coordinate
452
+ y=[y],
453
+ mode='markers',
454
+ marker=dict(size=10, color='blue'),
455
+ hoverinfo='none'
456
+ ))
457
+ fig1.add_annotation(
458
+ x=-x, # Reflect the x coordinate
459
+ y=y,
460
+ text=colored_node,
461
+ showarrow=False,
462
+ xshift=15,
463
+ align="center",
464
+ font=dict(size=12),
465
+ bordercolor='black',
466
+ borderwidth=1,
467
+ borderpad=2,
468
+ bgcolor='white',
469
+ width=300,
470
+ height=120
471
+ )
472
+
473
+ # Add edges and text above each edge
474
+ for i, edge in enumerate(edges):
475
+ x0, y0 = positions[edge[0]]
476
+ x1, y1 = positions[edge[1]]
477
+ fig1.add_trace(go.Scatter(
478
+ x=[-x0, -x1], # Reflect the x coordinates
479
+ y=[y0, y1],
480
+ mode='lines',
481
+ line=dict(color='black', width=1)
482
+ ))
483
+
484
+ # Calculate the midpoint of the edge
485
+ mid_x = (-x0 + -x1) / 2
486
+ mid_y = (y0 + y1) / 2
487
+
488
+ # Adjust y position to shift text upwards
489
+ text_y_position = mid_y + 0.8 # Increase this value to shift the text further upwards
490
 
491
+ # Add text annotation above the edge
492
+ fig1.add_annotation(
493
+ x=mid_x,
494
+ y=text_y_position,
495
+ text=edge_texts[i], # Use the text specific to this edge
496
+ showarrow=False,
497
+ font=dict(size=12),
498
+ align="center"
499
+ )
500
 
501
+ fig1.update_layout(
502
+ showlegend=False,
503
+ margin=dict(t=20, b=20, l=20, r=20),
504
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
505
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
506
+ width=1435, # Adjusted width to accommodate more levels
507
+ height=1000 # Adjusted height to accommodate more levels
508
+ )
509
 
510
+ return fig1
511
+
512
+ def generate_subplot2(scheme_sentences, sampled_sentence, highlight_info, common_grams):
513
+ # Combine nodes into one list with appropriate labels
514
+ nodes = scheme_sentences + sampled_sentence
515
+ para_len = len(scheme_sentences)
516
+
517
+ # Reassign levels: L1 -> L0, L2 -> L1
518
+ for i in range(para_len):
519
+ nodes[i] += ' L0' # Scheme sentences are now level 0
520
+ for i in range(para_len, len(nodes)):
521
+ nodes[i] += ' L1' # Sampled sentences are now level 1
522
+
523
+ # Function to apply LCS numbering based on common_grams
524
+ def apply_lcs_numbering(sentence, common_grams):
525
+ for idx, lcs in common_grams:
526
+ # Only replace if the LCS is a whole word (not part of another word)
527
+ sentence = re.sub(rf"\b{lcs}\b", f"({idx}){lcs}", sentence)
528
+ return sentence
529
+
530
+ # Apply LCS numbering
531
+ nodes = [apply_lcs_numbering(node, common_grams) for node in nodes]
532
+
533
+ # Define the highlight_words function
534
+ def highlight_words(sentence, color_map):
535
+ for word, color in color_map.items():
536
+ sentence = re.sub(f"\\b{word}\\b", f"{{{{{word}}}}}", sentence, flags=re.IGNORECASE)
537
+ return sentence
538
+
539
+ # Clean and wrap nodes, and highlight specified words globally
540
+ cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes]
541
+ global_color_map = dict(highlight_info)
542
+ highlighted_nodes = [highlight_words(node, global_color_map) for node in cleaned_nodes]
543
+ wrapped_nodes = ['<br>'.join(textwrap.wrap(node, width=80)) for node in highlighted_nodes]
544
+
545
+ # Function to determine tree levels and create edges dynamically
546
+ def get_levels_and_edges(nodes):
547
+ levels = {}
548
+ edges = []
549
+ for i, node in enumerate(nodes):
550
+ level = int(node.split()[-1][1])
551
+ levels[i] = level
552
+
553
+ # Add edges from L0 to all L1 nodes
554
+ l0_indices = [i for i, level in levels.items() if level == 0]
555
+ l1_indices = [i for i, level in levels.items() if level == 1]
556
+
557
+ # Ensure there are exactly 3 L0 nodes
558
+ if len(l0_indices) < 3:
559
+ raise ValueError("There should be exactly 3 L0 nodes to attach edges correctly.")
560
+
561
+ # Split L1 nodes into 3 groups of 4 for attaching to L0 nodes
562
+ for i, l1_node in enumerate(l1_indices):
563
+ if i < 4:
564
+ edges.append((l0_indices[0], l1_node)) # Connect to the first L0 node
565
+ elif i < 8:
566
+ edges.append((l0_indices[1], l1_node)) # Connect to the second L0 node
567
+ else:
568
+ edges.append((l0_indices[2], l1_node)) # Connect to the third L0 node
569
 
570
  return levels, edges
571
 
 
582
  y_offsets = {level: - (height - 1) / 2 for level, height in level_heights.items()}
583
  x_gap = 2
584
  l1_y_gap = 10
 
585
 
586
  for node, level in levels.items():
587
  if level == 1:
588
  positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap)
 
 
589
  else:
590
+ positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap)
591
  y_offsets[level] += 1
592
 
593
  # Function to highlight words in a wrapped node string
 
624
  ]
625
 
626
  # Create figure
627
+ fig2 = go.Figure()
628
 
629
  # Add nodes to the figure
630
  for i, node in enumerate(wrapped_nodes):
631
  colored_node = color_highlighted_words(node, global_color_map)
632
  x, y = positions[i]
633
+ fig2.add_trace(go.Scatter(
634
  x=[-x], # Reflect the x coordinate
635
  y=[y],
636
  mode='markers',
637
  marker=dict(size=10, color='blue'),
638
  hoverinfo='none'
639
  ))
640
+ fig2.add_annotation(
641
  x=-x, # Reflect the x coordinate
642
  y=y,
643
  text=colored_node,
644
  showarrow=False,
645
  xshift=15,
646
  align="center",
647
+ font=dict(size=12),
648
  bordercolor='black',
649
  borderwidth=1,
650
  borderpad=2,
651
  bgcolor='white',
652
+ width=450,
653
+ height=65
654
  )
655
 
656
  # Add edges and text above each edge
657
  for i, edge in enumerate(edges):
658
  x0, y0 = positions[edge[0]]
659
  x1, y1 = positions[edge[1]]
660
+ fig2.add_trace(go.Scatter(
661
  x=[-x0, -x1], # Reflect the x coordinates
662
  y=[y0, y1],
663
  mode='lines',
 
672
  text_y_position = mid_y + 0.8 # Increase this value to shift the text further upwards
673
 
674
  # Add text annotation above the edge
675
+ # Use a fallback text if we exceed the length of edge_texts
676
+ text = edge_texts[i] if i < len(edge_texts) else f"Edge {i+1}"
677
+ fig2.add_annotation(
678
  x=mid_x,
679
  y=text_y_position,
680
+ text=text, # Use the text specific to this edge
681
  showarrow=False,
682
+ font=dict(size=12),
683
  align="center"
684
  )
685
 
686
+ fig2.update_layout(
687
  showlegend=False,
688
  margin=dict(t=20, b=20, l=20, r=20),
689
  xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
690
  yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
691
+ width=1435, # Adjusted width to accommodate more levels
692
  height=1000 # Adjusted height to accommodate more levels
693
  )
694
 
695
+ return fig2