Spaces:
Runtime error
Runtime error
import os | |
import numpy as np | |
from matplotlib import rcParams | |
import matplotlib.pyplot as plt | |
from tensorflow.keras.models import load_model, Model | |
from tensorflow.keras.utils import load_img, save_img, img_to_array | |
from tensorflow.keras.applications.vgg19 import preprocess_input | |
from tensorflow.keras.layers import GlobalAveragePooling2D | |
from pymilvus import connections, Collection, utility | |
from requests import get | |
import streamlit as st | |
import zipfile | |
# unzip vegetable images | |
def unzip_images(): | |
with zipfile.ZipFile("Vegetable Images.zip", 'r') as zip_ref: | |
zip_ref.extractall('.') | |
print('unzipped images') | |
if not os.path.exists('Vegetable Images/'): | |
unzip_images() | |
class ImageVectorizer: | |
''' | |
Get vector representation of an image using VGG19 model fine tuned on vegetable images for classification | |
''' | |
def __init__(self): | |
self.__model = self.get_model() | |
def get_model(): | |
model = load_model('vegetable_classification_model_vgg.h5') # loading saved VGG model finetuned on vegetable images for classification | |
top = model.get_layer('block5_pool').output | |
top = GlobalAveragePooling2D()(top) | |
model = Model(inputs=model.input, outputs=top) | |
print('loaded model') | |
return model | |
def vectorize(self, img_path: str): | |
model = self.__model | |
test_image = load_img(img_path, color_mode="rgb", target_size=(224, 224)) | |
test_image = img_to_array(test_image) | |
test_image = preprocess_input(test_image) | |
test_image = np.array([test_image]) | |
return model(test_image).numpy()[0] | |
def get_milvus_collection(): | |
uri = os.environ.get("URI") | |
token = os.environ.get("TOKEN") | |
connections.connect("default", uri=uri, token=token) | |
print(f"Connected to DB") | |
collection_name = os.environ.get("COLLECTION_NAME") | |
collection = Collection(name=collection_name) | |
collection.load() | |
return collection | |
def plot_images(input_image_path: str, similar_img_paths: list): | |
# plotting similar images | |
rows = 5 # rows in subplots | |
cols = 3 # columns in subplots | |
fig, ax = plt.subplots(rows, cols, figsize=(12, 20)) | |
r = 0 | |
c = 0 | |
for i in range(rows*cols): | |
sim_image = load_img(similar_img_paths[i], color_mode="rgb", target_size=(224, 224)) | |
ax[r,c].axis("off") | |
ax[r,c].imshow(sim_image) | |
c += 1 | |
if c == cols: | |
c = 0 | |
r += 1 | |
plt.subplots_adjust(wspace=0.01, hspace=0.01) | |
# display input image | |
rcParams.update({'figure.autolayout': True}) | |
input_image = load_img(input_image_path, color_mode="rgb", target_size=(224, 224)) | |
with placeholder.container(): | |
st.markdown('<p style="font-size: 20px; font-weight: bold">Input image</p>', unsafe_allow_html=True) | |
st.image(input_image) | |
st.write(' \n') | |
# display similar images | |
st.markdown('<p style="font-size: 20px; font-weight: bold">Similar images</p>', unsafe_allow_html=True) | |
st.pyplot(fig) | |
def find_similar_images(img_path: str, top_n: int=15): | |
search_params = {"metric_type": "L2"} | |
search_vec = vectorizer.vectorize(img_path) | |
result = collection.search([search_vec], | |
anns_field='image_vector', # annotation field specified in the schema definition | |
param=search_params, | |
limit=top_n, | |
guarantee_timestamp=1, | |
output_fields=['image_path']) # which fields to return in output | |
output_dict = {"input_image_path": img_path, "similar_image_paths": [hit.entity.get('image_path') for hits in result for hit in hits]} | |
plot_images(output_dict['input_image_path'], output_dict['similar_image_paths']) | |
def delete_file(path_: str): | |
if os.path.exists(path_): | |
os.remove(path_) | |
def get_upload_path(): | |
upload_file_path = os.path.join('.', 'uploads') | |
if not os.path.exists(upload_file_path): | |
os.makedirs(upload_file_path) | |
upload_filename = "input.jpg" | |
upload_file_path = os.path.join(upload_file_path, upload_filename) | |
return upload_file_path | |
def process_input_image(img_url): | |
upload_file_path = get_upload_path() | |
headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/102.0.0.0 Safari/537.36'} | |
r = get(img_url, headers=headers) | |
with open(upload_file_path, "wb") as file: | |
file.write(r.content) | |
return upload_file_path | |
vectorizer = ImageVectorizer() | |
collection = get_milvus_collection() | |
try: | |
st.markdown("<h3>Find Similar Vegetable Images</h3>", unsafe_allow_html=True) | |
desc = '''<p style="font-size: 15px;">Allowed Vegetables: Broad Bean, Bitter Gourd, Bottle Gourd, | |
Green Round Brinjal, Broccoli, Cabbage, Capsicum, Carrot, Cauliflower, Cucumber, | |
Raw Papaya, Potato, Green Pumpkin, Radish, Tomato. | |
</p> | |
<p style="font-size: 13px;">Image embeddings are extracted from a fine-tuned VGG model. The model is fine-tuned on <a href="https://www.kaggle.com/datasets/misrakahmed/vegetable-image-dataset" target="_blank">images</a> clicked using a mobile phone camera. | |
Embeddings of 20,000 vegetable images are stored in Milvus vector database. Embeddings of the input image are computed and 15 most similar images (based on L2 distance) are displayed.</p> | |
''' | |
st.markdown(desc, unsafe_allow_html=True) | |
img_url = st.text_input("Paste the image URL of a vegetable and hit Enter:", "") | |
placeholder = st.empty() | |
if img_url: | |
placeholder.empty() | |
img_path = process_input_image(img_url) | |
find_similar_images(img_path, 15) | |
delete_file(img_path) | |
except Exception as e: | |
st.error(f'An unexpected error occured: \n{e}') | |