import os from collections import OrderedDict import numpy as np import torch from PIL import Image from torchvision.transforms import transforms from sam_diffsr.utils_sr.hparams import set_hparams, hparams from sam_diffsr.utils_sr.matlab_resize import imresize from sam_diffsr.tasks.srdiff_df2k_sam import SRDiffDf2k_sam as trainer_ori ROOT_PATH = os.path.dirname(__file__) class sam_diffsr_demo: def __init__(self): set_hparams() ckpt_path = os.path.join(ROOT_PATH, 'weight/model_ckpt_steps_400000.ckpt') self.model_init(ckpt_path) def get_img_data(self, img_PIL, hparams, sr_scale=4): img_lr = img_PIL.convert('RGB') img_lr = np.uint8(np.asarray(img_lr)) h, w, c = img_lr.shape h, w = h * sr_scale, w * sr_scale h = h - h % (sr_scale * 2) w = w - w % (sr_scale * 2) h_l = h // sr_scale w_l = w // sr_scale img_lr = img_lr[:h_l, :w_l] to_tensor_norm = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) img_lr_up = imresize(img_lr / 256, hparams['sr_scale']) # np.float [H, W, C] img_lr, img_lr_up = [to_tensor_norm(x).float() for x in [img_lr, img_lr_up]] img_lr = torch.unsqueeze(img_lr, dim=0) img_lr_up = torch.unsqueeze(img_lr_up, dim=0) return img_lr, img_lr_up def load_checkpoint(self, ckpt_path): checkpoint = torch.load(ckpt_path, map_location='cpu') print(f'loding check from: {ckpt_path}') stat_dict = checkpoint['state_dict']['model'] new_state_dict = OrderedDict() for k, v in stat_dict.items(): if k[:7] == 'module.': k = k[7:] # 去掉 `module.` new_state_dict[k] = v self.model.model.load_state_dict(new_state_dict) self.model.model.cuda() del checkpoint torch.cuda.empty_cache() def model_init(self, ckpt_path): self.model = trainer_ori() self.model.build_model() self.load_checkpoint(ckpt_path) torch.backends.cudnn.benchmark = False def infer(self, img_PIL): with torch.no_grad(): self.model.model.eval() img_lr, img_lr_up = self.get_img_data(img_PIL, hparams, sr_scale=4) img_lr = img_lr.to('cuda') img_lr_up = img_lr_up.to('cuda') img_sr, _ = self.model.model.sample(img_lr, img_lr_up, img_lr_up.shape) img_sr = img_sr.clamp(-1, 1) img_sr = self.model.tensor2img(img_sr)[0] img_sr = Image.fromarray(img_sr) return img_sr