gigant commited on
Commit
11112c6
1 Parent(s): 77a319b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -0
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import networkx as nx
2
+ import matplotlib.pyplot as plt
3
+ import jraph
4
+ import jax.numpy as jnp
5
+ from datasets import load_dataset
6
+ import spacy
7
+
8
+ dataset = load_dataset("gigant/tib_transcripts")
9
+
10
+
11
+ nlp = spacy.load("en_core_web_sm")
12
+
13
+ def dependency_parser(sentences):
14
+ return [nlp(sentence) for sentence in sentences]
15
+
16
+ def construct_dependency_graph(docs):
17
+ """
18
+ docs is a list of outputs of the SpaCy dependency parser
19
+ """
20
+ graphs = []
21
+ for doc in docs:
22
+ nodes = [token.text for token in doc]
23
+ senders = []
24
+ receivers = []
25
+ for token in doc:
26
+ for child in token.children:
27
+ senders.append(token.i)
28
+ receivers.append(child.i)
29
+ graphs.append({"nodes": nodes, "senders": senders, "receivers": receivers})
30
+ return graphs
31
+
32
+ def to_jraph(graph):
33
+ nodes = graph["nodes"]
34
+ s = graph["senders"]
35
+ r = graph["receivers"]
36
+
37
+ # Define a three node graph, each node has an integer as its feature.
38
+ node_features = jnp.array([0]*len(nodes))
39
+
40
+ # We will construct a graph for which there is a directed edge between each node
41
+ # and its successor. We define this with `senders` (source nodes) and `receivers`
42
+ # (destination nodes).
43
+ senders = jnp.array(s)
44
+ receivers = jnp.array(r)
45
+
46
+ # We then save the number of nodes and the number of edges.
47
+ # This information is used to make running GNNs over multiple graphs
48
+ # in a GraphsTuple possible.
49
+ n_node = jnp.array([len(nodes)])
50
+ n_edge = jnp.array([len(s)])
51
+
52
+
53
+ return jraph.GraphsTuple(nodes=node_features, senders=senders, receivers=receivers,
54
+ edges=None, n_node=n_node, n_edge=n_edge, globals=None)
55
+
56
+ def convert_jraph_to_networkx_graph(jraph_graph: jraph.GraphsTuple) -> nx.Graph:
57
+ nodes, edges, receivers, senders, _, _, _ = jraph_graph
58
+ nx_graph = nx.DiGraph()
59
+ if nodes is None:
60
+ for n in range(jraph_graph.n_node[0]):
61
+ nx_graph.add_node(n)
62
+ else:
63
+ for n in range(jraph_graph.n_node[0]):
64
+ nx_graph.add_node(n, node_feature=nodes[n])
65
+ if edges is None:
66
+ for e in range(jraph_graph.n_edge[0]):
67
+ nx_graph.add_edge(int(senders[e]), int(receivers[e]))
68
+ else:
69
+ for e in range(jraph_graph.n_edge[0]):
70
+ nx_graph.add_edge(
71
+ int(senders[e]), int(receivers[e]), edge_feature=edges[e])
72
+ return nx_graph
73
+
74
+ def plot_graph_sentence(sentence):
75
+ docs = dependency_parser([sentence])
76
+ graphs = construct_dependency_graph(docs)
77
+ g = to_jraph(graphs[0])
78
+ nx_graph = convert_jraph_to_networkx_graph(g)
79
+ pos = nx.spring_layout(nx_graph)
80
+ plot = plt.figure(figsize=(6, 6))
81
+ nx.draw(nx_graph, pos=pos, labels={i: e for i,e in enumerate(graphs[0]["nodes"])}, with_labels = True,
82
+ node_size=500, font_color='black', node_color="yellow")
83
+ return plot
84
+
85
+ def get_list_sentences(id):
86
+ return gr.update(choices = dataset["train"][id]["transcript"].split("."))
87
+
88
+ with gr.Blocks() as demo:
89
+ id = gr.Slider(maximum=len(dataset["train"]) - 1)
90
+ sentence = gr.Dropdown(choices = dataset["train"][0]["transcript"].split("."), interactive = True)
91
+ plot = gr.Plot()
92
+ id.change(get_list_sentences, id, sentence)
93
+ sentence.change(plot_graph_sentence, sentence, plot)
94
+
95
+ demo.launch()