cryptocalypse commited on
Commit
6840d6b
1 Parent(s): a4844a1

Update psychohistory.py

Browse files
Files changed (1) hide show
  1. psychohistory.py +76 -46
psychohistory.py CHANGED
@@ -1,16 +1,16 @@
1
  import matplotlib.pyplot as plt
2
  from mpl_toolkits.mplot3d import Axes3D
3
  import networkx as nx
4
- import random
5
  import numpy as np
6
  import json
7
  import sys
 
8
 
9
  def generate_tree(current_x, current_y, depth, max_depth, max_nodes, x_range, G, parent=None, node_count_per_depth=None):
10
- """Generates a tree of nodes with positions adjusted on the x-axis, and the number of nodes on the z-axis."""
11
  if node_count_per_depth is None:
12
  node_count_per_depth = {}
13
-
14
  if depth not in node_count_per_depth:
15
  node_count_per_depth[depth] = 0
16
 
@@ -19,13 +19,13 @@ def generate_tree(current_x, current_y, depth, max_depth, max_nodes, x_range, G,
19
 
20
  num_children = random.randint(1, max_nodes)
21
  x_positions = [current_x + i * x_range / (num_children + 1) for i in range(num_children)]
22
-
23
  for x in x_positions:
24
  # Add node to the graph
25
  node_id = len(G.nodes)
26
  node_count_per_depth[depth] += 1
27
  prob = random.uniform(0, 1) # Assign random probability
28
- G.add_node(node_id, pos=(x, prob, depth)) # Use `depth` for the z position
29
  if parent is not None:
30
  G.add_edge(parent, node_id)
31
  # Recursively add child nodes
@@ -33,27 +33,39 @@ def generate_tree(current_x, current_y, depth, max_depth, max_nodes, x_range, G,
33
 
34
  return node_count_per_depth
35
 
 
 
36
  def build_graph_from_json(json_data, G):
37
  """Builds a graph from JSON data."""
38
- def add_event(parent_id, event_data, prob_level):
39
- for key, value in event_data.get('events', {}).items():
40
- # Add node
41
- node_id = len(G.nodes)
42
- prob = {'high_probability': 0.9, 'medium_probability': 0.5, 'low_probability': 0.1}[prob_level]
43
- G.add_node(node_id, pos=(len(G.nodes), prob, len(G.nodes))) # Ensure each node has 'pos'
 
 
 
44
  G.add_edge(parent_id, node_id)
45
- # Add child events
46
- add_event(node_id, {'events': value}, key)
47
 
48
- root_id = len(G.nodes)
49
- G.add_node(root_id, pos=(0, 0.5, 0)) # Root node with default medium probability
50
- if len(G.nodes) > 1:
51
- G.add_edge(-1, root_id) # Root node without a parent
 
 
 
 
52
  data = json.loads(json_data)
53
- add_event(root_id, data, 'medium_probability')
 
 
 
 
 
54
 
55
  def find_paths(G):
56
- """Finds the paths with the highest and lowest average probability, and the maximum and minimum duration in graph G."""
57
  best_path = None
58
  worst_path = None
59
  longest_duration_path = None
@@ -72,30 +84,30 @@ def find_paths(G):
72
  if not all('pos' in G.nodes[node] for node in path):
73
  continue # Skip paths with nodes missing the 'pos' attribute
74
 
75
- # Calculate the average probability of the path
76
- probabilities = [G.nodes[node]['pos'][1] for node in path] # Get probabilities of the nodes in the path
77
  mean_prob = np.mean(probabilities)
78
 
79
- # Evaluate the path with the highest average probability
80
  if mean_prob > best_mean_prob:
81
  best_mean_prob = mean_prob
82
  best_path = path
83
 
84
- # Evaluate the path with the lowest average probability
85
  if mean_prob < worst_mean_prob:
86
  worst_mean_prob = mean_prob
87
  worst_path = path
88
 
89
- # Calculate the duration of the path
90
  x_positions = [G.nodes[node]['pos'][0] for node in path]
91
  duration = max(x_positions) - min(x_positions)
92
 
93
- # Evaluate the path with the maximum duration
94
  if duration > max_duration:
95
  max_duration = duration
96
  longest_duration_path = path
97
 
98
- # Evaluate the path with the minimum duration
99
  if duration < min_duration:
100
  min_duration = duration
101
  shortest_duration_path = path
@@ -115,7 +127,7 @@ def draw_path_3d(G, path, filename='path_plot_3d.png', highlight_color='blue'):
115
  fig = plt.figure(figsize=(16, 12))
116
  ax = fig.add_subplot(111, projection='3d')
117
 
118
- # Assign colors to the nodes based on probability
119
  node_colors = []
120
  for node in path:
121
  prob = G.nodes[node]['pos'][1]
@@ -135,31 +147,38 @@ def draw_path_3d(G, path, filename='path_plot_3d.png', highlight_color='blue'):
135
  x_end, y_end, z_end = pos[edge[1]]
136
  ax.plot([x_start, x_end], [y_start, y_end], [z_start, z_end], color=highlight_color, lw=2)
137
 
138
- # Add labels to the nodes
139
  for node, (x, y, z) in pos.items():
140
  if node in path:
141
  ax.text(x, y, z, str(node), fontsize=12, color='black')
142
 
143
- # Adjust labels and title
144
  ax.set_xlabel('Time (weeks)')
145
  ax.set_ylabel('Event Probability')
146
  ax.set_zlabel('Event Number')
147
- ax.set_title('Event Tree in 3D - Path')
 
 
 
148
 
149
- plt.savefig(filename, bbox_inches='tight') # Save to a file with adjusted margins
150
- plt.close() # Close the figure to free up resources
151
 
152
  def draw_global_tree_3d(G, filename='global_tree.png'):
153
  """Draws the entire graph in 3D using networkx and matplotlib and saves the figure to a file."""
154
  pos = nx.get_node_attributes(G, 'pos')
 
155
 
 
 
 
 
 
156
  # Get data for 3D visualization
157
  x_vals, y_vals, z_vals = zip(*pos.values())
158
 
159
  fig = plt.figure(figsize=(16, 12))
160
  ax = fig.add_subplot(111, projection='3d')
161
 
162
- # Assign colors to the nodes based on probability
163
  node_colors = []
164
  for node, (x, prob, z) in pos.items():
165
  if prob < 0.33:
@@ -178,18 +197,19 @@ def draw_global_tree_3d(G, filename='global_tree.png'):
178
  x_end, y_end, z_end = pos[edge[1]]
179
  ax.plot([x_start, x_end], [y_start, y_end], [z_start, z_end], color='gray', lw=2)
180
 
181
- # Add labels to the nodes
182
  for node, (x, y, z) in pos.items():
183
- ax.text(x, y, z, str(node), fontsize=12, color='black')
 
184
 
185
- # Adjust labels and title
186
- ax.set_xlabel('Time (weeks)')
187
- ax.set_ylabel('Event Probability')
188
  ax.set_zlabel('Event Number')
189
- ax.set_title('Event Tree in 3D')
190
 
191
- plt.savefig(filename, bbox_inches='tight') # Save to a file with adjusted margins
192
- plt.close() # Close the figure to free up resources
193
 
194
  def main(mode, input_file=None):
195
  G = nx.DiGraph()
@@ -197,12 +217,14 @@ def main(mode, input_file=None):
197
  if mode == 'random':
198
  starting_x = 0
199
  starting_y = 0
200
- max_depth = 5 # Maximum tree depth
201
  max_nodes = 3 # Maximum number of child nodes
202
- x_range = 10 # Maximum range for node x positions
203
 
204
- # Generate the tree and get the node count per depth
205
  generate_tree(starting_x, starting_y, 0, max_depth, max_nodes, x_range, G)
 
 
206
  elif mode == 'json' and input_file:
207
  with open(input_file, 'r') as file:
208
  json_data = file.read()
@@ -211,10 +233,14 @@ def main(mode, input_file=None):
211
  print("Invalid mode or input file not provided.")
212
  return
213
 
 
 
 
 
214
  # Find relevant paths
215
  best_path, best_mean_prob, worst_path, worst_mean_prob, longest_duration_path, shortest_duration_path = find_paths(G)
216
 
217
- # Print the results
218
  if best_path:
219
  print(f"\nPath with the highest average probability:")
220
  print(" -> ".join(map(str, best_path)))
@@ -251,10 +277,14 @@ def main(mode, input_file=None):
251
  if shortest_duration_path:
252
  draw_path_3d(G, path=shortest_duration_path, filename='shortest_duration_path.png', highlight_color='purple')
253
 
 
 
254
  if __name__ == "__main__":
255
  if len(sys.argv) < 2:
256
- print("Usage: python script.py <mode> [json_file]")
257
  else:
258
  mode = sys.argv[1]
259
  input_file = sys.argv[2] if len(sys.argv) > 2 else None
260
  main(mode, input_file)
 
 
 
1
  import matplotlib.pyplot as plt
2
  from mpl_toolkits.mplot3d import Axes3D
3
  import networkx as nx
 
4
  import numpy as np
5
  import json
6
  import sys
7
+ import random
8
 
9
  def generate_tree(current_x, current_y, depth, max_depth, max_nodes, x_range, G, parent=None, node_count_per_depth=None):
10
+ """Generates a tree of nodes with positions adjusted on the x-axis, y-axis, and number of nodes on the z-axis."""
11
  if node_count_per_depth is None:
12
  node_count_per_depth = {}
13
+
14
  if depth not in node_count_per_depth:
15
  node_count_per_depth[depth] = 0
16
 
 
19
 
20
  num_children = random.randint(1, max_nodes)
21
  x_positions = [current_x + i * x_range / (num_children + 1) for i in range(num_children)]
22
+
23
  for x in x_positions:
24
  # Add node to the graph
25
  node_id = len(G.nodes)
26
  node_count_per_depth[depth] += 1
27
  prob = random.uniform(0, 1) # Assign random probability
28
+ G.add_node(node_id, pos=(x, prob, depth)) # Use `depth` for z position
29
  if parent is not None:
30
  G.add_edge(parent, node_id)
31
  # Recursively add child nodes
 
33
 
34
  return node_count_per_depth
35
 
36
+
37
+
38
  def build_graph_from_json(json_data, G):
39
  """Builds a graph from JSON data."""
40
+ def add_event(parent_id, event_data, depth):
41
+ """Recursively adds events and subevents to the graph."""
42
+ # Add the current event node
43
+ node_id = len(G.nodes)
44
+ prob = event_data['probability'] / 100.0 # Convert percentage to probability
45
+ pos = (depth, prob, event_data['event_number']) # Use event_number for z position
46
+ label = event_data['name'] # Use event name as label
47
+ G.add_node(node_id, pos=pos, label=label)
48
+ if parent_id is not None:
49
  G.add_edge(parent_id, node_id)
 
 
50
 
51
+ # Add child events
52
+ subevents = event_data.get('subevents', {}).get('event', [])
53
+ if not isinstance(subevents, list):
54
+ subevents = [subevents] # Ensure subevents is a list
55
+
56
+ for subevent in subevents:
57
+ add_event(node_id, subevent, depth + 1)
58
+
59
  data = json.loads(json_data)
60
+ root_id = len(G.nodes)
61
+ root_event = list(data.get('events', {}).values())[0]
62
+ G.add_node(root_id, pos=(0, root_event['probability'] / 100.0, root_event['event_number']), label=root_event['name'])
63
+ add_event(None, root_event, 0) # Start from the root
64
+
65
+
66
 
67
  def find_paths(G):
68
+ """Finds the paths with the highest and lowest average probability, and the longest and shortest durations in graph G."""
69
  best_path = None
70
  worst_path = None
71
  longest_duration_path = None
 
84
  if not all('pos' in G.nodes[node] for node in path):
85
  continue # Skip paths with nodes missing the 'pos' attribute
86
 
87
+ # Calculate the mean probability of the path
88
+ probabilities = [G.nodes[node]['pos'][1] for node in path] # Get node probabilities
89
  mean_prob = np.mean(probabilities)
90
 
91
+ # Evaluate path with the highest mean probability
92
  if mean_prob > best_mean_prob:
93
  best_mean_prob = mean_prob
94
  best_path = path
95
 
96
+ # Evaluate path with the lowest mean probability
97
  if mean_prob < worst_mean_prob:
98
  worst_mean_prob = mean_prob
99
  worst_path = path
100
 
101
+ # Calculate path duration
102
  x_positions = [G.nodes[node]['pos'][0] for node in path]
103
  duration = max(x_positions) - min(x_positions)
104
 
105
+ # Evaluate path with the longest duration
106
  if duration > max_duration:
107
  max_duration = duration
108
  longest_duration_path = path
109
 
110
+ # Evaluate path with the shortest duration
111
  if duration < min_duration:
112
  min_duration = duration
113
  shortest_duration_path = path
 
127
  fig = plt.figure(figsize=(16, 12))
128
  ax = fig.add_subplot(111, projection='3d')
129
 
130
+ # Assign colors to nodes based on probability
131
  node_colors = []
132
  for node in path:
133
  prob = G.nodes[node]['pos'][1]
 
147
  x_end, y_end, z_end = pos[edge[1]]
148
  ax.plot([x_start, x_end], [y_start, y_end], [z_start, z_end], color=highlight_color, lw=2)
149
 
150
+ # Add labels to nodes
151
  for node, (x, y, z) in pos.items():
152
  if node in path:
153
  ax.text(x, y, z, str(node), fontsize=12, color='black')
154
 
155
+ # Set labels and title
156
  ax.set_xlabel('Time (weeks)')
157
  ax.set_ylabel('Event Probability')
158
  ax.set_zlabel('Event Number')
159
+ ax.set_title('3D Event Tree - Path')
160
+
161
+ plt.savefig(filename, bbox_inches='tight') # Save to file with adjusted margins
162
+ plt.close() # Close the figure to free resources
163
 
 
 
164
 
165
  def draw_global_tree_3d(G, filename='global_tree.png'):
166
  """Draws the entire graph in 3D using networkx and matplotlib and saves the figure to a file."""
167
  pos = nx.get_node_attributes(G, 'pos')
168
+ labels = nx.get_node_attributes(G, 'label')
169
 
170
+ # Check if the graph is empty
171
+ if not pos:
172
+ print("Graph is empty. No nodes to visualize.")
173
+ return
174
+
175
  # Get data for 3D visualization
176
  x_vals, y_vals, z_vals = zip(*pos.values())
177
 
178
  fig = plt.figure(figsize=(16, 12))
179
  ax = fig.add_subplot(111, projection='3d')
180
 
181
+ # Assign colors to nodes based on probability
182
  node_colors = []
183
  for node, (x, prob, z) in pos.items():
184
  if prob < 0.33:
 
197
  x_end, y_end, z_end = pos[edge[1]]
198
  ax.plot([x_start, x_end], [y_start, y_end], [z_start, z_end], color='gray', lw=2)
199
 
200
+ # Add labels to nodes
201
  for node, (x, y, z) in pos.items():
202
+ label = labels.get(node, f"{node}")
203
+ ax.text(x, y, z, label, fontsize=12, color='black')
204
 
205
+ # Set labels and title
206
+ ax.set_xlabel('Time')
207
+ ax.set_ylabel('Probability')
208
  ax.set_zlabel('Event Number')
209
+ ax.set_title('3D Event Tree')
210
 
211
+ plt.savefig(filename, bbox_inches='tight') # Save to file with adjusted margins
212
+ plt.close() # Close the figure to free resources
213
 
214
  def main(mode, input_file=None):
215
  G = nx.DiGraph()
 
217
  if mode == 'random':
218
  starting_x = 0
219
  starting_y = 0
220
+ max_depth = 5 # Maximum depth of the tree
221
  max_nodes = 3 # Maximum number of child nodes
222
+ x_range = 10 # Maximum range for x position of nodes
223
 
224
+ # Generate the tree and get node count per depth
225
  generate_tree(starting_x, starting_y, 0, max_depth, max_nodes, x_range, G)
226
+
227
+
228
  elif mode == 'json' and input_file:
229
  with open(input_file, 'r') as file:
230
  json_data = file.read()
 
233
  print("Invalid mode or input file not provided.")
234
  return
235
 
236
+ # Save the global visualization
237
+ draw_global_tree_3d(G, filename='global_tree.png')
238
+
239
+
240
  # Find relevant paths
241
  best_path, best_mean_prob, worst_path, worst_mean_prob, longest_duration_path, shortest_duration_path = find_paths(G)
242
 
243
+ # Print results
244
  if best_path:
245
  print(f"\nPath with the highest average probability:")
246
  print(" -> ".join(map(str, best_path)))
 
277
  if shortest_duration_path:
278
  draw_path_3d(G, path=shortest_duration_path, filename='shortest_duration_path.png', highlight_color='purple')
279
 
280
+
281
+
282
  if __name__ == "__main__":
283
  if len(sys.argv) < 2:
284
+ print("Usage: python script.py <mode> [input_file]")
285
  else:
286
  mode = sys.argv[1]
287
  input_file = sys.argv[2] if len(sys.argv) > 2 else None
288
  main(mode, input_file)
289
+
290
+