Spaces:
Sleeping
Sleeping
File size: 7,400 Bytes
746d998 e1e5ef8 746d998 e1e5ef8 746d998 e1e5ef8 746d998 e1e5ef8 746d998 e1e5ef8 746d998 e1e5ef8 746d998 e1e5ef8 746d998 e1e5ef8 746d998 e1e5ef8 746d998 e1e5ef8 746d998 e1e5ef8 746d998 e1e5ef8 746d998 e1e5ef8 746d998 e1e5ef8 746d998 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
"""
A small Streamlit app that loads a Keras model trained on the MNIST dataset and allows the user to draw a digit on a canvas and get a predicted digit from the model.
"""
import streamlit as st
from PIL import Image
from streamlit_drawable_canvas import st_canvas
import os
import numpy as np
from keras import models
import keras.datasets.mnist as mnist
import matplotlib.pyplot as plt
import pandas as pd
import time
import onnx
import onnxruntime
from scipy.special import softmax
@st.cache_resource
def load_picture():
"""
Shows the MNIST dataset image
"""
st.image("img/show.png", width=250, caption="First 9 images from the MNIST dataset")
def keras_prediction(final, model_path):
"""Make a predition using a Keras model
Args:
final: The input image
model_path: The path of the Keras model to load
Returns:
np.array: Predictions from the model. The probability of each digit.
float: Time to make the prediction
float: Time to load the model
"""
# load the model
load_time = time.time()
model = models.load_model(
os.path.abspath(os.path.join(os.path.dirname(__file__), model_path))
)
after_load_curr = time.time()
# Make the prediction
curr_time = time.time()
prediction = model.predict(final[None, ...])
after_time = time.time()
return prediction, after_time - curr_time, after_load_curr - load_time
def onnx_prediction(final, model_path):
"""Make a predition using an Onnx model
Args:
final: The input image
model_path: The path of the Onnx model to load
Returns:
np.array: Predictions from the model. The probability of each digit.
float: Time to make the prediction
float: Time to load the model
"""
im_np = np.expand_dims(final, axis=0)
im_np = np.expand_dims(im_np, axis=0)
im_np = im_np.astype("float32")
# Load the model
load_curr = time.time()
session = onnxruntime.InferenceSession(model_path, None)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
after_load_curr = time.time()
# Make the prediction
curr_time = time.time()
result = session.run([output_name], {input_name: im_np})
prediction = softmax(np.array(result).squeeze(), axis=0)
after_time = time.time()
return prediction, after_time - curr_time, after_load_curr - load_curr
def main():
"""
The main function/primary entry point of the app
"""
# Setup
st.set_page_config(layout="wide")
st.title("MNIST Digit Recognizer")
col1, col2 = st.columns([0.8, 0.2], gap="small")
with col1:
st.markdown(
"""
This Streamlit app demonstrates the performance of multiple different neural networks (and associated frameworks) trained on the <a href="https://yann.lecun.com/exdb/mnist/">MNIST dataset</a> to predict handwritten digits. Draw a digit in the canvas below and see the model's prediction. You can:
- Change the stroke width of the digit using the slider
- Choose what model you use for predictions
- Onnx: The mnist-12 Onnx model from <a href="https://xethub.com/XetHub/onnx-models/src/branch/main/vision/classification/mnist">Onnx's pre-trained MNIST models</a>
- Autokeras: A model generated using the <a href="https://autokeras.com/image_classifier/">Autokeras image classifier class</a>
- Basic: A simple <a href="https://keras.io/">Keras</a> model with two layers where each layer has 300 nodes. The model was trained on the MNIST dataset for 35 epochs.
Like any machine learning model, this model is a function of the data it was fed during training. As you can see in the picture, the numbers in the images have a specific shape, location, and size. By playing around with the stroke width and where you draw the digit, you can see how the model's prediction changes.
If you change your selected model after drawing the digit, that same drawing will be used with the newly selected model. To clear your "hand" drawn digit, click the trashcan icon under the drawing canvas.""",
unsafe_allow_html=True,
)
with col2:
# Load the first 9 images from the MNIST dataset and show them
load_picture()
col3, col4 = st.columns(2, gap="small")
with col4:
# Stroke width slider to change the width of the canvas stroke
# Starts at 10 because that's reasonably close to the width of the MNIST digits
stroke_width = st.slider("Stroke width: ", 1, 25, 10)
model_choice = st.selectbox(
"Choose what model to use for predictions:", ("Onnx", "Autokeras", "Basic")
)
if "Basic" in model_choice:
model_path = "models/mnist_model.keras"
if "Auto" in model_choice:
model_path = "models/autokeras_model.keras"
if "Onnx" in model_choice:
model_path = "models/mnist_12.onnx"
with col3:
# Create a canvas component
canvas_result = st_canvas(
stroke_width=stroke_width,
stroke_color="#FFF",
fill_color="#000",
background_color="#000",
background_image=None,
update_streamlit=True,
height=300,
width=300,
drawing_mode="freedraw",
point_display_radius=0,
key="canvas",
)
if canvas_result is not None and canvas_result.image_data is not None:
# Get the image data, convert it to grayscale, and resize it to 28x28 (the same size as the MNIST dataset images)
img_data = canvas_result.image_data
im = Image.fromarray(img_data.astype("uint8")).convert("L")
im = im.resize((28, 28))
# Convert the image to a numpy array and normalize the values
final = np.array(im, dtype=np.float32) / 255.0
# if final is not all zeros, run the prediction
if not np.all(final == 0):
if model_choice != "Onnx":
prediction, pred_time, load_time = keras_prediction(final, model_path)
else:
prediction, pred_time, load_time = onnx_prediction(final, model_path)
# print the prediction
st.header(f"Results")
table_data = {
"Model": [model_choice],
"Prediction": [np.argmax(prediction)],
"Load time (ms)": f"{load_time * 1000:.2f}",
"Prediction time (ms)": f"{pred_time * 1000:.2f}",
}
st.table(table_data)
# Create a 2 column dataframe with one column as the digits and the other as the probability
data = pd.DataFrame(
{"Digit": list(range(10)), "Probability": np.ravel(prediction)}
)
col1, col2 = st.columns([0.8, 0.2], gap="small")
# create a bar chart to show the predictions
with col1:
st.bar_chart(data, x="Digit", y="Probability", height=500)
# show the probability distribution numerically
with col2:
data["Probability"] = data["Probability"].apply(lambda x: f"{x:.2%}")
st.dataframe(data, hide_index=True)
if __name__ == "__main__":
main()
|