import numpy as np from dehazing_gen import CycleGenerator from PIL import Image import torch from torchvision import transforms import gradio as gr gan = CycleGenerator(num_residuals=6) gan.load_state_dict(torch.load("genC.pth.tar", map_location=torch.device('cpu'))["model"]) def dehaze(img): gan_transforms = transforms.Compose([ transforms.Resize((800, 800)), transforms.ToTensor() ]) dehazed_output = gan(gan_transforms(img)) out_arr = dehazed_output.detach().cpu() return np.array(out_arr).transpose(1, 2, 0) sample_images = [ ("Haze", "gradio_check1.png"), ("Haze", "gradio_check10.png"), ("Haze", "gradio_check8.png"), ] with gr.Blocks() as demo: gr.Markdown("# ClarityGAN") gr.Markdown("## Image Dehazing using CycleGANs") with gr.Row(): with gr.Column(): input_image = gr.Image(label="Input Image", type="pil") with gr.Row(): dehaze_button = gr.Button("Dehaze") with gr.Column(): output_image = gr.Image(label="Output Image", type="pil") gr.Markdown("### Choose from these sample images below:") for name, file_path in sample_images: gr.Button(name).click(lambda fp=file_path:Image.open(fp), outputs=input_image) dehaze_button.click(dehaze, inputs=input_image, outputs=output_image) demo.launch()