johnbradley's picture
docker-app (#4)
d505436
raw
history blame
No virus
2.31 kB
import os
import tempfile
import json
import numpy as np
import gradio as gr
import cv2
from drexel_metadata.gen_metadata import gen_metadata
from PIL import Image
import urllib.request
from huggingface_hub import hf_hub_download
# Download model if not already cached locally
hf_hub_download(repo_id="imageomics/Drexel-metadata-generator", filename="model_final.pth", local_dir="output/enhanced")
EXAMPLE_URLS = [
'http://www.tubri.org/HDR/INHS/INHS_FISH_59422.jpg',
'http://www.tubri.org/HDR/INHS/INHS_FISH_76560.jpg'
]
EXAMPLES = []
for example_url in EXAMPLE_URLS:
file_name = os.path.basename(example_url)
urllib.request.urlretrieve(example_url, file_name)
# According to the docs examples should be a nested list
EXAMPLES.append([file_name])
def create_temp_file_path(prefix, suffix):
with tempfile.NamedTemporaryFile(prefix=prefix, suffix=suffix, delete=False) as tmpfile:
return tmpfile.name
def run_inference(input_img):
# input_mg: NumPy array with the shape (width, height, 3)
# Save input_mg as a temporary file
tmpfile = create_temp_file_path(prefix="input_", suffix=".png")
im = Image.fromarray(input_img)
im.save(tmpfile)
# Create temp filenames for output images
visfname = create_temp_file_path(prefix="vis_", suffix=".png")
maskfname = create_temp_file_path(prefix="mask_", suffix=".png")
# Run inference
result = gen_metadata(tmpfile, device='cpu', maskfname=maskfname, visfname=visfname)
json_metadata = json.dumps(result)
# Cleanup
os.remove(tmpfile)
return visfname, maskfname, json_metadata
def read_app_header_markdown():
with open('app_header.md') as infile:
return infile.read()
dm_app = gr.Interface(
description=read_app_header_markdown(),
fn=run_inference,
# Input shows markdown explaining and app and a single image upload panel
inputs=[
gr.Image()
],
# Output consists of a visualization image, a masked image, and JSON metadata
outputs=[
gr.Image(label='visualization'),
gr.Image(label='mask'),
gr.JSON(label="JSON metadata")
],
allow_flagging="never", # Do not save user's results or prompt for users to save the results
examples=EXAMPLES,
)
dm_app.launch(server_name="0.0.0.0")