po-fetch-detail / app.py
virendravaishnav's picture
Updated with OCR model and Gradio integration
7fee682
import gradio as gr
from transformers import AutoImageProcessor, AutoTokenizer, AutoModel
import torch
repo_id = "OpenGVLab/InternVL2-1B"
# Load the image processor, tokenizer, and model directly from the Hub
image_processor = AutoImageProcessor.from_pretrained(repo_id, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(repo_id, trust_remote_code=True)
model = AutoModel.from_pretrained(
repo_id,
trust_remote_code=True,
torch_dtype=torch.float16 # Use half-precision for efficiency
)
# Move model to the appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
def analyze_image(image):
try:
img = image.convert("RGB")
text = "describe this image"
# Process the image
image_inputs = image_processor(images=img, return_tensors="pt").to(device)
# Process the text
text_inputs = tokenizer(text, return_tensors="pt").to(device)
# Combine the inputs
inputs = {
"input_ids": text_inputs["input_ids"],
"attention_mask": text_inputs["attention_mask"],
"pixel_values": image_inputs["pixel_values"],
}
# Generate outputs
outputs = model.generate(**inputs)
# Decode the outputs
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_text
except Exception as e:
return f"An error occurred: {str(e)}"
demo = gr.Interface(
fn=analyze_image,
inputs=gr.Image(type="pil"),
outputs="text",
title="Image Description using InternVL2-1B",
description="Upload an image and get a description generated by the InternVL2-1B model."
)
if __name__ == "__main__":
demo.launch()