Spaces:
Running
Running
import numpy as np | |
from dehazing_gen import CycleGenerator | |
from PIL import Image | |
import torch | |
from torchvision import transforms | |
import gradio as gr | |
gan = CycleGenerator(num_residuals=6) | |
gan.load_state_dict(torch.load("genC.pth.tar", map_location=torch.device('cpu'))["model"]) | |
def dehaze(img): | |
gan_transforms = transforms.Compose([ | |
transforms.Resize((800, 800)), | |
transforms.ToTensor() | |
]) | |
dehazed_output = gan(gan_transforms(img)) | |
out_arr = dehazed_output.detach().cpu() | |
return np.array(out_arr).transpose(1, 2, 0) | |
sample_images = [ | |
("Haze", "gradio_check1.png"), | |
("Haze", "gradio_check10.png"), | |
("Haze", "gradio_check8.png"), | |
] | |
with gr.Blocks() as demo: | |
gr.Markdown("# ClarityGAN") | |
gr.Markdown("## Image Dehazing using CycleGANs") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(label="Input Image", type="pil") | |
with gr.Row(): | |
dehaze_button = gr.Button("Dehaze") | |
with gr.Column(): | |
output_image = gr.Image(label="Output Image", type="pil") | |
gr.Markdown("### Choose from these sample images below:") | |
for name, file_path in sample_images: | |
gr.Button(name).click(lambda fp=file_path:Image.open(fp), outputs=input_image) | |
dehaze_button.click(dehaze, inputs=input_image, outputs=output_image) | |
demo.launch() |