File size: 2,696 Bytes
cc9f92c
 
 
 
 
 
 
 
ff6fbb2
cc9f92c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import numpy as np
import time
from tensorflow.keras.preprocessing import image
# from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tensorflow as tf
import streamlit as st
# with tf.device('/cpu:0'):
# Load the saved model
model = tf.keras.models.load_model('best_resnet152_model.h5')

class_names = {0: '1099_Div', 1: '1099_Int', 2: 'Non_Form', 3: 'w_2', 4: 'w_3'}
# print(class_names)

# Load and preprocess the image
# img_path = '/app/filled_form_1.jpg'
@st.cache_resource
def predict(pil_img):
    # Convert the PIL image to a NumPy array
    img_array = image.img_to_array(pil_img)
    img_array = np.expand_dims(img_array, axis=0)
    img_array /= 255.0  # Rescale pixel values

    # Predict the class
    start_time = time.time()
    predictions = model.predict(img_array)
    end_time = time.time()
    predicted_class_index = np.argmax(predictions, axis=1)[0]

    # Get the predicted class name
    predicted_class_name = class_names[predicted_class_index]
    print("Predicted class:", predicted_class_name)
    # print("Execution time: ", end_time - start_time)
    return predicted_class_name
# import numpy as np
# import time
# from PIL import Image  # Import for PIL image handling
# from torchvision import transforms  # Import for image preprocessing

# import torch
# import torch.nn as nn  # Import for PyTorch neural networks
# import streamlit as st

# # Load the PyTorch model (assuming it's saved in PyTorch format)
# model = torch.load('./best_resnet152_model.pt')  # Replace with your model filename

# # Define class names dictionary
# class_names = {0: '1099_Div', 1: '1099_Int', 2: 'Non_Form', 3: 'w_2', 4: 'w_3'}


# # Define a function for prediction using PyTorch
# @st.cache_resource
# def predict(pil_img):
#     # Preprocess the image
#     preprocess = transforms.Compose([
#         transforms.ToTensor(),  # Convert to PyTorch tensor
#         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize based on ImageNet statistics
#     ])
#     img_tensor = preprocess(pil_img)
#     img_tensor.unsqueeze_(0)  # Add batch dimension

#     # Predict with PyTorch
#     start_time = time.time()
#     with torch.no_grad():  # Disable gradient calculation for prediction
#         predictions = model(img_tensor)
#     end_time = time.time()

#     # Get the predicted class
#     predicted_class_index = torch.argmax(predictions, dim=1).item()
#     predicted_class_name = class_names[predicted_class_index]

#     # Print results (optional for debugging)
#     print("Predicted class:", predicted_class_name)
#     print("Execution time: ", end_time - start_time)

#     return predicted_class_name