import importlib from collections import OrderedDict from pathlib import Path import gradio as gr import os import numpy as np import torch from PIL import Image from torchvision import transforms from sam_diffsr.utils_sr.hparams import set_hparams, hparams from sam_diffsr.utils_sr.matlab_resize import imresize def get_img_data(img_PIL, hparams, sr_scale=4): img_lr = img_PIL.convert('RGB') img_lr = np.uint8(np.asarray(img_lr)) h, w, c = img_lr.shape h, w = h * sr_scale, w * sr_scale h = h - h % (sr_scale * 2) w = w - w % (sr_scale * 2) h_l = h // sr_scale w_l = w // sr_scale img_lr = img_lr[:h_l, :w_l] to_tensor_norm = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) img_lr_up = imresize(img_lr / 256, hparams['sr_scale']) # np.float [H, W, C] img_lr, img_lr_up = [to_tensor_norm(x).float() for x in [img_lr, img_lr_up]] img_lr = torch.unsqueeze(img_lr, dim=0) img_lr_up = torch.unsqueeze(img_lr_up, dim=0) return img_lr, img_lr_up def load_checkpoint(model, ckpt_path): checkpoint = torch.load(ckpt_path, map_location='cpu') print(f'loding check from: {ckpt_path}') stat_dict = checkpoint['state_dict']['model'] new_state_dict = OrderedDict() for k, v in stat_dict.items(): if k[:7] == 'module.': k = k[7:] # 去掉 `module.` new_state_dict[k] = v model.load_state_dict(new_state_dict) model.cuda() del checkpoint torch.cuda.empty_cache() def model_init(ckpt_path): set_hparams() from sam_diffsr.tasks.srdiff_df2k_sam import SRDiffDf2k_sam as trainer trainer = trainer() trainer.build_model() load_checkpoint(trainer.model, ckpt_path) torch.backends.cudnn.benchmark = False return trainer def image_infer(img_PIL): with torch.no_grad(): trainer.model.eval() img_lr, img_lr_up = get_img_data(img_PIL, hparams, sr_scale=4) img_lr = img_lr.to('cuda') img_lr_up = img_lr_up.to('cuda') img_sr, _ = trainer.model.sample(img_lr, img_lr_up, img_lr_up.shape) img_sr = img_sr.clamp(-1, 1) img_sr = trainer.tensor2img(img_sr)[0] img_sr = Image.fromarray(img_sr) return img_sr root_path = os.path.dirname(__file__) cheetah = os.path.join(root_path, "images/0801x4.png") print(cheetah) ckpt_path = os.path.join(root_path, 'sam_diffsr/weight/model_ckpt_steps_400000.ckpt') trainer = model_init(ckpt_path) demo = gr.Interface(image_infer, gr.Image(type="pil", value=cheetah), "image", # flagging_options=["blurry", "incorrect", "other"], examples=[ os.path.join(root_path, "images/0801x4.png"), os.path.join(root_path, "images/0804x4.png"), os.path.join(root_path, "images/0809x4.png"), ] ) if __name__ == "__main__": demo.launch()