ClarityGAN / app.py
GRMenon's picture
Update app.py
9738334 verified
raw
history blame
1.39 kB
import numpy as np
from dehazing_gen import CycleGenerator
from PIL import Image
import torch
from torchvision import transforms
import gradio as gr
gan = CycleGenerator(num_residuals=6)
gan.load_state_dict(torch.load("genC.pth.tar", map_location=torch.device('cpu'))["model"])
def dehaze(img):
gan_transforms = transforms.Compose([
transforms.Resize((800, 800)),
transforms.ToTensor()
])
dehazed_output = gan(gan_transforms(img))
out_arr = dehazed_output.detach().cpu()
return np.array(out_arr).transpose(1, 2, 0)
sample_images = [
("Haze", "gradio_check1.png"),
("Haze", "gradio_check10.png"),
("Haze", "gradio_check13.png"),
]
with gr.Blocks() as demo:
gr.Markdown("# ClarityGAN")
gr.Markdown("## Image Dehazing using CycleGANs")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="pil")
with gr.Row():
dehaze_button = gr.Button("Dehaze")
with gr.Column():
output_image = gr.Image(label="Output Image", type="pil")
gr.Markdown("### Choose from these sample images below:")
for name, file_path in sample_images:
gr.Button(name).click(lambda fp=file_path:Image.open(fp), outputs=input_image)
dehaze_button.click(dehaze, inputs=input_image, outputs=output_image)
demo.launch()