NemesisAlm's picture
Update app.py
ee9ee13
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel
import gradio as gr
model = CLIPModel.from_pretrained("geolocal/StreetCLIP")
processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP")
labels = ['Albania', 'Andorra', 'Argentina', 'Australia', 'Austria', 'Bangladesh', 'Belgium', 'Bermuda', 'Bhutan', 'Bolivia', 'Botswana', 'Brazil', 'Bulgaria', 'Cambodia', 'Canada', 'Chile', 'China', 'Colombia', 'Croatia', 'Czech Republic', 'Denmark', 'Dominican Republic', 'Ecuador', 'Estonia', 'Finland', 'France', 'Germany', 'Ghana', 'Greece', 'Greenland', 'Guam', 'Guatemala', 'Hungary', 'Iceland', 'India', 'Indonesia', 'Ireland', 'Israel', 'Italy', 'Japan', 'Jordan', 'Kenya', 'Kyrgyzstan', 'Laos', 'Latvia', 'Lesotho', 'Lithuania', 'Luxembourg', 'Macedonia', 'Madagascar', 'Malaysia', 'Malta', 'Mexico', 'Monaco', 'Mongolia', 'Montenegro', 'Netherlands', 'New Zealand', 'Nigeria', 'Norway', 'Pakistan', 'Palestine', 'Peru', 'Philippines', 'Poland', 'Portugal', 'Puerto Rico', 'Romania', 'Russia', 'Rwanda', 'Senegal', 'Serbia', 'Singapore', 'Slovakia', 'Slovenia', 'South Africa', 'South Korea', 'Spain', 'Sri Lanka', 'Swaziland', 'Sweden', 'Switzerland', 'Taiwan', 'Thailand', 'Tunisia', 'Turkey', 'Uganda', 'Ukraine', 'United Arab Emirates', 'United Kingdom', 'United States', 'Uruguay']
def classify(image):
inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
prediction = logits_per_image.softmax(dim=1)
confidences = {labels[i]: float(prediction[0][i].item()) for i in range(len(labels))}
return confidences
DESCRIPTION = """
<img src='file/logo.jpg' alt='logo' style='width:100px;float:left'/><h1 style='padding-top:30px'>&nbsp;StreetClip Model used for classification - 92 countries supported</h1>
<h2 style='margin-top:50px'> This Space demonstrates how the StreetClip model (https://huggingface.co/geolocal/StreetCLIP) can be used to geolocate images (classification per country).</h2>
<p>
🌎 <b>List of countries supported</b>: Albania, Andorra, Argentina, Australia, Austria, Bangladesh, Belgium, Bermuda, Bhutan, Bolivia, Botswana, Brazil, Bulgaria, Cambodia, Canada, Chile, China, Colombia, Croatia, Czech Republic, Denmark, Dominican Republic, Ecuador, Estonia, Finland, France, Germany, Ghana, Greece, Greenland, Guam, Guatemala, Hungary, Iceland, India, Indonesia, Ireland, Israel, Italy, Japan, Jordan, Kenya, Kyrgyzstan, Laos, Latvia, Lesotho, Lithuania, Luxembourg, Macedonia, Madagascar, Malaysia, Malta, Mexico, Monaco, Mongolia, Montenegro, Netherlands, New Zealand, Nigeria, Norway, Pakistan, Palestine, Peru, Philippines, Poland, Portugal, Puerto Rico, Romania, Russia, Rwanda, Senegal, Serbia, Singapore, Slovakia, Slovenia, South Africa, South Korea, Spain, Sri Lanka, Swaziland, Sweden, Switzerland, Taiwan, Thailand, Tunisia, Turkey, Uganda, Ukraine, United Arab Emirates, United Kingdom, United States, Uruguay
</p>
---<br>
As a derivate work of [StreetClip] (https://huggingface.co/geolocal/StreetCLIP),this demo is governed by the original Create Commons Attribution Non Commercial 4.0.
"""
image = gr.Image(label="Input image")
label = gr.Label(num_top_classes=10,label="Top 10 countries")
demo = gr.Interface(
fn=classify,
inputs=image,
outputs=label,
interpretation="default",
title="StreetClip Classification",
description=DESCRIPTION,
article="Interpretation requires time to process, please be patient πŸ™πŸ»",
examples=["turkey.jpg","croatia.jpg"],
allow_flagging="never",
)
demo.launch(favicon_path='favicon.ico')