import gradio as gr from lavis.models import load_model_and_preprocess import torch device = torch.device("cuda") if torch.cuda.is_available() else "cpu" model_name = "blip2_t5_instruct" model_type = "flant5xl" model, vis_processors, _ = load_model_and_preprocess( name=model_name, model_type=model_type, is_eval=True, device=device ) def infer(image, prompt, min_len, max_len, beam_size, len_penalty, repetition_penalty, top_p, decoding_method): use_nucleus_sampling = decoding_method == "Nucleus sampling" image = vis_processors["eval"](image).unsqueeze(0).to(device) samples = { "image": image, "prompt": prompt, } output = model.generate( samples, length_penalty=float(len_penalty), repetition_penalty=float(repetition_penalty), num_beams=beam_size, max_length=max_len, min_length=min_len, top_p=top_p, use_nucleus_sampling=use_nucleus_sampling ) return output[0] theme = gr.themes.Monochrome( primary_hue="indigo", secondary_hue="blue", neutral_hue="slate", radius_size=gr.themes.sizes.radius_sm, font=[gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif"], ) css = ".generating {visibility: hidden}" examples = [ ["banff.jpg", "Can you tell me about this image in detail", 1, 200, 5, 1, 3, 0.9, "Nucleus sampling"] ] with gr.Blocks(theme=theme, analytics_enabled=False,css=css) as demo: gr.Markdown("## InstructBLIP: Towards General-purpose Vision-Language Models with Instruction Tuning") gr.Markdown( """ Unofficial demo for InstructBLIP. InstructBLIP is a new vision-language instruction-tuning framework by Salesforce that uses BLIP-2 models, achieving state-of-the-art zero-shot generalization performance on a wide range of vision-language tasks. The demo is based on the official Github implementation """ ) with gr.Row(): with gr.Column(scale=3): image_input = gr.Image(type="pil") prompt_textbox = gr.Textbox(label="Prompt:", placeholder="prompt", lines=2) output = gr.Textbox(label="Output") submit = gr.Button("Run", variant="primary") with gr.Column(scale=1): min_len = gr.Slider( minimum=1, maximum=50, value=1, step=1, interactive=True, label="Min Length", ) max_len = gr.Slider( minimum=10, maximum=500, value=250, step=5, interactive=True, label="Max Length", ) sampling = gr.Radio( choices=["Beam search", "Nucleus sampling"], value="Beam search", label="Text Decoding Method", interactive=True, ) top_p = gr.Slider( minimum=0.5, maximum=1.0, value=0.9, step=0.1, interactive=True, label="Top p", ) beam_size = gr.Slider( minimum=1, maximum=10, value=5, step=1, interactive=True, label="Beam Size", ) len_penalty = gr.Slider( minimum=-1, maximum=2, value=1, step=0.2, interactive=True, label="Length Penalty", ) repetition_penalty = gr.Slider( minimum=-1, maximum=3, value=1, step=0.2, interactive=True, label="Repetition Penalty", ) gr.Examples( examples=examples, inputs=[image_input, prompt_textbox, min_len, max_len, beam_size, len_penalty, repetition_penalty, top_p, sampling], cache_examples=False, fn=infer, outputs=[output], ) submit.click(infer, inputs=[image_input, prompt_textbox, min_len, max_len, beam_size, len_penalty, repetition_penalty, top_p, sampling], outputs=[output]) demo.queue(concurrency_count=16).launch(debug=True)