File size: 3,120 Bytes
11112c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import networkx as nx
import matplotlib.pyplot as plt
import jraph
import jax.numpy as jnp
from datasets import load_dataset
import spacy

dataset = load_dataset("gigant/tib_transcripts")


nlp = spacy.load("en_core_web_sm")

def dependency_parser(sentences):
  return [nlp(sentence) for sentence in sentences]

def construct_dependency_graph(docs):
  """
  docs is a list of outputs of the SpaCy dependency parser
  """
  graphs = []
  for doc in docs:
    nodes = [token.text for token in doc]
    senders = []
    receivers = []
    for token in doc:
        for child in token.children:
            senders.append(token.i)
            receivers.append(child.i)
    graphs.append({"nodes": nodes, "senders": senders, "receivers": receivers})
  return graphs

def to_jraph(graph):
  nodes = graph["nodes"]
  s = graph["senders"]
  r = graph["receivers"]

  # Define a three node graph, each node has an integer as its feature.
  node_features = jnp.array([0]*len(nodes))

  # We will construct a graph for which there is a directed edge between each node
  # and its successor. We define this with `senders` (source nodes) and `receivers`
  # (destination nodes).
  senders = jnp.array(s)
  receivers = jnp.array(r)

  # We then save the number of nodes and the number of edges.
  # This information is used to make running GNNs over multiple graphs
  # in a GraphsTuple possible.
  n_node = jnp.array([len(nodes)])
  n_edge = jnp.array([len(s)])


  return jraph.GraphsTuple(nodes=node_features, senders=senders, receivers=receivers,
  edges=None, n_node=n_node, n_edge=n_edge, globals=None)

def convert_jraph_to_networkx_graph(jraph_graph: jraph.GraphsTuple) -> nx.Graph:
  nodes, edges, receivers, senders, _, _, _ = jraph_graph
  nx_graph = nx.DiGraph()
  if nodes is None:
    for n in range(jraph_graph.n_node[0]):
      nx_graph.add_node(n)
  else:
    for n in range(jraph_graph.n_node[0]):
      nx_graph.add_node(n, node_feature=nodes[n])
  if edges is None:
    for e in range(jraph_graph.n_edge[0]):
      nx_graph.add_edge(int(senders[e]), int(receivers[e]))
  else:
    for e in range(jraph_graph.n_edge[0]):
      nx_graph.add_edge(
          int(senders[e]), int(receivers[e]), edge_feature=edges[e])
  return nx_graph

def plot_graph_sentence(sentence):
  docs = dependency_parser([sentence])
  graphs = construct_dependency_graph(docs)
  g = to_jraph(graphs[0])
  nx_graph = convert_jraph_to_networkx_graph(g)
  pos = nx.spring_layout(nx_graph)
  plot = plt.figure(figsize=(6, 6))
  nx.draw(nx_graph, pos=pos, labels={i: e for i,e in enumerate(graphs[0]["nodes"])}, with_labels = True,
          node_size=500, font_color='black', node_color="yellow")
  return plot
    
def get_list_sentences(id):
  return gr.update(choices = dataset["train"][id]["transcript"].split("."))

with gr.Blocks() as demo:
    id = gr.Slider(maximum=len(dataset["train"]) - 1)
    sentence = gr.Dropdown(choices = dataset["train"][0]["transcript"].split("."), interactive = True)
    plot = gr.Plot()
    id.change(get_list_sentences, id, sentence)
    sentence.change(plot_graph_sentence, sentence, plot)

demo.launch()