jsulz HF staff commited on
Commit
e1e5ef8
1 Parent(s): e867093

cleaning up and themeing

Browse files
Files changed (2) hide show
  1. .streamlit/config.toml +6 -0
  2. app.py +49 -25
.streamlit/config.toml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [theme]
2
+ primaryColor = "F36295"
3
+ backgroundColor = "#FFF"
4
+ secondaryBackgroundColor = "#3183D1"
5
+ textColor = "#000"
6
+ font = "sans-serif"
app.py CHANGED
@@ -20,49 +20,66 @@ from scipy.special import softmax
20
  @st.cache_resource
21
  def load_picture():
22
  """
23
- Loads the first 9 images from the mnist dataset and add them to a plot
24
- to be displayed in streamlit.
25
  """
26
- # load the mnist dataset
27
- (x_train, y_train), (x_test, y_test) = mnist.load_data()
28
- # plot the first 9 images
29
- for i in range(9):
30
- plt.subplot(330 + 1 + i)
31
- image = x_train[i] / 255.0
32
- plt.imshow(image, cmap=plt.get_cmap("gray"))
33
-
34
- # Save the plot as a png file and show it in streamlit
35
- # This is commented out for not because the plot was created and saved in the img directory during the initial run of the app locally
36
- # plt.savefig("img/show.png")
37
  st.image("img/show.png", width=250, caption="First 9 images from the MNIST dataset")
38
 
39
 
40
  def keras_prediction(final, model_path):
 
 
 
 
 
 
 
 
 
 
 
 
41
  load_time = time.time()
42
  model = models.load_model(
43
  os.path.abspath(os.path.join(os.path.dirname(__file__), model_path))
44
  )
45
  after_load_curr = time.time()
 
 
46
  curr_time = time.time()
47
  prediction = model.predict(final[None, ...])
48
  after_time = time.time()
 
49
  return prediction, after_time - curr_time, after_load_curr - load_time
50
 
51
 
52
  def onnx_prediction(final, model_path):
53
- im_np = np.expand_dims(final, axis=0) # Add batch dimension
54
- im_np = np.expand_dims(im_np, axis=0) # Add channel dimension
 
 
 
 
 
 
 
 
 
 
55
  im_np = im_np.astype("float32")
 
 
56
  load_curr = time.time()
57
  session = onnxruntime.InferenceSession(model_path, None)
58
  input_name = session.get_inputs()[0].name
59
  output_name = session.get_outputs()[0].name
60
  after_load_curr = time.time()
61
 
 
62
  curr_time = time.time()
63
  result = session.run([output_name], {input_name: im_np})
64
  prediction = softmax(np.array(result).squeeze(), axis=0)
65
  after_time = time.time()
 
66
  return prediction, after_time - curr_time, after_load_curr - load_curr
67
 
68
 
@@ -70,21 +87,24 @@ def main():
70
  """
71
  The main function/primary entry point of the app
72
  """
73
- # write the title of the page as MNIST Digit Recognizer
 
74
  st.title("MNIST Digit Recognizer")
75
 
76
  col1, col2 = st.columns([0.8, 0.2], gap="small")
77
  with col1:
78
  st.markdown(
79
  """
80
- This Streamlit app loads a Keras neural network trained on the MNIST dataset to predict handwritten digits. Draw a digit in the canvas below and see the model's prediction. You can:
81
  - Change the stroke width of the digit using the slider
82
  - Choose what model you use for predictions
83
  - Onnx: The mnist-12 Onnx model from <a href="https://xethub.com/XetHub/onnx-models/src/branch/main/vision/classification/mnist">Onnx's pre-trained MNIST models</a>
84
  - Autokeras: A model generated using the <a href="https://autokeras.com/image_classifier/">Autokeras image classifier class</a>
85
- - Basic: A simple two layer nueral net where each layer has 300 nodes
 
 
86
 
87
- Like any machine learning model, this model is a function of the data it was fed during training. As you can see in the picture, the numbers in the images have a specific shape, location, and size. By playing around with the stroke width and where you draw the digit, you can see how the model's prediction changes.""",
88
  unsafe_allow_html=True,
89
  )
90
  with col2:
@@ -118,8 +138,8 @@ def main():
118
  background_color="#000",
119
  background_image=None,
120
  update_streamlit=True,
121
- height=200,
122
- width=200,
123
  drawing_mode="freedraw",
124
  point_display_radius=0,
125
  key="canvas",
@@ -144,10 +164,14 @@ def main():
144
  prediction, pred_time, load_time = onnx_prediction(final, model_path)
145
 
146
  # print the prediction
147
- st.header(f"Using model: {model_choice}")
148
- st.write(f"Prediction: {np.argmax(prediction)}")
149
- st.write(f"Load time (in ms): {(load_time) * 1000:.2f}")
150
- st.write(f"Prediction time (in ms): {(pred_time) * 1000:.2f}")
 
 
 
 
151
 
