import gradio as gr
import torch
from transformers import FuyuForCausalLM, AutoTokenizer
from transformers.models.fuyu.processing_fuyu import FuyuProcessor
from transformers.models.fuyu.image_processing_fuyu import FuyuImageProcessor
from PIL import Image
model_id = "adept/fuyu-8b"
dtype = torch.bfloat16
device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = FuyuForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=dtype)
processor = FuyuProcessor(image_processor=FuyuImageProcessor(), tokenizer=tokenizer)
caption_prompt = "Generate a coco-style caption.\\n"
def resize_to_max(image, max_width=1080, max_height=1080):
width, height = image.size
if width <= max_width and height <= max_height:
return image
scale = min(max_width/width, max_height/height)
width = int(width*scale)
height = int(height*scale)
return image.resize((width, height), Image.LANCZOS)
def predict(image, prompt):
# image = image.convert('RGB')
image = resize_to_max(image)
model_inputs = processor(text=prompt, images=[image])
model_inputs = {k: v.to(dtype=dtype if torch.is_floating_point(v) else v.dtype, device=device) for k,v in model_inputs.items()}
generation_output = model.generate(**model_inputs, max_new_tokens=40)
prompt_len = model_inputs["input_ids"].shape[-1]
return tokenizer.decode(generation_output[0][prompt_len:], skip_special_tokens=True)
def caption(image):
return predict(image, caption_prompt)
def set_example_image(example: list) -> dict:
return gr.Image.update(value=example[0])
css = """
#mkd {
height: 500px;
overflow: auto;
border: 1px solid #ccc;
}
"""
with gr.Blocks(css=css) as demo:
gr.HTML(
"""
Fuyu Multimodal Demo
Fuyu-8B is a multimodal model that supports a variety of tasks combining text and image prompts.
For example, you can use it for captioning by asking it to describe an image. You can also ask it questions about an image, a task known as Visual Question Answering, or VQA. This demo lets you explore captioning and VQA, with more tasks coming soon :)
Learn more about the model in our blog post.
Note: This is a raw model release. We have not added further instruction-tuning, postprocessing or sampling strategies to control for undesirable outputs. The model may hallucinate, and you should expect to have to fine-tune the model for your use-case!
Play with Fuyu-8B in this demo! 💬
"""
)
with gr.Tab("Visual Question Answering"):
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Upload your Image", type="pil")
text_input = gr.Textbox(label="Ask a Question")
vqa_output = gr.Textbox(label="Output")
vqa_btn = gr.Button("Answer Visual Question")
gr.Examples(
[["assets/vqa_example_1.png", "How is this made?"], ["assets/vqa_example_2.png", "What is this flower and where is it's origin?"]],
inputs = [image_input, text_input],
outputs = [vqa_output],
fn=predict,
cache_examples=True,
label='Click on any Examples below to get VQA results quickly 👇'
)
with gr.Tab("Image Captioning"):
with gr.Row():
captioning_input = gr.Image(label="Upload your Image", type="pil")
captioning_output = gr.Textbox(label="Output")
captioning_btn = gr.Button("Generate Caption")
gr.Examples(
[["assets/captioning_example_1.png"], ["assets/captioning_example_2.png"]],
inputs = [captioning_input],
outputs = [captioning_output],
fn=caption,
cache_examples=True,
label='Click on any Examples below to get captioning results quickly 👇'
)
captioning_btn.click(fn=caption, inputs=captioning_input, outputs=captioning_output)
vqa_btn.click(fn=predict, inputs=[image_input, text_input], outputs=vqa_output)
demo.launch(server_name="0.0.0.0")