|
from PIL import Image |
|
import torch |
|
from transformers import CLIPProcessor, CLIPModel |
|
import gradio as gr |
|
|
|
model = CLIPModel.from_pretrained("geolocal/StreetCLIP") |
|
processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP") |
|
|
|
labels = ['Albania', 'Andorra', 'Argentina', 'Australia', 'Austria', 'Bangladesh', 'Belgium', 'Bermuda', 'Bhutan', 'Bolivia', 'Botswana', 'Brazil', 'Bulgaria', 'Cambodia', 'Canada', 'Chile', 'China', 'Colombia', 'Croatia', 'Czech Republic', 'Denmark', 'Dominican Republic', 'Ecuador', 'Estonia', 'Finland', 'France', 'Germany', 'Ghana', 'Greece', 'Greenland', 'Guam', 'Guatemala', 'Hungary', 'Iceland', 'India', 'Indonesia', 'Ireland', 'Israel', 'Italy', 'Japan', 'Jordan', 'Kenya', 'Kyrgyzstan', 'Laos', 'Latvia', 'Lesotho', 'Lithuania', 'Luxembourg', 'Macedonia', 'Madagascar', 'Malaysia', 'Malta', 'Mexico', 'Monaco', 'Mongolia', 'Montenegro', 'Netherlands', 'New Zealand', 'Nigeria', 'Norway', 'Pakistan', 'Palestine', 'Peru', 'Philippines', 'Poland', 'Portugal', 'Puerto Rico', 'Romania', 'Russia', 'Rwanda', 'Senegal', 'Serbia', 'Singapore', 'Slovakia', 'Slovenia', 'South Africa', 'South Korea', 'Spain', 'Sri Lanka', 'Swaziland', 'Sweden', 'Switzerland', 'Taiwan', 'Thailand', 'Tunisia', 'Turkey', 'Uganda', 'Ukraine', 'United Arab Emirates', 'United Kingdom', 'United States', 'Uruguay'] |
|
|
|
|
|
def classify(image): |
|
inputs = processor(text=labels, images=image, return_tensors="pt", padding=True) |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
logits_per_image = outputs.logits_per_image |
|
prediction = logits_per_image.softmax(dim=1) |
|
confidences = {labels[i]: float(prediction[0][i].item()) for i in range(len(labels))} |
|
return confidences |
|
|
|
DESCRIPTION = """ |
|
<img src='file/logo.jpg' alt='logo' style='width:100px;float:left'/><h1 style='padding-top:30px'> StreetClip Model used for classification - 92 countries supported</h1> |
|
<h2 style='margin-top:50px'> This Space demonstrates how the StreetClip model (https://huggingface.co/geolocal/StreetCLIP) can be used to geolocate images (classification per country).</h2> |
|
<p> |
|
π <b>List of countries supported</b>: Albania, Andorra, Argentina, Australia, Austria, Bangladesh, Belgium, Bermuda, Bhutan, Bolivia, Botswana, Brazil, Bulgaria, Cambodia, Canada, Chile, China, Colombia, Croatia, Czech Republic, Denmark, Dominican Republic, Ecuador, Estonia, Finland, France, Germany, Ghana, Greece, Greenland, Guam, Guatemala, Hungary, Iceland, India, Indonesia, Ireland, Israel, Italy, Japan, Jordan, Kenya, Kyrgyzstan, Laos, Latvia, Lesotho, Lithuania, Luxembourg, Macedonia, Madagascar, Malaysia, Malta, Mexico, Monaco, Mongolia, Montenegro, Netherlands, New Zealand, Nigeria, Norway, Pakistan, Palestine, Peru, Philippines, Poland, Portugal, Puerto Rico, Romania, Russia, Rwanda, Senegal, Serbia, Singapore, Slovakia, Slovenia, South Africa, South Korea, Spain, Sri Lanka, Swaziland, Sweden, Switzerland, Taiwan, Thailand, Tunisia, Turkey, Uganda, Ukraine, United Arab Emirates, United Kingdom, United States, Uruguay |
|
</p> |
|
---<br> |
|
As a derivate work of [StreetClip] (https://huggingface.co/geolocal/StreetCLIP),this demo is governed by the original Create Commons Attribution Non Commercial 4.0. |
|
""" |
|
|
|
image = gr.Image(label="Input image") |
|
label = gr.Label(num_top_classes=10,label="Top 10 countries") |
|
|
|
demo = gr.Interface( |
|
fn=classify, |
|
inputs=image, |
|
outputs=label, |
|
interpretation="default", |
|
title="StreetClip Classification", |
|
description=DESCRIPTION, |
|
article="Interpretation requires time to process, please be patient ππ»", |
|
examples=["turkey.jpg","croatia.jpg"], |
|
allow_flagging="never", |
|
) |
|
|
|
demo.launch(favicon_path='favicon.ico') |
|
|