import os os.system("pip install networkx") os.system("pip install Cython") os.system("pip install benepar") import networkx as nx import matplotlib.pyplot as plt import jraph import jax.numpy as jnp from datasets import load_dataset import spacy import gradio as gr import en_core_web_trf import numpy as np import benepar import re dataset = load_dataset("gigant/tib_transcripts") nlp = en_core_web_trf.load() benepar.download('benepar_en3') nlp.add_pipe('benepar', config={'model': 'benepar_en3'}) def parse_tree(sentence): stack = [] # or a `collections.deque()` object, which is a little faster top = items = [] for token in filter(None, re.compile(r'(?:([()])|\s+)').split(sentence)): if token == '(': stack.append(items) items.append([]) items = items[-1] elif token == ')': if not stack: raise ValueError("Unbalanced parentheses") items = stack.pop() else: items.append(token) if stack: raise ValueError("Unbalanced parentheses") return top class Tree(): def __init__(self, name, children): self.children = children self.name = name self.id = None def set_id_rec(self, id=0): self.id = id last_id=id for child in self.children: last_id = child.set_id_rec(id=last_id+1) return last_id def set_all_ids(self): self.set_id_rec(0) def print_tree(self, level=0): to_print = f'|{"-" * level} {self.name} ({self.id})' for child in self.children: to_print += f"\n{child.print_tree(level + 1)}" return to_print def __str__(self): return self.print_tree(0) def get_list_nodes(self): return [self.name] + [_ for child in self.children for _ in child.get_list_nodes()] def rec_const_parsing(list_nodes): if isinstance(list_nodes, list): name, children = list_nodes[0], list_nodes[1:] else: name, children = list_nodes, [] return Tree(name, [rec_const_parsing(child) for i, child in enumerate(children)]) def tree_to_graph(t): senders = [] receivers = [] for child in t.children: senders.append(t.id) receivers.append(child.id) s_rec, r_rec = tree_to_graph(child) senders.extend(s_rec) receivers.extend(r_rec) return senders, receivers def construct_constituency_graph(docs): doc = docs[0] sent = list(doc.sents)[0] print(sent._.parse_string) t = rec_const_parsing(parse_tree(sent._.parse_string)[0]) t.set_all_ids() senders, receivers = tree_to_graph(t) nodes = t.get_list_nodes() graphs = [{"nodes": nodes, "senders": senders, "receivers": receivers, "edge_labels": {}}] return graphs def half_circle_layout(n_nodes, sentence_node=True): pos = {} for i_node in range(n_nodes - 1): pos[i_node] = ((- np.cos(i_node * np.pi/(n_nodes - 1))), 0.5 * (-np.sin(i_node * np.pi/(n_nodes - 1)))) pos[n_nodes - 1] = (0, -0.25) return pos def get_adjacency_matrix(jraph_graph: jraph.GraphsTuple): nodes, edges, receivers, senders, _, _, _ = jraph_graph adj_mat = jnp.zeros((len(nodes), len(nodes))) for i in range(len(receivers)): adj_mat = adj_mat.at[senders[i], receivers[i]].set(1) return adj_mat def dependency_parser(sentences): return [nlp(sentence) for sentence in sentences] def construct_dependency_graph(docs): """ docs is a list of outputs of the SpaCy dependency parser """ graphs = [] for doc in docs: nodes = [token.text for token in doc] senders = [] receivers = [] edge_labels = {} for token in doc: for child in token.children: senders.append(child.i) receivers.append(token.i) edge_labels[(token.i, child.i)] = child.dep_ graphs.append({"nodes": nodes, "senders": senders, "receivers": receivers, "edge_labels": edge_labels}) return graphs def construct_both_graph(docs): """ docs is a list of outputs of the SpaCy dependency parser """ graphs = [] for doc in docs: nodes = [token.text for token in doc] nodes.append("Sentence") senders = [token.i for token in doc][:-1] senders.extend([token.i for token in doc][1:]) receivers = [token.i for token in doc][1:] receivers.extend([token.i for token in doc][:-1]) edge_labels = {(token.i, token.i + 1): "next" for token in doc[:-1]} for token in doc[:-1]: edge_labels[(token.i + 1, token.i)] = "previous" for node in range(len(nodes) - 1): senders.append(node) receivers.append(len(nodes) - 1) edge_labels[(node, len(nodes) - 1)] = "in" for token in doc: for child in token.children: senders.append(child.i) receivers.append(token.i) edge_labels[(token.i, child.i)] = child.dep_ graphs.append({"nodes": nodes, "senders": senders, "receivers": receivers, "edge_labels": edge_labels}) return graphs def construct_structural_graph(docs): graphs = [] for doc in docs: nodes = [token.text for token in doc] nodes.append("Sentence") senders = [token.i for token in doc][:-1] senders.extend([token.i for token in doc][1:]) receivers = [token.i for token in doc][1:] receivers.extend([token.i for token in doc][:-1]) edge_labels = {(token.i, token.i + 1): "next" for token in doc[:-1]} for token in doc[:-1]: edge_labels[(token.i + 1, token.i)] = "previous" for node in range(len(nodes) - 1): senders.append(node) receivers.append(len(nodes) - 1) edge_labels[(node, len(nodes) - 1)] = "in" graphs.append({"nodes": nodes, "senders": senders, "receivers": receivers, "edge_labels": edge_labels}) return graphs def to_jraph(graph): nodes = graph["nodes"] s = graph["senders"] r = graph["receivers"] # Define a three node graph, each node has an integer as its feature. node_features = jnp.array([0]*len(nodes)) # We will construct a graph for which there is a directed edge between each node # and its successor. We define this with `senders` (source nodes) and `receivers` # (destination nodes). senders = jnp.array(s) receivers = jnp.array(r) # We then save the number of nodes and the number of edges. # This information is used to make running GNNs over multiple graphs # in a GraphsTuple possible. n_node = jnp.array([len(nodes)]) n_edge = jnp.array([len(s)]) return jraph.GraphsTuple(nodes=node_features, senders=senders, receivers=receivers, edges=None, n_node=n_node, n_edge=n_edge, globals=None) def convert_jraph_to_networkx_graph(jraph_graph: jraph.GraphsTuple) -> nx.Graph: nodes, edges, receivers, senders, _, _, _ = jraph_graph nx_graph = nx.DiGraph() if nodes is None: for n in range(jraph_graph.n_node[0]): nx_graph.add_node(n) else: for n in range(jraph_graph.n_node[0]): nx_graph.add_node(n, node_feature=nodes[n]) if edges is None: for e in range(jraph_graph.n_edge[0]): nx_graph.add_edge(int(senders[e]), int(receivers[e])) else: for e in range(jraph_graph.n_edge[0]): nx_graph.add_edge( int(senders[e]), int(receivers[e]), edge_feature=edges[e]) return nx_graph def plot_graph_sentence(sentence, graph_type="constituency"): # sentences = dataset["train"][0]["abstract"].split(".") docs = dependency_parser([sentence]) if graph_type == "dependency": graphs = construct_dependency_graph(docs) elif graph_type == "structural": graphs = construct_structural_graph(docs) elif graph_type == "structural+dependency": graphs = construct_both_graph(docs) elif graph_type == "constituency": graphs = construct_constituency_graph(docs) g = to_jraph(graphs[0]) adj_mat = get_adjacency_matrix(g) nx_graph = convert_jraph_to_networkx_graph(g) pos = half_circle_layout(len(graphs[0]["nodes"])) if graph_type == "constituency": pos = nx.planar_layout(nx_graph) plot = plt.figure(figsize=(12, 6)) nx.draw(nx_graph, pos=pos, labels={i: e for i,e in enumerate(graphs[0]["nodes"])}, with_labels = True, edge_color="blue", # connectionstyle="arc3,rad=0.1", node_size=1000, font_color='black', node_color="yellow") nx.draw_networkx_edge_labels( nx_graph, pos=pos, edge_labels=graphs[0]["edge_labels"], font_color='red' ) adj_mat_plot, ax = plt.subplots(figsize=(6, 6)) ax.matshow(adj_mat) return [gr.update(value=plot), gr.update(value=adj_mat_plot)] def get_list_sentences(id): id = int(min(id, len(dataset["train"]) - 1)) return gr.update(choices = dataset["train"][id]["transcript"].split(".")) with gr.Blocks() as demo: with gr.Row(): graph_type = gr.Dropdown(label="Graph type", choices=["structural", "dependency", "structural+dependency", "constituency"], value="structural+dependency", interactive = True) with gr.Tab("From transcript"): with gr.Row(): with gr.Column(): id = gr.Number(label="Transcript") with gr.Column(scale=3): sentence_transcript = gr.Dropdown(label="Sentence", choices = dataset["train"][0]["transcript"].split(".")[1:], interactive = True) with gr.Tab("Type sentence"): with gr.Row(): sentence_typed = gr.Textbox(label="Sentence", interactive = True) with gr.Row(): with gr.Column(scale=2): plot_graph = gr.Plot(label="Word graph") with gr.Column(): plot_adj = gr.Plot(label="Word graph adjacency matrix") id.change(get_list_sentences, id, sentence_transcript) sentence_transcript.change(plot_graph_sentence, [sentence_transcript, graph_type], [plot_graph, plot_adj]) sentence_typed.change(plot_graph_sentence, [sentence_typed, graph_type], [plot_graph, plot_adj]) demo.launch()