Traly's picture
fix-1
e9b996f
raw
history blame contribute delete
No virus
2.81 kB
import os
from collections import OrderedDict
import numpy as np
import torch
from PIL import Image
from torchvision.transforms import transforms
from sam_diffsr.utils_sr.hparams import set_hparams, hparams
from sam_diffsr.utils_sr.matlab_resize import imresize
from sam_diffsr.tasks.srdiff_df2k_sam import SRDiffDf2k_sam as trainer_ori
ROOT_PATH = os.path.dirname(__file__)
class sam_diffsr_demo:
def __init__(self):
set_hparams()
ckpt_path = os.path.join(ROOT_PATH, 'weight/model_ckpt_steps_400000.ckpt')
self.model_init(ckpt_path)
def get_img_data(self, img_PIL, hparams, sr_scale=4):
img_lr = img_PIL.convert('RGB')
img_lr = np.uint8(np.asarray(img_lr))
h, w, c = img_lr.shape
h, w = h * sr_scale, w * sr_scale
h = h - h % (sr_scale * 2)
w = w - w % (sr_scale * 2)
h_l = h // sr_scale
w_l = w // sr_scale
img_lr = img_lr[:h_l, :w_l]
to_tensor_norm = transforms.Compose([
transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
img_lr_up = imresize(img_lr / 256, hparams['sr_scale']) # np.float [H, W, C]
img_lr, img_lr_up = [to_tensor_norm(x).float() for x in [img_lr, img_lr_up]]
img_lr = torch.unsqueeze(img_lr, dim=0)
img_lr_up = torch.unsqueeze(img_lr_up, dim=0)
return img_lr, img_lr_up
def load_checkpoint(self, ckpt_path):
checkpoint = torch.load(ckpt_path, map_location='cpu')
print(f'loding check from: {ckpt_path}')
stat_dict = checkpoint['state_dict']['model']
new_state_dict = OrderedDict()
for k, v in stat_dict.items():
if k[:7] == 'module.':
k = k[7:] # εŽ»ζŽ‰ `module.`
new_state_dict[k] = v
self.model.model.load_state_dict(new_state_dict)
self.model.model.cuda()
del checkpoint
torch.cuda.empty_cache()
def model_init(self, ckpt_path):
self.model = trainer_ori()
self.model.build_model()
self.load_checkpoint(ckpt_path)
torch.backends.cudnn.benchmark = False
def infer(self, img_PIL):
with torch.no_grad():
self.model.model.eval()
img_lr, img_lr_up = self.get_img_data(img_PIL, hparams, sr_scale=4)
img_lr = img_lr.to('cuda')
img_lr_up = img_lr_up.to('cuda')
img_sr, _ = self.model.model.sample(img_lr, img_lr_up, img_lr_up.shape)
img_sr = img_sr.clamp(-1, 1)
img_sr = self.model.tensor2img(img_sr)[0]
img_sr = Image.fromarray(img_sr)
return img_sr