Spaces:
Running
Running
import streamlit as st | |
from text2image import get_model, get_tokenizer, get_image_transform | |
from utils import text_encoder, image_encoder | |
from PIL import Image | |
from jax import numpy as jnp | |
import pandas as pd | |
import requests | |
import jax | |
import gc | |
def app(): | |
#st.title("From Image to Text") | |
st.markdown("<h1 style='text-align: center; color: #CD212A;'> Zero Shot Image Classification </h1>", unsafe_allow_html=True) | |
st.markdown("<h2 style='text-align: center; color: #008C45; font-weight:bold;'> Image to Text </h2>", unsafe_allow_html=True) | |
st.markdown( | |
""" | |
π Ciao! Here you can find the captions or the labels that are most related to a given image. | |
Try typing "gatto" (cat) in the space for label1 and "cane" (dog) in the space for label2 and click | |
"classify"! | |
""" | |
) | |
image_url = st.text_input( | |
"YOU CAN INPUT THE URL OF AN IMAGE : ", | |
value="https://www.petdetective.it/wp-content/uploads/2016/04/gatto-toilette.jpg", | |
) | |
MAX_CAP = 4 | |
col1, col2 = st.beta_columns([3, 1]) | |
with col2: | |
captions_count = st.selectbox( | |
"NUMBER OF LABELS", options=range(1, MAX_CAP + 1), index=1 | |
) | |
compute = st.button("CLASSIFY") | |
with col1: | |
captions = list() | |
for idx in range(min(MAX_CAP, captions_count)): | |
captions.append(st.text_input(f"INSERT LABEL {idx+1}")) | |
if compute: | |
captions = [c for c in captions if c != ""] | |
if not captions or not image_url: | |
st.error("Please choose one image and at least one label") | |
else: | |
with st.spinner("Computing..."): | |
model = get_model() | |
tokenizer = get_tokenizer() | |
text_embeds = list() | |
for i, c in enumerate(captions): | |
text_embeds.extend(text_encoder(c, model, tokenizer)) | |
text_embeds = jnp.array(text_embeds) | |
image_raw = requests.get(image_url, stream=True,).raw | |
image = Image.open(image_raw).convert("RGB") | |
transform = get_image_transform(model.config.vision_config.image_size) | |
image_embed = image_encoder(transform(image), model) | |
# we could have a softmax here | |
cos_similarities = jax.nn.softmax( | |
jnp.matmul(image_embed, text_embeds.T) | |
) | |
chart_data = pd.Series(cos_similarities[0], index=captions) | |
col1, col2 = st.beta_columns(2) | |
with col1: | |
st.bar_chart(chart_data) | |
with col2: | |
st.image(image, use_column_width=True) | |
gc.collect() | |
elif image_url: | |
image_raw = requests.get(image_url, stream=True,).raw | |
image = Image.open(image_raw).convert("RGB") | |
st.image(image) | |