import gradio as gr from datasets import load_dataset from PIL import Image from collections import OrderedDict from random import sample import csv from transformers import AutoFeatureExtractor, AutoModelForImageClassification import random feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224") model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224") classdict = OrderedDict() for line in open('LOC_synset_mapping.txt', 'r').readlines(): try: classdict[line.split(' ')[0]]= ' '.join(line.split(' ')[1:]).replace('\n','').split(',')[0] except: continue classes = list(classdict.values()) imagedict={} with open('image_labels.csv', 'r') as csv_file: reader = csv.DictReader(csv_file) for row in reader: imagedict[row['image_name']] = row['image_label'] images= list(imagedict.keys()) labels = list(set(imagedict.values())) def model_classify(radio, im): if radio is not None: inputs = feature_extractor(images=im, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits predicted_class_idx = logits.argmax(-1).item() modelclass=model.config.id2label[predicted_class_idx] return modelclass.split(',')[0], predicted_class_idx, True else: return None, None, False def random_image(): imname = random.choice(images) im = Image.open('images/'+ imname +'.jpg') label = str(imagedict[imname]) labels.remove(label) options = sample(labels,3) options.append(label) random.shuffle(options) options = [classes[int(i)] for i in options] return im, label, gr.Radio.update(value=None, choices=options), None def check_score(pred, truth, current_score, total_score, has_guessed): if not(has_guessed): if pred == classes[int(truth)]: total_score +=1 return current_score + 1, f"Your score is {current_score+1} out of {total_score}!", total_score else: if pred is not None: total_score +=1 return current_score, f"Your score is {current_score} out of {total_score}!", total_score else: return current_score, f"Your score is {current_score} out of {total_score}!", total_score def compare_score(userclass, truth): if userclass is None: return"Try guessing a category!" else: if userclass == classes[int(truth)]: return "Great! You guessed it right" else: return "The right answer was " +str(classes[int(truth)])+ "! Try guessing the next image." with gr.Blocks() as demo: user_score = gr.State(0) model_score = gr.State(0) image_label = gr.State() model_class = gr.State() total_score = gr.State(0) has_guessed = gr.State(False) gr.Markdown("# ImageNet Quiz") gr.Markdown("### ImageNet is one of the most popular datasets used for training and evaluating AI models.") gr.Markdown("### But many of its categories are hard to guess, even for humans.") gr.Markdown("#### Try your hand at guessing the category of each image displayed, from the options provided. Compare your answers to that of a neural network trained on the dataset, and see if you can do better!") with gr.Row(): with gr.Column(min_width= 900): image = gr.Image(shape=(600, 600)) radio = gr.Radio(["option1", "option2", "option3"], label="Pick a category", interactive=True) with gr.Column(): prediction = gr.Label(label="The AI model predicts:") score = gr.Label(label="Your Score") message = gr.Label(label="Did you guess it right?") btn = gr.Button("Next image") demo.load(random_image, None, [image, image_label, radio, prediction]) radio.change(model_classify, [radio, image], [prediction, model_class, has_guessed]) radio.change(check_score, [radio, image_label, user_score, total_score, has_guessed], [user_score, score, total_score]) radio.change(compare_score, [radio, image_label], message) btn.click(random_image, None, [image, image_label, radio, prediction]) btn.click(lambda :False, None, has_guessed) demo.launch()