152
  # Create a 2 column dataframe with one column as the digits and the other as the probability
153
  data = pd.DataFrame(
 
20
  @st.cache_resource
21
  def load_picture():
22
  """
23
+ Shows the MNIST dataset image
 
24
  """
 
 
 
 
 
 
 
 
 
 
 
25
  st.image("img/show.png", width=250, caption="First 9 images from the MNIST dataset")
26
 
27
 
28
  def keras_prediction(final, model_path):
29
+ """Make a predition using a Keras model
30
+
31
+ Args:
32
+ final: The input image
33
+ model_path: The path of the Keras model to load
34
+
35
+ Returns:
36
+ np.array: Predictions from the model. The probability of each digit.
37
+ float: Time to make the prediction
38
+ float: Time to load the model
39
+ """
40
+ # load the model
41
  load_time = time.time()
42
  model = models.load_model(
43
  os.path.abspath(os.path.join(os.path.dirname(__file__), model_path))
44
  )
45
  after_load_curr = time.time()
46
+
47
+ # Make the prediction
48
  curr_time = time.time()
49
  prediction = model.predict(final[None, ...])
50
  after_time = time.time()
51
+
52
  return prediction, after_time - curr_time, after_load_curr - load_time
53
 
54
 
55
  def onnx_prediction(final, model_path):
56
+ """Make a predition using an Onnx model
57
+ Args:
58
+ final: The input image
59
+ model_path: The path of the Onnx model to load
60
+
61
+ Returns:
62
+ np.array: Predictions from the model. The probability of each digit.
63
+ float: Time to make the prediction
64
+ float: Time to load the model
65
+ """
66
+ im_np = np.expand_dims(final, axis=0)
67
+ im_np = np.expand_dims(im_np, axis=0)
68
  im_np = im_np.astype("float32")
69
+
70
+ # Load the model
71
  load_curr = time.time()
72
  session = onnxruntime.InferenceSession(model_path, None)
73
  input_name = session.get_inputs()[0].name
74
  output_name = session.get_outputs()[0].name
75
  after_load_curr = time.time()
76
 
77
+ # Make the prediction
78
  curr_time = time.time()
79
  result = session.run([output_name], {input_name: im_np})
80
  prediction = softmax(np.array(result).squeeze(), axis=0)
81
  after_time = time.time()
82
+
83
  return prediction, after_time - curr_time, after_load_curr - load_curr
84
 
85
 
 
87
  """
88
  The main function/primary entry point of the app
89
  """
90
+ # Setup
91
+ st.set_page_config(layout="wide")
92
  st.title("MNIST Digit Recognizer")
93
 
94
  col1, col2 = st.columns([0.8, 0.2], gap="small")
95
  with col1:
96
  st.markdown(
97
  """
98
+ This Streamlit app demonstrates the performance of multiple different neural networks (and associated frameworks) trained on the <a href="https://yann.lecun.com/exdb/mnist/">MNIST dataset</a> to predict handwritten digits. Draw a digit in the canvas below and see the model's prediction. You can:
99
  - Change the stroke width of the digit using the slider
100
  - Choose what model you use for predictions
101
  - Onnx: The mnist-12 Onnx model from <a href="https://xethub.com/XetHub/onnx-models/src/branch/main/vision/classification/mnist">Onnx's pre-trained MNIST models</a>
102
  - Autokeras: A model generated using the <a href="https://autokeras.com/image_classifier/">Autokeras image classifier class</a>
103
+ - Basic: A simple <a href="https://keras.io/">Keras</a> model with two layers where each layer has 300 nodes. The model was trained on the MNIST dataset for 35 epochs.
104
+
105
+ Like any machine learning model, this model is a function of the data it was fed during training. As you can see in the picture, the numbers in the images have a specific shape, location, and size. By playing around with the stroke width and where you draw the digit, you can see how the model's prediction changes.
106
 
107
+ If you change your selected model after drawing the digit, that same drawing will be used with the newly selected model. To clear your "hand" drawn digit, click the trashcan icon under the drawing canvas.""",
108
  unsafe_allow_html=True,
109
  )
110
  with col2:
 
138
  background_color="#000",
139
  background_image=None,
140
  update_streamlit=True,
141
+ height=300,
142
+ width=300,
143
  drawing_mode="freedraw",
144
  point_display_radius=0,
145
  key="canvas",
 
164
  prediction, pred_time, load_time = onnx_prediction(final, model_path)
165
 
166
  # print the prediction
167
+ st.header(f"Results")
168
+ table_data = {
169
+ "Model": [model_choice],
170
+ "Prediction": [np.argmax(prediction)],
171
+ "Load time (ms)": f"{load_time * 1000:.2f}",
172
+ "Prediction time (ms)": f"{pred_time * 1000:.2f}",
173
+ }
174
+ st.table(table_data)
175
 
176
  # Create a 2 column dataframe with one column as the digits and the other as the probability
177
  data = pd.DataFrame(