Traly commited on
Commit
193c713
1 Parent(s): 191a4b9
Files changed (50) hide show
  1. app.py +113 -0
  2. images/0801x4.png +0 -0
  3. images/0804x4.png +0 -0
  4. images/0809x4.png +0 -0
  5. images/lion.jpg +0 -0
  6. images/logo.png +0 -0
  7. requirements.txt +27 -0
  8. sam_diffsr/configs/base/config_base.yaml +41 -0
  9. sam_diffsr/configs/base/diffsr_base.yaml +41 -0
  10. sam_diffsr/configs/base/sr_base.yaml +11 -0
  11. sam_diffsr/configs/data/df2k4x.yaml +11 -0
  12. sam_diffsr/configs/data/df2k4x_sam.yaml +11 -0
  13. sam_diffsr/configs/diffsr_df2k4x.yaml +18 -0
  14. sam_diffsr/configs/rrdb/df2k4x_pretrain.yaml +14 -0
  15. sam_diffsr/configs/sam/sam_diffsr_df2k4x.yaml +26 -0
  16. sam_diffsr/models_sr/__init__.py +0 -0
  17. sam_diffsr/models_sr/commons.py +317 -0
  18. sam_diffsr/models_sr/diffsr_modules.py +177 -0
  19. sam_diffsr/models_sr/diffusion.py +291 -0
  20. sam_diffsr/models_sr/diffusion_sam.py +90 -0
  21. sam_diffsr/models_sr/module_util.py +58 -0
  22. sam_diffsr/tasks/__init__.py +0 -0
  23. sam_diffsr/tasks/infer.py +81 -0
  24. sam_diffsr/tasks/rrdb.py +68 -0
  25. sam_diffsr/tasks/rrdb_sam.py +49 -0
  26. sam_diffsr/tasks/srdiff.py +76 -0
  27. sam_diffsr/tasks/srdiff_df2k.py +119 -0
  28. sam_diffsr/tasks/srdiff_df2k_sam.py +211 -0
  29. sam_diffsr/tasks/trainer.py +346 -0
  30. sam_diffsr/tb_logs/events.out.tfevents.1709283169.wangchengchengdeMacBook-Pro.local.99018.0 +3 -0
  31. sam_diffsr/tb_logs/events.out.tfevents.1709284054.wangchengchengdeMacBook-Pro.local.99188.0 +3 -0
  32. sam_diffsr/tb_logs/events.out.tfevents.1709284076.wangchengchengdeMacBook-Pro.local.99198.0 +3 -0
  33. sam_diffsr/tb_logs/events.out.tfevents.1709284101.wangchengchengdeMacBook-Pro.local.99211.0 +3 -0
  34. sam_diffsr/tb_logs/events.out.tfevents.1709284193.wangchengchengdeMacBook-Pro.local.99233.0 +3 -0
  35. sam_diffsr/tb_logs/events.out.tfevents.1709284415.wangchengchengdeMacBook-Pro.local.99289.0 +3 -0
  36. sam_diffsr/tb_logs/events.out.tfevents.1709284460.wangchengchengdeMacBook-Pro.local.99308.0 +3 -0
  37. sam_diffsr/tb_logs/events.out.tfevents.1709284491.wangchengchengdeMacBook-Pro.local.99315.0 +3 -0
  38. sam_diffsr/tb_logs/events.out.tfevents.1709285127.wangchengchengdeMacBook-Pro.local.785.0 +3 -0
  39. sam_diffsr/tb_logs/events.out.tfevents.1709285146.wangchengchengdeMacBook-Pro.local.901.0 +3 -0
  40. sam_diffsr/tools/caculate_iqa.py +136 -0
  41. sam_diffsr/tools/visualize_sam_mask.py +20 -0
  42. sam_diffsr/utils_sr/__init__.py +0 -0
  43. sam_diffsr/utils_sr/dataset.py +50 -0
  44. sam_diffsr/utils_sr/hparams.py +157 -0
  45. sam_diffsr/utils_sr/indexed_datasets.py +72 -0
  46. sam_diffsr/utils_sr/matlab_resize.py +181 -0
  47. sam_diffsr/utils_sr/plt_img.py +109 -0
  48. sam_diffsr/utils_sr/sr_utils.py +171 -0
  49. sam_diffsr/utils_sr/utils.py +269 -0
  50. sam_diffsr/weight/model_ckpt_steps_400000.ckpt +3 -0
app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from collections import OrderedDict
3
+ from pathlib import Path
4
+
5
+ import gradio as gr
6
+ import os
7
+
8
+ import numpy as np
9
+ import torch
10
+ from PIL import Image
11
+ from torchvision import transforms
12
+
13
+ from sam_diffsr.utils_sr.hparams import set_hparams, hparams
14
+ from sam_diffsr.utils_sr.matlab_resize import imresize
15
+
16
+
17
+ def get_img_data(img_PIL, hparams, sr_scale=4):
18
+ img_lr = img_PIL.convert('RGB')
19
+ img_lr = np.uint8(np.asarray(img_lr))
20
+
21
+ h, w, c = img_lr.shape
22
+ h, w = h * sr_scale, w * sr_scale
23
+ h = h - h % (sr_scale * 2)
24
+ w = w - w % (sr_scale * 2)
25
+ h_l = h // sr_scale
26
+ w_l = w // sr_scale
27
+
28
+ img_lr = img_lr[:h_l, :w_l]
29
+
30
+ to_tensor_norm = transforms.Compose([
31
+ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
32
+ ])
33
+
34
+ img_lr_up = imresize(img_lr / 256, hparams['sr_scale']) # np.float [H, W, C]
35
+ img_lr, img_lr_up = [to_tensor_norm(x).float() for x in [img_lr, img_lr_up]]
36
+
37
+ img_lr = torch.unsqueeze(img_lr, dim=0)
38
+ img_lr_up = torch.unsqueeze(img_lr_up, dim=0)
39
+
40
+ return img_lr, img_lr_up
41
+
42
+
43
+ def load_checkpoint(model, ckpt_path):
44
+ checkpoint = torch.load(ckpt_path, map_location='cpu')
45
+ print(f'loding check from: {ckpt_path}')
46
+ stat_dict = checkpoint['state_dict']['model']
47
+
48
+ new_state_dict = OrderedDict()
49
+ for k, v in stat_dict.items():
50
+ if k[:7] == 'module.':
51
+ k = k[7:] # 去掉 `module.`
52
+ new_state_dict[k] = v
53
+
54
+ model.load_state_dict(new_state_dict)
55
+ model.cuda()
56
+ del checkpoint
57
+ torch.cuda.empty_cache()
58
+
59
+
60
+ def model_init(ckpt_path):
61
+ set_hparams()
62
+
63
+ from sam_diffsr.tasks.srdiff_df2k_sam import SRDiffDf2k_sam as trainer
64
+
65
+ trainer = trainer()
66
+
67
+ trainer.build_model()
68
+ load_checkpoint(trainer.model, ckpt_path)
69
+
70
+ torch.backends.cudnn.benchmark = False
71
+
72
+ return trainer
73
+
74
+
75
+ def image_infer(img_PIL):
76
+ with torch.no_grad():
77
+ trainer.model.eval()
78
+ img_lr, img_lr_up = get_img_data(img_PIL, hparams, sr_scale=4)
79
+
80
+ img_lr = img_lr.to('cuda')
81
+ img_lr_up = img_lr_up.to('cuda')
82
+
83
+ img_sr, _ = trainer.model.sample(img_lr, img_lr_up, img_lr_up.shape)
84
+
85
+ img_sr = img_sr.clamp(-1, 1)
86
+ img_sr = trainer.tensor2img(img_sr)[0]
87
+ img_sr = Image.fromarray(img_sr)
88
+
89
+ return img_sr
90
+
91
+
92
+ # cheetah = os.path.join(os.path.dirname(__file__), "images/cheetah1.jpg")
93
+
94
+ root_path = os.path.dirname(__file__)
95
+
96
+ cheetah = os.path.join(root_path, "images/lion.jpg")
97
+ print(cheetah)
98
+
99
+ demo = gr.Interface(image_infer, gr.Image(type="pil", value=cheetah), "image",
100
+ # flagging_options=["blurry", "incorrect", "other"],
101
+ examples=[
102
+ os.path.join(root_path, "images/0801x4.png"),
103
+ os.path.join(root_path, "images/0809x4.png"),
104
+ os.path.join(root_path, "images/0809x4.png"),
105
+ ]
106
+ )
107
+
108
+ if __name__ == "__main__":
109
+ parent_path = Path(__file__).absolute().parent
110
+ fill_root = os.path.abspath(parent_path)
111
+ ckpt_path = os.path.join(fill_root, 'sam_diffsr/weight/model_ckpt_steps_400000.ckpt')
112
+ trainer = model_init(ckpt_path)
113
+ demo.launch()
images/0801x4.png ADDED
images/0804x4.png ADDED
images/0809x4.png ADDED
images/lion.jpg ADDED
images/logo.png ADDED
requirements.txt ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ Cython
4
+ matplotlib
5
+ tqdm
6
+ numpy
7
+ scipy
8
+ PyYAML
9
+ tensorboardX
10
+ tensorboard
11
+ scikit-learn
12
+ scikit-image
13
+ seaborn
14
+ pillow
15
+ opencv-contrib-python
16
+ einops
17
+ lpips
18
+ natsort
19
+ timm
20
+ openpyxl
21
+ kornia
22
+ xlwt==1.3.0
23
+ xlrd==1.2.0
24
+ pyiqa
25
+ rotary_embedding_torch
26
+ opencv-python>=4.8.0.76
27
+ opencv-python-headless>=4.5.5.64
sam_diffsr/configs/base/config_base.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # task
2
+ binary_data_dir: ''
3
+ work_dir: '' # experiment directory.
4
+ infer: false # infer
5
+ seed: 1234
6
+ debug: false
7
+ save_codes:
8
+ - configs
9
+ - models_sr
10
+ - tasks
11
+ - utils_sr
12
+
13
+ #############
14
+ # dataset
15
+ #############
16
+ ds_workers: 1
17
+ endless: false
18
+
19
+ #########
20
+ # train and eval
21
+ #########
22
+ print_nan_grads: false
23
+ load_ckpt: ''
24
+ save_best: true
25
+ num_ckpt_keep: 100
26
+ clip_grad_norm: 0
27
+ accumulate_grad_batches: 1
28
+ tb_log_interval: 100
29
+ num_sanity_val_steps: 5 # steps of validation at the beginning
30
+ check_val_every_n_epoch: 10
31
+ val_check_interval: 4000
32
+ valid_monitor_key: 'val_loss'
33
+ valid_monitor_mode: 'min'
34
+ max_epochs: 1000
35
+ max_updates: 600000
36
+ amp: false
37
+ batch_size: 32
38
+ eval_batch_size: 32
39
+ num_workers: 8
40
+ test_input_dir: ''
41
+ resume_from_checkpoint: 0
sam_diffsr/configs/base/diffsr_base.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_config:
2
+ - ./config_base.yaml
3
+ - ./sr_base.yaml
4
+ # model
5
+ beta_schedule: cosine
6
+ beta_s: 0.008
7
+ beta_end: 0.02
8
+ hidden_size: 64
9
+ timesteps: 100
10
+ res: true
11
+ res_rescale: 2.0
12
+ up_input: false
13
+ use_wn: false
14
+ gn_groups: 0
15
+ use_rrdb: true
16
+ #rrdb_num_block: 8
17
+ #rrdb_num_feat: 32
18
+ rrdb_num_block: 17
19
+ rrdb_num_feat: 64
20
+ rrdb_ckpt: ''
21
+ unet_dim_mults: 1|2|2|4
22
+ clip_input: true
23
+ denoise_fn: unet
24
+ use_attn: false
25
+ aux_l1_loss: true
26
+ aux_ssim_loss: false
27
+ aux_percep_loss: false
28
+ loss_type: l1
29
+ pred_noise: true
30
+ clip_grad_norm: 10
31
+ weight_init: false
32
+ fix_rrdb: true
33
+
34
+ # train and eval
35
+ lr: 0.0002
36
+ decay_steps: 100000
37
+ accumulate_grad_batches: 1
38
+ style_interp: false
39
+ save_intermediate: false
40
+ show_training_process: false
41
+ print_arch: false
sam_diffsr/configs/base/sr_base.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_config: ./config_base.yaml
2
+ data_interp: bicubic # bilinear | bicubic
3
+ data_augmentation: false
4
+ max_updates: 300000
5
+ batch_size: 16
6
+ eval_batch_size: 1
7
+ test_batch_size: 1
8
+ valid_steps: 3
9
+ num_sanity_val_steps: 3
10
+ test_save_png: false
11
+ gen_dir_name: ''
sam_diffsr/configs/data/df2k4x.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ binary_data_dir: data/train/df2k4x
2
+ patch_size: 160
3
+ crop_size: 320
4
+ thresh_size: 160
5
+ test_crop_size: [ 2040, 2040 ]
6
+ test_thresh_size: 0
7
+ valid_steps: 4
8
+ num_sanity_val_steps: 4
9
+ eval_batch_size: 1
10
+ test_batch_size: 1
11
+ sr_scale: 4
sam_diffsr/configs/data/df2k4x_sam.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ binary_data_dir: data/train/df2k4x_sam
2
+ patch_size: 160
3
+ crop_size: 320
4
+ thresh_size: 160
5
+ test_crop_size: [ 2040, 2040 ]
6
+ test_thresh_size: 0
7
+ valid_steps: 4
8
+ num_sanity_val_steps: 4
9
+ eval_batch_size: 1
10
+ test_batch_size: 1
11
+ sr_scale: 4
sam_diffsr/configs/diffsr_df2k4x.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_config:
2
+ - ./base/diffsr_base.yaml
3
+ - ./data/df2k4x.yaml
4
+ trainer_cls: tasks.srdiff_df2k.SRDiffDf2k
5
+
6
+ # model
7
+ unet_dim_mults: 1|2|3|4
8
+ decay_steps: 200000
9
+
10
+ # train and test
11
+ batch_size: 64
12
+ max_updates: 400000
13
+
14
+ sam_config:
15
+ cond_sam: False
16
+ p_losses_sam: False
17
+ p_sample_sam: False
18
+ q_sample_sam: False
sam_diffsr/configs/rrdb/df2k4x_pretrain.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_config:
2
+ - ../sr_base.yaml
3
+ - ../df2k4x.yaml
4
+ trainer_cls: tasks.rrdb.RRDBDf2kTask
5
+ # model
6
+ hidden_size: 64
7
+ lr: 0.0002
8
+ num_block: 17
9
+
10
+ # train and eval
11
+ max_updates: 100000
12
+ batch_size: 64
13
+ eval_batch_size: 1
14
+ valid_steps: 3
sam_diffsr/configs/sam/sam_diffsr_df2k4x.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_config:
2
+ - ../base/diffsr_base.yaml
3
+ - ../data/df2k4x_sam.yaml
4
+ trainer_cls: tasks.srdiff_df2k_sam.SRDiffDf2k_sam
5
+
6
+ # model
7
+ unet_dim_mults: 1|2|3|4
8
+ decay_steps: 200000
9
+
10
+ # train and test
11
+ batch_size: 64
12
+ max_updates: 400000
13
+
14
+ rrdb_num_feat: 64
15
+
16
+ sam_config:
17
+ cond_sam: False
18
+ p_losses_sam: True
19
+ mask_coefficient: True
20
+
21
+ sam_data_config:
22
+ all_same_mask_to_zero: False
23
+ normalize_01: False
24
+ normalize_11: False
25
+
26
+ num_sanity_val_steps: 2
sam_diffsr/models_sr/__init__.py ADDED
File without changes
sam_diffsr/models_sr/commons.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+ from torch import nn
6
+ from torch.nn import Parameter
7
+
8
+
9
+ class Residual(nn.Module):
10
+ def __init__(self, fn):
11
+ super().__init__()
12
+ self.fn = fn
13
+
14
+ def forward(self, x, *args, **kwargs):
15
+ return self.fn(x, *args, **kwargs) + x
16
+
17
+
18
+ class SinusoidalPosEmb(nn.Module):
19
+ def __init__(self, dim):
20
+ super().__init__()
21
+ self.dim = dim
22
+
23
+ def forward(self, x):
24
+ device = x.device
25
+ half_dim = self.dim // 2
26
+ emb = math.log(10000) / (half_dim - 1)
27
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
28
+ emb = x[:, None] * emb[None, :]
29
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
30
+ return emb
31
+
32
+
33
+ class Mish(nn.Module):
34
+ def forward(self, x):
35
+ return x * torch.tanh(F.softplus(x))
36
+
37
+
38
+ class Rezero(nn.Module):
39
+ def __init__(self, fn):
40
+ super().__init__()
41
+ self.fn = fn
42
+ self.g = nn.Parameter(torch.zeros(1))
43
+
44
+ def forward(self, x):
45
+ return self.fn(x) * self.g
46
+
47
+
48
+ # building block modules
49
+
50
+ class Block(nn.Module):
51
+ def __init__(self, dim, dim_out, groups=8):
52
+ super().__init__()
53
+ if groups == 0:
54
+ self.block = nn.Sequential(
55
+ nn.ReflectionPad2d(1),
56
+ nn.Conv2d(dim, dim_out, 3),
57
+ Mish()
58
+ )
59
+ else:
60
+ self.block = nn.Sequential(
61
+ nn.ReflectionPad2d(1),
62
+ nn.Conv2d(dim, dim_out, 3),
63
+ nn.GroupNorm(groups, dim_out),
64
+ Mish()
65
+ )
66
+
67
+ def forward(self, x):
68
+ return self.block(x)
69
+
70
+
71
+ class ResnetBlock(nn.Module):
72
+ def __init__(self, dim, dim_out, *, time_emb_dim=0, groups=8):
73
+ super().__init__()
74
+ if time_emb_dim > 0:
75
+ self.mlp = nn.Sequential(
76
+ Mish(),
77
+ nn.Linear(time_emb_dim, dim_out)
78
+ )
79
+
80
+ self.block1 = Block(dim, dim_out, groups=groups)
81
+ self.block2 = Block(dim_out, dim_out, groups=groups)
82
+ self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
83
+
84
+ def forward(self, x, time_emb=None, cond=None):
85
+ h = self.block1(x)
86
+ if time_emb is not None:
87
+ h += self.mlp(time_emb)[:, :, None, None]
88
+ if cond is not None:
89
+ h += cond
90
+ h = self.block2(h)
91
+ return h + self.res_conv(x)
92
+
93
+
94
+ class Upsample(nn.Module):
95
+ def __init__(self, dim):
96
+ super().__init__()
97
+ self.conv = nn.Sequential(
98
+ nn.ConvTranspose2d(dim, dim, 4, 2, 1),
99
+ )
100
+
101
+ def forward(self, x):
102
+ return self.conv(x)
103
+
104
+
105
+ class Downsample(nn.Module):
106
+ def __init__(self, dim):
107
+ super().__init__()
108
+ self.conv = nn.Sequential(
109
+ nn.ReflectionPad2d(1),
110
+ nn.Conv2d(dim, dim, 3, 2),
111
+ )
112
+
113
+ def forward(self, x):
114
+ return self.conv(x)
115
+
116
+
117
+ class LinearAttention(nn.Module):
118
+ def __init__(self, dim, heads=4, dim_head=32):
119
+ super().__init__()
120
+ self.heads = heads
121
+ hidden_dim = dim_head * heads
122
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
123
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
124
+
125
+ def forward(self, x):
126
+ b, c, h, w = x.shape
127
+ qkv = self.to_qkv(x)
128
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads=self.heads, qkv=3)
129
+ k = k.softmax(dim=-1)
130
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
131
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
132
+ out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
133
+ return self.to_out(out)
134
+
135
+
136
+ class MultiheadAttention(nn.Module):
137
+ def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
138
+ add_bias_kv=False, add_zero_attn=False):
139
+ super().__init__()
140
+ self.embed_dim = embed_dim
141
+ self.kdim = kdim if kdim is not None else embed_dim
142
+ self.vdim = vdim if vdim is not None else embed_dim
143
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
144
+
145
+ self.num_heads = num_heads
146
+ self.dropout = dropout
147
+ self.head_dim = embed_dim // num_heads
148
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
149
+ self.scaling = self.head_dim ** -0.5
150
+ if self.qkv_same_dim:
151
+ self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
152
+ else:
153
+ self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
154
+ self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
155
+ self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
156
+
157
+ if bias:
158
+ self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
159
+ else:
160
+ self.register_parameter('in_proj_bias', None)
161
+
162
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
163
+
164
+ if add_bias_kv:
165
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
166
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
167
+ else:
168
+ self.bias_k = self.bias_v = None
169
+
170
+ self.add_zero_attn = add_zero_attn
171
+
172
+ self.reset_parameters()
173
+
174
+ self.enable_torch_version = False
175
+ if hasattr(F, "multi_head_attention_forward"):
176
+ self.enable_torch_version = True
177
+ else:
178
+ self.enable_torch_version = False
179
+ self.last_attn_probs = None
180
+
181
+ def reset_parameters(self):
182
+ if self.qkv_same_dim:
183
+ nn.init.xavier_uniform_(self.in_proj_weight)
184
+ else:
185
+ nn.init.xavier_uniform_(self.k_proj_weight)
186
+ nn.init.xavier_uniform_(self.v_proj_weight)
187
+ nn.init.xavier_uniform_(self.q_proj_weight)
188
+
189
+ nn.init.xavier_uniform_(self.out_proj.weight)
190
+ if self.in_proj_bias is not None:
191
+ nn.init.constant_(self.in_proj_bias, 0.)
192
+ nn.init.constant_(self.out_proj.bias, 0.)
193
+ if self.bias_k is not None:
194
+ nn.init.xavier_normal_(self.bias_k)
195
+ if self.bias_v is not None:
196
+ nn.init.xavier_normal_(self.bias_v)
197
+
198
+ def forward(
199
+ self,
200
+ query, key, value,
201
+ key_padding_mask=None,
202
+ need_weights=True,
203
+ attn_mask=None,
204
+ before_softmax=False,
205
+ need_head_weights=False,
206
+ ):
207
+ """Input shape: [B, T, C]
208
+
209
+ Args:
210
+ key_padding_mask (ByteTensor, optional): mask to exclude
211
+ keys that are pads, of shape `(batch, src_len)`, where
212
+ padding elements are indicated by 1s.
213
+ need_weights (bool, optional): return the attention weights,
214
+ averaged over heads (default: False).
215
+ attn_mask (ByteTensor, optional): typically used to
216
+ implement causal attention, where the mask prevents the
217
+ attention from looking forward in time (default: None).
218
+ before_softmax (bool, optional): return the raw attention
219
+ weights and values before the attention softmax.
220
+ need_head_weights (bool, optional): return the attention
221
+ weights for each head. Implies *need_weights*. Default:
222
+ return the average attention weights over all heads.
223
+ """
224
+ if need_head_weights:
225
+ need_weights = True
226
+ query = query.transpose(0, 1)
227
+ key = key.transpose(0, 1)
228
+ value = value.transpose(0, 1)
229
+ tgt_len, bsz, embed_dim = query.size()
230
+ assert embed_dim == self.embed_dim
231
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
232
+ attn_output, attn_output_weights = F.multi_head_attention_forward(
233
+ query, key, value, self.embed_dim, self.num_heads,
234
+ self.in_proj_weight, self.in_proj_bias, self.bias_k, self.bias_v,
235
+ self.add_zero_attn, self.dropout, self.out_proj.weight, self.out_proj.bias,
236
+ self.training, key_padding_mask, need_weights, attn_mask)
237
+ attn_output = attn_output.transpose(0, 1)
238
+ return attn_output, attn_output_weights
239
+
240
+ def in_proj_qkv(self, query):
241
+ return self._in_proj(query).chunk(3, dim=-1)
242
+
243
+ def in_proj_q(self, query):
244
+ if self.qkv_same_dim:
245
+ return self._in_proj(query, end=self.embed_dim)
246
+ else:
247
+ bias = self.in_proj_bias
248
+ if bias is not None:
249
+ bias = bias[:self.embed_dim]
250
+ return F.linear(query, self.q_proj_weight, bias)
251
+
252
+ def in_proj_k(self, key):
253
+ if self.qkv_same_dim:
254
+ return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
255
+ else:
256
+ weight = self.k_proj_weight
257
+ bias = self.in_proj_bias
258
+ if bias is not None:
259
+ bias = bias[self.embed_dim:2 * self.embed_dim]
260
+ return F.linear(key, weight, bias)
261
+
262
+ def in_proj_v(self, value):
263
+ if self.qkv_same_dim:
264
+ return self._in_proj(value, start=2 * self.embed_dim)
265
+ else:
266
+ weight = self.v_proj_weight
267
+ bias = self.in_proj_bias
268
+ if bias is not None:
269
+ bias = bias[2 * self.embed_dim:]
270
+ return F.linear(value, weight, bias)
271
+
272
+ def _in_proj(self, input, start=0, end=None):
273
+ weight = self.in_proj_weight
274
+ bias = self.in_proj_bias
275
+ weight = weight[start:end, :]
276
+ if bias is not None:
277
+ bias = bias[start:end]
278
+ return F.linear(input, weight, bias)
279
+
280
+
281
+ class ResidualDenseBlock_5C(nn.Module):
282
+ def __init__(self, nf=64, gc=32, bias=True):
283
+ super(ResidualDenseBlock_5C, self).__init__()
284
+ # gc: growth channel, i.e. intermediate channels
285
+ self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
286
+ self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
287
+ self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
288
+ self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
289
+ self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
290
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
291
+
292
+ # initialization
293
+ # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
294
+
295
+ def forward(self, x):
296
+ x1 = self.lrelu(self.conv1(x))
297
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
298
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
299
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
300
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
301
+ return x5 * 0.2 + x
302
+
303
+
304
+ class RRDB(nn.Module):
305
+ '''Residual in Residual Dense Block'''
306
+
307
+ def __init__(self, nf, gc=32):
308
+ super(RRDB, self).__init__()
309
+ self.RDB1 = ResidualDenseBlock_5C(nf, gc)
310
+ self.RDB2 = ResidualDenseBlock_5C(nf, gc)
311
+ self.RDB3 = ResidualDenseBlock_5C(nf, gc)
312
+
313
+ def forward(self, x):
314
+ out = self.RDB1(x)
315
+ out = self.RDB2(out)
316
+ out = self.RDB3(out)
317
+ return out * 0.2 + x
sam_diffsr/models_sr/diffsr_modules.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+
7
+ from sam_diffsr.utils_sr.hparams import hparams
8
+ from .commons import Mish, SinusoidalPosEmb, RRDB, Residual, Rezero, LinearAttention
9
+ from .commons import ResnetBlock, Upsample, Block, Downsample
10
+ from .module_util import make_layer, initialize_weights
11
+
12
+
13
+ class RRDBNet(nn.Module):
14
+ def __init__(self, in_nc, out_nc, nf, nb, gc=32):
15
+ super(RRDBNet, self).__init__()
16
+ RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
17
+
18
+ self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
19
+ self.RRDB_trunk = make_layer(RRDB_block_f, nb)
20
+ self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
21
+ #### upsampling
22
+ self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
23
+ self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
24
+ if hparams['sr_scale'] == 8:
25
+ self.upconv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
26
+ self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
27
+ self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
28
+
29
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2)
30
+
31
+ def forward(self, x, get_fea=False):
32
+ feas = []
33
+ x = (x + 1) / 2
34
+ fea_first = fea = self.conv_first(x)
35
+ for l in self.RRDB_trunk:
36
+ fea = l(fea)
37
+ feas.append(fea)
38
+ trunk = self.trunk_conv(fea)
39
+ fea = fea_first + trunk
40
+ feas.append(fea)
41
+
42
+ fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
43
+ fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
44
+ if hparams['sr_scale'] == 8:
45
+ fea = self.lrelu(self.upconv3(F.interpolate(fea, scale_factor=2, mode='nearest')))
46
+ fea_hr = self.HRconv(fea)
47
+ out = self.conv_last(self.lrelu(fea_hr))
48
+ out = out.clamp(0, 1)
49
+ out = out * 2 - 1
50
+ if get_fea:
51
+ return out, feas
52
+ else:
53
+ return out
54
+
55
+
56
+ class Unet(nn.Module):
57
+ def __init__(self, dim, out_dim=None, dim_mults=(1, 2, 4, 8), cond_dim=32):
58
+ super().__init__()
59
+ dims = [3, *map(lambda m: dim * m, dim_mults)]
60
+ in_out = list(zip(dims[:-1], dims[1:]))
61
+ groups = 0
62
+
63
+ self.sam_config = hparams['sam_config']
64
+
65
+ cond_proj_in = cond_dim * ((hparams['rrdb_num_block'] + 1) // 3)
66
+ if self.sam_config['cond_sam']:
67
+ # cond_proj_in += 1
68
+ self.sam_conv = nn.Sequential(
69
+ nn.Conv2d(dim + 1, dim, 1, 1, 0, bias=True),
70
+ nn.Conv2d(dim, dim, 1, 1, 0, bias=True),
71
+ nn.Conv2d(dim, dim, 1, 1, 0, bias=True)
72
+ )
73
+ else:
74
+ self.sam_conv = None
75
+
76
+ self.cond_proj = nn.ConvTranspose2d(cond_proj_in, dim, hparams['sr_scale'] * 2, hparams['sr_scale'],
77
+ hparams['sr_scale'] // 2)
78
+
79
+ self.time_pos_emb = SinusoidalPosEmb(dim)
80
+ self.mlp = nn.Sequential(
81
+ nn.Linear(dim, dim * 4),
82
+ Mish(),
83
+ nn.Linear(dim * 4, dim)
84
+ )
85
+
86
+ self.downs = nn.ModuleList([])
87
+ self.ups = nn.ModuleList([])
88
+ num_resolutions = len(in_out)
89
+
90
+ for ind, (dim_in, dim_out) in enumerate(in_out):
91
+ is_last = ind >= (num_resolutions - 1)
92
+
93
+ self.downs.append(nn.ModuleList([
94
+ ResnetBlock(dim_in, dim_out, time_emb_dim=dim, groups=groups),
95
+ ResnetBlock(dim_out, dim_out, time_emb_dim=dim, groups=groups),
96
+ Downsample(dim_out) if not is_last else nn.Identity()
97
+ ]))
98
+
99
+ mid_dim = dims[-1]
100
+ self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim, groups=groups)
101
+ if hparams['use_attn']:
102
+ self.mid_attn = Residual(Rezero(LinearAttention(mid_dim)))
103
+ self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim, groups=groups)
104
+
105
+ for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
106
+ is_last = ind >= (num_resolutions - 1)
107
+
108
+ self.ups.append(nn.ModuleList([
109
+ ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim, groups=groups),
110
+ ResnetBlock(dim_in, dim_in, time_emb_dim=dim, groups=groups),
111
+ Upsample(dim_in) if not is_last else nn.Identity()
112
+ ]))
113
+
114
+ self.final_conv = nn.Sequential(
115
+ Block(dim, dim, groups=groups),
116
+ nn.Conv2d(dim, out_dim, 1)
117
+ )
118
+
119
+ if hparams['res'] and hparams['up_input']:
120
+ self.up_proj = nn.Sequential(
121
+ nn.ReflectionPad2d(1), nn.Conv2d(3, dim, 3),
122
+ )
123
+ if hparams['use_wn']:
124
+ self.apply_weight_norm()
125
+ if hparams['weight_init']:
126
+ self.apply(initialize_weights)
127
+
128
+ def apply_weight_norm(self):
129
+ def _apply_weight_norm(m):
130
+ if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
131
+ torch.nn.utils.weight_norm(m)
132
+ # print(f"| Weight norm is applied to {m}.")
133
+
134
+ self.apply(_apply_weight_norm)
135
+
136
+ def forward(self, x, time, cond, img_lr_up, sam_mask=None):
137
+ t = self.time_pos_emb(time)
138
+ t = self.mlp(t)
139
+ h = []
140
+
141
+ cond = self.cond_proj(torch.cat(cond[2::3], 1))
142
+
143
+ if self.sam_config['cond_sam']:
144
+ cond = torch.cat([cond, sam_mask], 1)
145
+ cond = self.sam_conv(cond)
146
+
147
+ for i, (resnet, resnet2, downsample) in enumerate(self.downs):
148
+ x = resnet(x, t)
149
+ x = resnet2(x, t)
150
+ if i == 0:
151
+ x = x + cond
152
+ if hparams['res'] and hparams['up_input']:
153
+ x = x + self.up_proj(img_lr_up)
154
+ h.append(x)
155
+ x = downsample(x)
156
+
157
+ x = self.mid_block1(x, t)
158
+ if hparams['use_attn']:
159
+ x = self.mid_attn(x)
160
+ x = self.mid_block2(x, t)
161
+
162
+ for resnet, resnet2, upsample in self.ups:
163
+ x = torch.cat((x, h.pop()), dim=1)
164
+ x = resnet(x, t)
165
+ x = resnet2(x, t)
166
+ x = upsample(x)
167
+
168
+ return self.final_conv(x)
169
+
170
+ def make_generation_fast_(self):
171
+ def remove_weight_norm(m):
172
+ try:
173
+ nn.utils.remove_weight_norm(m)
174
+ except ValueError: # this module didn't have weight norm
175
+ return
176
+
177
+ self.apply(remove_weight_norm)
sam_diffsr/models_sr/diffusion.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ from tqdm import tqdm
7
+
8
+ from sam_diffsr.utils_sr.plt_img import plt_tensor_img
9
+ from .module_util import default
10
+ from sam_diffsr.utils_sr.sr_utils import SSIM, PerceptualLoss
11
+ from sam_diffsr.utils_sr.hparams import hparams
12
+
13
+
14
+ # gaussian diffusion trainer class
15
+ def extract(a, t, x_shape):
16
+ b, *_ = t.shape
17
+ out = a.gather(-1, t)
18
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
19
+
20
+
21
+ def noise_like(shape, device, repeat=False):
22
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
23
+ noise = lambda: torch.randn(shape, device=device)
24
+ return repeat_noise() if repeat else noise()
25
+
26
+
27
+ def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
28
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
29
+ warmup_time = int(num_diffusion_timesteps * warmup_frac)
30
+ betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
31
+ return betas
32
+
33
+
34
+ def get_beta_schedule(num_diffusion_timesteps, beta_schedule='linear', beta_start=0.0001, beta_end=0.02):
35
+ if beta_schedule == 'quad':
36
+ betas = np.linspace(beta_start ** 0.5, beta_end ** 0.5, num_diffusion_timesteps, dtype=np.float64) ** 2
37
+ elif beta_schedule == 'linear':
38
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
39
+ elif beta_schedule == 'warmup10':
40
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
41
+ elif beta_schedule == 'warmup50':
42
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
43
+ elif beta_schedule == 'const':
44
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
45
+ elif beta_schedule == 'jsd': # 1/T, 1/(T-1), 1/(T-2), ..., 1
46
+ betas = 1. / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64)
47
+ else:
48
+ raise NotImplementedError(beta_schedule)
49
+ assert betas.shape == (num_diffusion_timesteps,)
50
+ return betas
51
+
52
+
53
+ def cosine_beta_schedule(timesteps, s=0.008):
54
+ """
55
+ cosine schedule
56
+ as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
57
+ """
58
+ steps = timesteps + 1
59
+ x = np.linspace(0, steps, steps)
60
+ alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
61
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
62
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
63
+ return np.clip(betas, a_min=0, a_max=0.999)
64
+
65
+
66
+ class GaussianDiffusion(nn.Module):
67
+ def __init__(self, denoise_fn, rrdb_net, timesteps=1000, loss_type='l1'):
68
+ super().__init__()
69
+ self.denoise_fn = denoise_fn
70
+ # condition net
71
+ self.rrdb = rrdb_net
72
+ self.ssim_loss = SSIM(window_size=11)
73
+
74
+
75
+ if hparams['beta_schedule'] == 'cosine':
76
+ betas = cosine_beta_schedule(timesteps, s=hparams['beta_s'])
77
+ if hparams['beta_schedule'] == 'linear':
78
+ betas = get_beta_schedule(timesteps, beta_end=hparams['beta_end'])
79
+ if hparams['res']:
80
+ betas[-1] = 0.999
81
+
82
+ alphas = 1. - betas
83
+ alphas_cumprod = np.cumprod(alphas, axis=0)
84
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
85
+
86
+ timesteps, = betas.shape
87
+ self.num_timesteps = int(timesteps)
88
+ self.loss_type = loss_type
89
+
90
+ to_torch = partial(torch.tensor, dtype=torch.float32)
91
+
92
+ self.register_buffer('betas', to_torch(betas))
93
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
94
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
95
+
96
+ # calculations for diffusion q(x_t | x_{t-1}) and others
97
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
98
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
99
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
100
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
101
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
102
+
103
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
104
+ posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
105
+
106
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
107
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
108
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
109
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
110
+ self.register_buffer('posterior_mean_coef1', to_torch(
111
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
112
+ self.register_buffer('posterior_mean_coef2', to_torch(
113
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
114
+ self.sample_tqdm = True
115
+
116
+ self.mask_coefficient = to_torch(np.sqrt(1. - alphas_cumprod) * betas)
117
+
118
+ def q_mean_variance(self, x_start, t):
119
+ mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
120
+ variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
121
+ log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
122
+ return mean, variance, log_variance
123
+
124
+ def predict_start_from_noise(self, x_t, t, noise):
125
+ return (
126
+ extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
127
+ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
128
+ )
129
+
130
+ def q_posterior(self, x_start, x_t, t):
131
+ posterior_mean = (
132
+ extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
133
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
134
+ )
135
+ posterior_variance = extract(self.posterior_variance, t, x_t.shape)
136
+ posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
137
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
138
+
139
+ def p_mean_variance(self, x, t, noise_pred, clip_denoised: bool):
140
+ x_recon = self.predict_start_from_noise(x, t=t, noise=noise_pred)
141
+
142
+ if clip_denoised:
143
+ x_recon.clamp_(-1.0, 1.0)
144
+
145
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
146
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
147
+
148
+ def forward(self, img_hr, img_lr, img_lr_up, t=None, *args, **kwargs):
149
+ x = img_hr
150
+ b, *_, device = *x.shape, x.device
151
+ t = torch.randint(0, self.num_timesteps, (b,), device=device).long() \
152
+ if t is None else torch.LongTensor([t]).repeat(b).to(device)
153
+ if hparams['use_rrdb']:
154
+ if hparams['fix_rrdb']:
155
+ self.rrdb.eval()
156
+ with torch.no_grad():
157
+ rrdb_out, cond = self.rrdb(img_lr, True)
158
+ else:
159
+ rrdb_out, cond = self.rrdb(img_lr, True)
160
+ else:
161
+ rrdb_out = img_lr_up
162
+ cond = img_lr
163
+ x = self.img2res(x, img_lr_up)
164
+ p_losses, x_tp1, noise_pred, x_t, x_t_gt, x_0 = self.p_losses(x, t, cond, img_lr_up, *args, **kwargs)
165
+ ret = {'q': p_losses}
166
+ if not hparams['fix_rrdb']:
167
+ if hparams['aux_l1_loss']:
168
+ ret['aux_l1'] = F.l1_loss(rrdb_out, img_hr)
169
+ if hparams['aux_ssim_loss']:
170
+ ret['aux_ssim'] = 1 - self.ssim_loss(rrdb_out, img_hr)
171
+ if hparams['aux_percep_loss']:
172
+ ret['aux_percep'] = self.percep_loss_fn[0](img_hr, rrdb_out)
173
+
174
+
175
+ x_tp1 = self.res2img(x_tp1, img_lr_up)
176
+ x_t = self.res2img(x_t, img_lr_up)
177
+ x_t_gt = self.res2img(x_t_gt, img_lr_up)
178
+ return ret, (x_tp1, x_t_gt, x_t), t
179
+
180
+ def p_losses(self, x_start, t, cond, img_lr_up, noise=None):
181
+ noise = default(noise, lambda: torch.randn_like(x_start))
182
+ x_tp1_gt = self.q_sample(x_start=x_start, t=t, noise=noise)
183
+ x_t_gt = self.q_sample(x_start=x_start, t=t - 1, noise=noise)
184
+ noise_pred = self.denoise_fn(x_tp1_gt, t, cond, img_lr_up)
185
+ x_t_pred, x0_pred = self.p_sample(x_tp1_gt, t, cond, img_lr_up, noise_pred=noise_pred)
186
+
187
+ if self.loss_type == 'l1':
188
+ loss = (noise - noise_pred).abs().mean()
189
+ elif self.loss_type == 'l2':
190
+ loss = F.mse_loss(noise, noise_pred)
191
+ elif self.loss_type == 'ssim':
192
+ loss = (noise - noise_pred).abs().mean()
193
+ loss = loss + (1 - self.ssim_loss(noise, noise_pred))
194
+ else:
195
+ raise NotImplementedError()
196
+ return loss, x_tp1_gt, noise_pred, x_t_pred, x_t_gt, x0_pred
197
+
198
+ def q_sample(self, x_start, t, noise=None):
199
+ noise = default(noise, lambda: torch.randn_like(x_start))
200
+ t_cond = (t[:, None, None, None] >= 0).float()
201
+ t = t.clamp_min(0)
202
+ return (
203
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
204
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
205
+ ) * t_cond + x_start * (1 - t_cond)
206
+
207
+ @torch.no_grad()
208
+ def p_sample(self, x, t, cond, img_lr_up, noise_pred=None, clip_denoised=True, repeat_noise=False):
209
+ if noise_pred is None:
210
+ noise_pred = self.denoise_fn(x, t, cond=cond, img_lr_up=img_lr_up)
211
+ b, *_, device = *x.shape, x.device
212
+ model_mean, _, model_log_variance, x0_pred = self.p_mean_variance(
213
+ x=x, t=t, noise_pred=noise_pred, clip_denoised=clip_denoised)
214
+ noise = noise_like(x.shape, device, repeat_noise)
215
+ # no noise when t == 0
216
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
217
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0_pred
218
+
219
+ @torch.no_grad()
220
+ def sample(self, img_lr, img_lr_up, shape, save_intermediate=False):
221
+ device = self.betas.device
222
+ b = shape[0]
223
+ if not hparams['res']:
224
+ t = torch.full((b,), self.num_timesteps - 1, device=device, dtype=torch.long)
225
+ img = self.q_sample(img_lr_up, t)
226
+ else:
227
+ img = torch.randn(shape, device=device)
228
+ if hparams['use_rrdb']:
229
+ rrdb_out, cond = self.rrdb(img_lr, True)
230
+ else:
231
+ rrdb_out = img_lr_up
232
+ cond = img_lr
233
+ it = reversed(range(0, self.num_timesteps))
234
+ if self.sample_tqdm:
235
+ it = tqdm(it, desc='sampling loop time step', total=self.num_timesteps)
236
+ images = []
237
+ for i in it:
238
+ img, x_recon = self.p_sample(
239
+ img, torch.full((b,), i, device=device, dtype=torch.long), cond, img_lr_up)
240
+ if save_intermediate:
241
+ img_ = self.res2img(img, img_lr_up)
242
+ x_recon_ = self.res2img(x_recon, img_lr_up)
243
+ images.append((img_.cpu(), x_recon_.cpu()))
244
+ img = self.res2img(img, img_lr_up)
245
+ if save_intermediate:
246
+ return img, rrdb_out, images
247
+ else:
248
+ return img, rrdb_out
249
+
250
+ @torch.no_grad()
251
+ def interpolate(self, x1, x2, img_lr, img_lr_up, t=None, lam=0.5):
252
+ b, *_, device = *x1.shape, x1.device
253
+ t = default(t, self.num_timesteps - 1)
254
+ if hparams['use_rrdb']:
255
+ rrdb_out, cond = self.rrdb(img_lr, True)
256
+ else:
257
+ cond = img_lr
258
+
259
+ assert x1.shape == x2.shape
260
+
261
+ x1 = self.img2res(x1, img_lr_up)
262
+ x2 = self.img2res(x2, img_lr_up)
263
+
264
+ t_batched = torch.stack([torch.tensor(t, device=device)] * b)
265
+ xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2))
266
+
267
+ img = (1 - lam) * xt1 + lam * xt2
268
+ for i in tqdm(reversed(range(0, t)), desc='interpolation sample time step', total=t):
269
+ img, x_recon = self.p_sample(
270
+ img, torch.full((b,), i, device=device, dtype=torch.long), cond, img_lr_up)
271
+
272
+ img = self.res2img(img, img_lr_up)
273
+ return img
274
+
275
+ def res2img(self, img_, img_lr_up, clip_input=None):
276
+ if clip_input is None:
277
+ clip_input = hparams['clip_input']
278
+ if hparams['res']:
279
+ if clip_input:
280
+ img_ = img_.clamp(-1, 1)
281
+ img_ = img_ / hparams['res_rescale'] + img_lr_up
282
+ return img_
283
+
284
+ def img2res(self, x, img_lr_up, clip_input=None):
285
+ if clip_input is None:
286
+ clip_input = hparams['clip_input']
287
+ if hparams['res']:
288
+ x = (x - img_lr_up) * hparams['res_rescale']
289
+ if clip_input:
290
+ x = x.clamp(-1, 1)
291
+ return x
sam_diffsr/models_sr/diffusion_sam.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from tqdm import tqdm
4
+
5
+ from sam_diffsr.utils_sr.hparams import hparams
6
+ from .diffusion import GaussianDiffusion, noise_like, extract
7
+ from .module_util import default
8
+
9
+
10
+ class GaussianDiffusion_sam(GaussianDiffusion):
11
+ def __init__(self, denoise_fn, rrdb_net, timesteps=1000, loss_type='l1', sam_config=None):
12
+ super().__init__(denoise_fn, rrdb_net, timesteps, loss_type)
13
+ self.sam_config = sam_config
14
+
15
+ def p_losses(self, x_start, t, cond, img_lr_up, noise=None, sam_mask=None):
16
+ noise = default(noise, lambda: torch.randn_like(x_start))
17
+
18
+ if self.sam_config['p_losses_sam']:
19
+ _sam_mask = F.interpolate(sam_mask, noise.shape[2:], mode='bilinear')
20
+ if self.sam_config.get('mask_coefficient', False):
21
+ _sam_mask *= extract(self.mask_coefficient.to(_sam_mask.device), t, x_start.shape)
22
+ noise += _sam_mask
23
+
24
+ x_tp1_gt = self.q_sample(x_start=x_start, t=t, noise=noise)
25
+ x_t_gt = self.q_sample(x_start=x_start, t=t - 1, noise=noise)
26
+ noise_pred = self.denoise_fn(x_tp1_gt, t, cond, img_lr_up, sam_mask=sam_mask)
27
+ x_t_pred, x0_pred = self.p_sample(x_tp1_gt, t, cond, img_lr_up, noise_pred=noise_pred, sam_mask=sam_mask)
28
+
29
+ if self.loss_type == 'l1':
30
+ loss = (noise - noise_pred).abs().mean()
31
+ elif self.loss_type == 'l2':
32
+ loss = F.mse_loss(noise, noise_pred)
33
+ elif self.loss_type == 'ssim':
34
+ loss = (noise - noise_pred).abs().mean()
35
+ loss = loss + (1 - self.ssim_loss(noise, noise_pred))
36
+ else:
37
+ raise NotImplementedError()
38
+ return loss, x_tp1_gt, noise_pred, x_t_pred, x_t_gt, x0_pred
39
+
40
+ @torch.no_grad()
41
+ def p_sample(self, x, t, cond, img_lr_up, noise_pred=None, clip_denoised=True, repeat_noise=False, sam_mask=None):
42
+ if noise_pred is None:
43
+ noise_pred = self.denoise_fn(x, t, cond=cond, img_lr_up=img_lr_up, sam_mask=sam_mask)
44
+ b, *_, device = *x.shape, x.device
45
+ model_mean, _, model_log_variance, x0_pred = self.p_mean_variance(
46
+ x=x, t=t, noise_pred=noise_pred, clip_denoised=clip_denoised)
47
+
48
+ noise = noise_like(x.shape, device, repeat_noise)
49
+
50
+ # no noise when t == 0
51
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
52
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0_pred
53
+
54
+ @torch.no_grad()
55
+ def sample(self, img_lr, img_lr_up, shape, sam_mask=None, save_intermediate=False):
56
+ device = self.betas.device
57
+ b = shape[0]
58
+
59
+ if not hparams['res']:
60
+ t = torch.full((b,), self.num_timesteps - 1, device=device, dtype=torch.long)
61
+ noise = None
62
+ img = self.q_sample(img_lr_up, t, noise=noise)
63
+ else:
64
+ img = torch.randn(shape, device=device)
65
+
66
+ if hparams['use_rrdb']:
67
+ rrdb_out, cond = self.rrdb(img_lr, True)
68
+ else:
69
+ rrdb_out = img_lr_up
70
+ cond = img_lr
71
+
72
+ it = reversed(range(0, self.num_timesteps))
73
+
74
+ if self.sample_tqdm:
75
+ it = tqdm(it, desc='sampling loop time step', total=self.num_timesteps)
76
+
77
+ images = []
78
+ for i in it:
79
+ img, x_recon = self.p_sample(
80
+ img, torch.full((b,), i, device=device, dtype=torch.long), cond, img_lr_up, sam_mask=sam_mask)
81
+ if save_intermediate:
82
+ img_ = self.res2img(img, img_lr_up)
83
+ x_recon_ = self.res2img(x_recon, img_lr_up)
84
+ images.append((img_.cpu(), x_recon_.cpu()))
85
+ img = self.res2img(img, img_lr_up)
86
+
87
+ if save_intermediate:
88
+ return img, rrdb_out, images
89
+ else:
90
+ return img, rrdb_out
sam_diffsr/models_sr/module_util.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ from torch import nn
3
+ from torch.nn import init
4
+
5
+
6
+ def exists(x):
7
+ return x is not None
8
+
9
+
10
+ def default(val, d):
11
+ if exists(val):
12
+ return val
13
+ return d() if isfunction(d) else d
14
+
15
+
16
+ def cycle(dl):
17
+ while True:
18
+ for data in dl:
19
+ yield data
20
+
21
+
22
+ def num_to_groups(num, divisor):
23
+ groups = num // divisor
24
+ remainder = num % divisor
25
+ arr = [divisor] * groups
26
+ if remainder > 0:
27
+ arr.append(remainder)
28
+ return arr
29
+
30
+
31
+ def initialize_weights(net_l, scale=0.1):
32
+ if not isinstance(net_l, list):
33
+ net_l = [net_l]
34
+ for net in net_l:
35
+ for m in net.modules():
36
+ if isinstance(m, nn.Conv2d):
37
+ init.kaiming_normal_(m.weight, a=0, mode='fan_in')
38
+ m.weight.data *= scale # for residual block
39
+ if m.bias is not None:
40
+ m.bias.data.zero_()
41
+ elif isinstance(m, nn.Linear):
42
+ init.kaiming_normal_(m.weight, a=0, mode='fan_in')
43
+ m.weight.data *= scale
44
+ if m.bias is not None:
45
+ m.bias.data.zero_()
46
+ elif isinstance(m, nn.BatchNorm2d):
47
+ init.constant_(m.weight, 1)
48
+ init.constant_(m.bias.data, 0.0)
49
+
50
+
51
+ def make_layer(block, n_layers, seq=False):
52
+ layers = []
53
+ for _ in range(n_layers):
54
+ layers.append(block())
55
+ if seq:
56
+ return nn.Sequential(*layers)
57
+ else:
58
+ return nn.ModuleList(layers)
sam_diffsr/tasks/__init__.py ADDED
File without changes
sam_diffsr/tasks/infer.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import os
3
+ import sys
4
+ from collections import OrderedDict
5
+ from pathlib import Path
6
+
7
+ from tasks.srdiff_df2k import InferDataSet
8
+
9
+ parent_path = Path(__file__).absolute().parent.parent
10
+ sys.path.append(os.path.abspath(parent_path))
11
+ os.chdir(parent_path)
12
+ print(f'>-------------> parent path {parent_path}')
13
+ print(f'>-------------> current work dir {os.getcwd()}')
14
+
15
+ cache_path = os.path.join(parent_path, 'cache')
16
+ os.environ["HF_DATASETS_CACHE"] = cache_path
17
+ os.environ["TRANSFORMERS_CACHE"] = cache_path
18
+ os.environ["torch_HOME"] = cache_path
19
+
20
+ import torch
21
+ from PIL import Image
22
+ from tqdm import tqdm
23
+ from torch.utils.tensorboard import SummaryWriter
24
+ from utils_sr.hparams import hparams, set_hparams
25
+
26
+
27
+ def load_ckpt(ckpt_path, model):
28
+ checkpoint = torch.load(ckpt_path, map_location='cpu')
29
+ stat_dict = checkpoint['state_dict']['model']
30
+
31
+ new_state_dict = OrderedDict()
32
+ for k, v in stat_dict.items():
33
+ if k[:7] == 'module.':
34
+ k = k[7:] # 去掉 `module.`
35
+ new_state_dict[k] = v
36
+
37
+ model.load_state_dict(new_state_dict)
38
+ model.cuda()
39
+
40
+
41
+ def infer(trainer, ckpt_path, img_dir, save_dir):
42
+ trainer.build_model()
43
+ load_ckpt(ckpt_path, trainer.model)
44
+
45
+ dataset = InferDataSet(img_dir)
46
+ test_dataloader = torch.utils.data.DataLoader(
47
+ dataset, batch_size=hparams['eval_batch_size'], shuffle=False, pin_memory=False)
48
+
49
+ torch.backends.cudnn.benchmark = False
50
+
51
+ with torch.no_grad():
52
+ trainer.model.eval()
53
+ pbar = tqdm(enumerate(test_dataloader), total=len(test_dataloader))
54
+ for batch_idx, batch in pbar:
55
+ img_lr, img_lr_up, img_name = batch
56
+
57
+ img_lr = img_lr.to('cuda')
58
+ img_lr_up = img_lr_up.to('cuda')
59
+
60
+ img_sr, _ = trainer.model.sample(img_lr, img_lr_up, img_lr_up.shape)
61
+
62
+ img_sr = img_sr.clamp(-1, 1)
63
+ img_sr = trainer.tensor2img(img_sr)[0]
64
+ img_sr = Image.fromarray(img_sr)
65
+ img_sr.save(os.path.join(save_dir, img_name[0]))
66
+
67
+
68
+ if __name__ == '__main__':
69
+ set_hparams()
70
+
71
+ img_dir = hparams['img_dir']
72
+ save_dir = hparams['save_dir']
73
+ ckpt_path = hparams['ckpt_path']
74
+
75
+ pkg = ".".join(hparams["trainer_cls"].split(".")[:-1])
76
+ cls_name = hparams["trainer_cls"].split(".")[-1]
77
+ trainer = getattr(importlib.import_module(pkg), cls_name)()
78
+
79
+ os.makedirs(save_dir, exist_ok=True)
80
+
81
+ infer(trainer, ckpt_path, img_dir, save_dir)
sam_diffsr/tasks/rrdb.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from models_sr.diffsr_modules import RRDBNet
5
+ from tasks.srdiff_df2k import Df2kDataSet
6
+ from tasks.trainer import Trainer
7
+ from utils_sr.hparams import hparams
8
+ from utils_sr.sr_utils import PerceptualLoss
9
+
10
+
11
+ class RRDBTask(Trainer):
12
+ def __init__(self):
13
+ super().__init__()
14
+ if 'rrdb_loss' in hparams and hparams['rrdb_loss']['percep_loss']:
15
+ self.percep_loss_fn = PerceptualLoss()
16
+ self.percep_loss_weight = hparams['rrdb_loss']['percep_loss_weight']
17
+ else:
18
+ self.percep_loss_fn = None
19
+ self.percep_loss_weight = 0
20
+
21
+ def build_model(self):
22
+ hidden_size = hparams['hidden_size']
23
+ self.model = RRDBNet(3, 3, hidden_size, hparams['num_block'], hidden_size // 2)
24
+ return self.model
25
+
26
+ def build_optimizer(self, model):
27
+ return torch.optim.Adam(model.parameters(), lr=hparams['lr'])
28
+
29
+ def build_scheduler(self, optimizer):
30
+ return torch.optim.lr_scheduler.StepLR(optimizer, 200000, 0.5)
31
+
32
+ def training_step(self, sample):
33
+ img_hr = sample['img_hr']
34
+ img_lr = sample['img_lr']
35
+ p = self.model(img_lr)
36
+ total_loss = 0
37
+ loss = F.l1_loss(p, img_hr, reduction='mean')
38
+ total_loss += loss
39
+
40
+ if self.percep_loss_fn:
41
+ loss_percep = self.percep_loss_fn(img_hr, p) * self.percep_loss_weight
42
+ total_loss += loss_percep
43
+ return {'l': loss, 'loss_percep': loss_percep, 'total_loss': total_loss,
44
+ 'lr': self.scheduler.get_last_lr()[0]}, total_loss
45
+ else:
46
+ return {'l': loss, 'lr': self.scheduler.get_last_lr()[0]}, total_loss
47
+
48
+ def sample_and_test(self, sample):
49
+ ret = {k: 0 for k in self.metric_keys}
50
+ ret['n_samples'] = 0
51
+ img_hr = sample['img_hr']
52
+ img_lr = sample['img_lr']
53
+ img_sr = self.model(img_lr)
54
+ img_sr = img_sr.clamp(-1, 1)
55
+ for b in range(img_sr.shape[0]):
56
+ s = self.measure.measure(img_sr[b], img_hr[b], img_lr[b], hparams['sr_scale'])
57
+ ret['psnr'] += s['psnr']
58
+ ret['ssim'] += s['ssim']
59
+ ret['lpips'] += s['lpips']
60
+ ret['lr_psnr'] += s['lr_psnr']
61
+ ret['n_samples'] += 1
62
+ return img_sr, img_sr, ret
63
+
64
+
65
+ class RRDBDf2kTask(RRDBTask):
66
+ def __init__(self):
67
+ super().__init__()
68
+ self.dataset_cls = Df2kDataSet
sam_diffsr/tasks/rrdb_sam.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from models_sr.diffsr_modules import RRDBNet
5
+ from tasks.srdiff_df2k_sam import Df2kDataSet_sam
6
+ from tasks.trainer import Trainer
7
+ from utils_sr.hparams import hparams
8
+
9
+
10
+ class RRDBTask_sam(Trainer):
11
+ def build_model(self):
12
+ hidden_size = hparams['hidden_size']
13
+ self.model = RRDBNet(3, 3, hidden_size, hparams['num_block'], hidden_size // 2)
14
+ return self.model
15
+
16
+ def build_optimizer(self, model):
17
+ return torch.optim.Adam(model.parameters(), lr=hparams['lr'])
18
+
19
+ def build_scheduler(self, optimizer):
20
+ return torch.optim.lr_scheduler.StepLR(optimizer, 200000, 0.5)
21
+
22
+ def training_step(self, sample):
23
+ img_hr = sample['img_hr']
24
+ img_lr = sample['img_lr']
25
+ p = self.model(img_lr)
26
+ loss = F.l1_loss(p, img_hr, reduction='mean')
27
+ return {'l': loss, 'lr': self.scheduler.get_last_lr()[0]}, loss
28
+
29
+ def sample_and_test(self, sample):
30
+ ret = {k: 0 for k in self.metric_keys}
31
+ ret['n_samples'] = 0
32
+ img_hr = sample['img_hr']
33
+ img_lr = sample['img_lr']
34
+ img_sr = self.model(img_lr)
35
+ img_sr = img_sr.clamp(-1, 1)
36
+ for b in range(img_sr.shape[0]):
37
+ s = self.measure.measure(img_sr[b], img_hr[b], img_lr[b], hparams['sr_scale'])
38
+ ret['psnr'] += s['psnr']
39
+ ret['ssim'] += s['ssim']
40
+ ret['lpips'] += s['lpips']
41
+ ret['lr_psnr'] += s['lr_psnr']
42
+ ret['n_samples'] += 1
43
+ return img_sr, img_sr, ret
44
+
45
+
46
+ class RRDBDf2kTask_sam(RRDBTask_sam):
47
+ def __init__(self):
48
+ super().__init__()
49
+ self.dataset_cls = Df2kDataSet_sam
sam_diffsr/tasks/srdiff.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+
3
+ import torch
4
+
5
+ from sam_diffsr.models_sr.diffsr_modules import Unet, RRDBNet
6
+ from sam_diffsr.models_sr.diffusion import GaussianDiffusion
7
+ from sam_diffsr.tasks.trainer import Trainer
8
+ from sam_diffsr.utils_sr.hparams import hparams
9
+ from sam_diffsr.utils_sr.utils import load_ckpt
10
+
11
+
12
+ class SRDiffTrainer(Trainer):
13
+ def build_model(self):
14
+ hidden_size = hparams['hidden_size']
15
+ dim_mults = hparams['unet_dim_mults']
16
+ dim_mults = [int(x) for x in dim_mults.split('|')]
17
+ denoise_fn = Unet(
18
+ hidden_size, out_dim=3, cond_dim=hparams['rrdb_num_feat'], dim_mults=dim_mults)
19
+ if hparams['use_rrdb']:
20
+ rrdb = RRDBNet(3, 3, hparams['rrdb_num_feat'], hparams['rrdb_num_block'],
21
+ hparams['rrdb_num_feat'] // 2)
22
+ if hparams['rrdb_ckpt'] != '' and os.path.exists(hparams['rrdb_ckpt']):
23
+ load_ckpt(rrdb, hparams['rrdb_ckpt'])
24
+ else:
25
+ rrdb = None
26
+ self.model = GaussianDiffusion(
27
+ denoise_fn=denoise_fn,
28
+ rrdb_net=rrdb,
29
+ timesteps=hparams['timesteps'],
30
+ loss_type=hparams['loss_type']
31
+ )
32
+ self.global_step = 0
33
+ return self.model
34
+
35
+ def sample_and_test(self, sample):
36
+ ret = {k: 0 for k in self.metric_keys}
37
+ ret['n_samples'] = 0
38
+ img_hr = sample['img_hr']
39
+ img_lr = sample['img_lr']
40
+ img_lr_up = sample['img_lr_up']
41
+ img_sr, rrdb_out = self.model.sample(img_lr, img_lr_up, img_hr.shape)
42
+ for b in range(img_sr.shape[0]):
43
+ s = self.measure.measure(img_sr[b], img_hr[b], img_lr[b], hparams['sr_scale'])
44
+ ret['psnr'] += s['psnr']
45
+ ret['ssim'] += s['ssim']
46
+ ret['lpips'] += s['lpips']
47
+ ret['lr_psnr'] += s['lr_psnr']
48
+ ret['n_samples'] += 1
49
+ return img_sr, rrdb_out, ret
50
+
51
+ def build_optimizer(self, model):
52
+ params = list(model.named_parameters())
53
+ if not hparams['fix_rrdb']:
54
+ params = [p for p in params if 'rrdb' not in p[0]]
55
+ params = [p[1] for p in params]
56
+ return torch.optim.Adam(params, lr=hparams['lr'])
57
+
58
+ def build_scheduler(self, optimizer):
59
+ if 'scheduler' in hparams:
60
+ scheduler_config = hparams['scheduler']
61
+ if scheduler_config['type'] == 'cosine':
62
+ lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, hparams['max_updates'],
63
+ eta_min=scheduler_config['eta_min'])
64
+
65
+ else:
66
+ lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, hparams['decay_steps'], gamma=0.5)
67
+
68
+ return lr_scheduler
69
+
70
+ def training_step(self, batch):
71
+ img_hr = batch['img_hr']
72
+ img_lr = batch['img_lr']
73
+ img_lr_up = batch['img_lr_up']
74
+ losses, _, _ = self.model(img_hr, img_lr, img_lr_up)
75
+ total_loss = sum(losses.values())
76
+ return losses, total_loss
sam_diffsr/tasks/srdiff_df2k.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import numpy as np
5
+ from PIL import Image
6
+ from torch.utils.data import Dataset
7
+ from torchvision import transforms
8
+
9
+ from sam_diffsr.tasks.srdiff import SRDiffTrainer
10
+ from sam_diffsr.utils_sr.dataset import SRDataSet
11
+ from sam_diffsr.utils_sr.hparams import hparams
12
+ from sam_diffsr.utils_sr.matlab_resize import imresize
13
+
14
+
15
+ class InferDataSet(Dataset):
16
+ def __init__(self, img_dir):
17
+ super().__init__()
18
+
19
+ self.img_path_list = [os.path.join(img_dir, img_name) for img_name in os.listdir(img_dir)]
20
+ self.to_tensor_norm = transforms.Compose([
21
+ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
22
+ ])
23
+
24
+ def __getitem__(self, index):
25
+ sr_scale = hparams['sr_scale']
26
+
27
+ img_path = self.img_path_list[index]
28
+ img_name = os.path.basename(img_path)
29
+
30
+ img_lr = Image.open(img_path).convert('RGB')
31
+ img_lr = np.uint8(np.asarray(img_lr))
32
+
33
+ h, w, c = img_lr.shape
34
+ h, w = h * sr_scale, w * sr_scale
35
+ h = h - h % (sr_scale * 2)
36
+ w = w - w % (sr_scale * 2)
37
+ h_l = h // sr_scale
38
+ w_l = w // sr_scale
39
+
40
+ img_lr = img_lr[:h_l, :w_l]
41
+
42
+ img_lr_up = imresize(img_lr / 256, hparams['sr_scale']) # np.float [H, W, C]
43
+ img_lr, img_lr_up = [self.to_tensor_norm(x).float() for x in [img_lr, img_lr_up]]
44
+
45
+ return img_lr, img_lr_up, img_name
46
+
47
+ def __len__(self):
48
+ return len(self.img_path_list)
49
+
50
+
51
+ class Df2kDataSet(SRDataSet):
52
+ def __init__(self, prefix='train'):
53
+ if prefix == 'valid':
54
+ _prefix = 'test'
55
+ else:
56
+ _prefix = prefix
57
+
58
+ super().__init__(_prefix)
59
+ self.patch_size = hparams['patch_size']
60
+ self.patch_size_lr = hparams['patch_size'] // hparams['sr_scale']
61
+ if prefix == 'valid':
62
+ self.len = hparams['eval_batch_size'] * hparams['valid_steps']
63
+
64
+ self.data_aug_transforms = transforms.Compose([
65
+ transforms.RandomHorizontalFlip(),
66
+ transforms.RandomRotation(20, resample=Image.BICUBIC),
67
+ transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
68
+ ])
69
+
70
+ def __getitem__(self, index):
71
+ item = self._get_item(index)
72
+ hparams = self.hparams
73
+ sr_scale = hparams['sr_scale']
74
+
75
+ img_hr = np.uint8(item['img'])
76
+ img_lr = np.uint8(item['img_lr'])
77
+
78
+ # TODO: clip for SRFlow
79
+ h, w, c = img_hr.shape
80
+ h = h - h % (sr_scale * 2)
81
+ w = w - w % (sr_scale * 2)
82
+ h_l = h // sr_scale
83
+ w_l = w // sr_scale
84
+ img_hr = img_hr[:h, :w]
85
+ img_lr = img_lr[:h_l, :w_l]
86
+ # random crop
87
+ if self.prefix == 'train':
88
+ if self.data_augmentation and random.random() < 0.5:
89
+ img_hr, img_lr = self.data_augment(img_hr, img_lr)
90
+ i = random.randint(0, h - self.patch_size) // sr_scale * sr_scale
91
+ i_lr = i // sr_scale
92
+ j = random.randint(0, w - self.patch_size) // sr_scale * sr_scale
93
+ j_lr = j // sr_scale
94
+ img_hr = img_hr[i:i + self.patch_size, j:j + self.patch_size]
95
+ img_lr = img_lr[i_lr:i_lr + self.patch_size_lr, j_lr:j_lr + self.patch_size_lr]
96
+ img_lr_up = imresize(img_lr / 256, hparams['sr_scale']) # np.float [H, W, C]
97
+ img_hr, img_lr, img_lr_up = [self.to_tensor_norm(x).float() for x in [img_hr, img_lr, img_lr_up]]
98
+ return {
99
+ 'img_hr': img_hr, 'img_lr': img_lr,
100
+ 'img_lr_up': img_lr_up, 'item_name': item['item_name'],
101
+ 'loc': np.array(item['loc']), 'loc_bdr': np.array(item['loc_bdr'])
102
+ }
103
+
104
+ def __len__(self):
105
+ return self.len
106
+
107
+ def data_augment(self, img_hr, img_lr):
108
+ sr_scale = self.hparams['sr_scale']
109
+ img_hr = Image.fromarray(img_hr)
110
+ img_hr = self.data_aug_transforms(img_hr)
111
+ img_hr = np.asarray(img_hr) # np.uint8 [H, W, C]
112
+ img_lr = imresize(img_hr, 1 / sr_scale)
113
+ return img_hr, img_lr
114
+
115
+
116
+ class SRDiffDf2k(SRDiffTrainer):
117
+ def __init__(self):
118
+ super().__init__()
119
+ self.dataset_cls = Df2kDataSet
sam_diffsr/tasks/srdiff_df2k_sam.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image
8
+ from rotary_embedding_torch import RotaryEmbedding
9
+ from torchvision import transforms
10
+
11
+ from sam_diffsr.models_sr.diffsr_modules import RRDBNet, Unet
12
+ from sam_diffsr.models_sr.diffusion_sam import GaussianDiffusion_sam
13
+ from sam_diffsr.tasks.srdiff import SRDiffTrainer
14
+ from sam_diffsr.utils_sr.dataset import SRDataSet
15
+ from sam_diffsr.utils_sr.hparams import hparams
16
+ from sam_diffsr.utils_sr.indexed_datasets import IndexedDataset
17
+ from sam_diffsr.utils_sr.matlab_resize import imresize
18
+ from sam_diffsr.utils_sr.utils import load_ckpt
19
+
20
+
21
+ def normalize_01(data):
22
+ mu = np.mean(data)
23
+ sigma = np.std(data)
24
+
25
+ if sigma == 0.:
26
+ return data - mu
27
+ else:
28
+ return (data - mu) / sigma
29
+
30
+
31
+ def normalize_11(data):
32
+ mu = np.mean(data)
33
+ sigma = np.std(data)
34
+
35
+ if sigma == 0.:
36
+ return data - mu
37
+ else:
38
+ return (data - mu) / sigma - 1
39
+
40
+
41
+ class Df2kDataSet_sam(SRDataSet):
42
+ def __init__(self, prefix='train'):
43
+
44
+ if prefix == 'valid':
45
+ _prefix = 'test'
46
+ else:
47
+ _prefix = prefix
48
+
49
+ super().__init__(_prefix)
50
+
51
+ self.patch_size = hparams['patch_size']
52
+ self.patch_size_lr = hparams['patch_size'] // hparams['sr_scale']
53
+ if prefix == 'valid':
54
+ self.len = hparams['eval_batch_size'] * hparams['valid_steps']
55
+
56
+ self.data_position_aug_transforms = transforms.Compose([
57
+ transforms.RandomHorizontalFlip(),
58
+ transforms.RandomRotation(20, interpolation=Image.BICUBIC),
59
+ ])
60
+
61
+ self.data_color_aug_transforms = transforms.Compose([
62
+ transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
63
+ ])
64
+
65
+ self.sam_config = hparams.get('sam_config', False)
66
+
67
+ if self.sam_config.get('mask_RoPE', False):
68
+ h, w = map(int, self.sam_config['mask_RoPE_shape'].split('-'))
69
+ rotary_emb = RotaryEmbedding(dim=h)
70
+ sam_mask = rotary_emb.rotate_queries_or_keys(torch.ones(1, 1, w, h))
71
+ self.RoPE_mask = sam_mask.cpu().numpy()[0, 0, ...]
72
+
73
+ def _get_item(self, index):
74
+ if self.indexed_ds is None:
75
+ self.indexed_ds = IndexedDataset(f'{self.data_dir}/{self.prefix}')
76
+ return self.indexed_ds[index]
77
+
78
+ def __getitem__(self, index):
79
+ item = self._get_item(index)
80
+ hparams = self.hparams
81
+ sr_scale = hparams['sr_scale']
82
+
83
+ img_hr = np.uint8(item['img'])
84
+ img_lr = np.uint8(item['img_lr'])
85
+
86
+ if self.sam_config.get('mask_RoPE', False):
87
+ sam_mask = self.RoPE_mask
88
+ else:
89
+ if 'sam_mask' in item:
90
+ sam_mask = item['sam_mask']
91
+ if sam_mask.shape != img_hr.shape[:2]:
92
+ sam_mask = cv2.resize(sam_mask, dsize=img_hr.shape[:2][::-1])
93
+ else:
94
+ sam_mask = np.zeros_like(img_lr)
95
+
96
+ # TODO: clip for SRFlow
97
+ h, w, c = img_hr.shape
98
+ h = h - h % (sr_scale * 2)
99
+ w = w - w % (sr_scale * 2)
100
+ h_l = h // sr_scale
101
+ w_l = w // sr_scale
102
+ img_hr = img_hr[:h, :w]
103
+ sam_mask = sam_mask[:h, :w]
104
+ img_lr = img_lr[:h_l, :w_l]
105
+
106
+ # random crop
107
+ if self.prefix == 'train':
108
+ if self.data_augmentation and random.random() < 0.5:
109
+ img_hr, img_lr, sam_mask = self.data_augment(img_hr, img_lr, sam_mask)
110
+ i = random.randint(0, h - self.patch_size) // sr_scale * sr_scale
111
+ i_lr = i // sr_scale
112
+ j = random.randint(0, w - self.patch_size) // sr_scale * sr_scale
113
+ j_lr = j // sr_scale
114
+ img_hr = img_hr[i:i + self.patch_size, j:j + self.patch_size]
115
+ sam_mask = sam_mask[i:i + self.patch_size, j:j + self.patch_size]
116
+ img_lr = img_lr[i_lr:i_lr + self.patch_size_lr, j_lr:j_lr + self.patch_size_lr]
117
+
118
+ img_lr_up = imresize(img_lr / 256, hparams['sr_scale']) # np.float [H, W, C]
119
+ img_hr, img_lr, img_lr_up = [self.to_tensor_norm(x).float() for x in [img_hr, img_lr, img_lr_up]]
120
+
121
+ if hparams['sam_data_config']['all_same_mask_to_zero']:
122
+ if len(np.unique(sam_mask)) == 1:
123
+ sam_mask = np.zeros_like(sam_mask)
124
+
125
+ if hparams['sam_data_config']['normalize_01']:
126
+ if len(np.unique(sam_mask)) != 1:
127
+ sam_mask = normalize_01(sam_mask)
128
+
129
+ if hparams['sam_data_config']['normalize_11']:
130
+ if len(np.unique(sam_mask)) != 1:
131
+ sam_mask = normalize_11(sam_mask)
132
+
133
+ sam_mask = torch.FloatTensor(sam_mask).unsqueeze(dim=0)
134
+
135
+ return {
136
+ 'img_hr': img_hr, 'img_lr': img_lr,
137
+ 'img_lr_up': img_lr_up, 'item_name': item['item_name'],
138
+ 'loc': np.array(item['loc']), 'loc_bdr': np.array(item['loc_bdr']),
139
+ 'sam_mask': sam_mask
140
+ }
141
+
142
+ def __len__(self):
143
+ return self.len
144
+
145
+ def data_augment(self, img_hr, img_lr, sam_mask):
146
+ sr_scale = self.hparams['sr_scale']
147
+ img_hr = Image.fromarray(img_hr)
148
+ img_hr, sam_mask = self.data_position_aug_transforms([img_hr, sam_mask])
149
+ img_hr = self.data_color_aug_transforms(img_hr)
150
+ img_hr = np.asarray(img_hr) # np.uint8 [H, W, C]
151
+ img_lr = imresize(img_hr, 1 / sr_scale)
152
+ return img_hr, img_lr, sam_mask
153
+
154
+
155
+ class SRDiffDf2k_sam(SRDiffTrainer):
156
+ def __init__(self):
157
+ super().__init__()
158
+ self.dataset_cls = Df2kDataSet_sam
159
+ self.sam_config = hparams['sam_config']
160
+
161
+ def build_model(self):
162
+ hidden_size = hparams['hidden_size']
163
+ dim_mults = hparams['unet_dim_mults']
164
+ dim_mults = [int(x) for x in dim_mults.split('|')]
165
+
166
+ denoise_fn = Unet(
167
+ hidden_size, out_dim=3, cond_dim=hparams['rrdb_num_feat'], dim_mults=dim_mults)
168
+ if hparams['use_rrdb']:
169
+ rrdb = RRDBNet(3, 3, hparams['rrdb_num_feat'], hparams['rrdb_num_block'],
170
+ hparams['rrdb_num_feat'] // 2)
171
+ if hparams['rrdb_ckpt'] != '' and os.path.exists(hparams['rrdb_ckpt']):
172
+ load_ckpt(rrdb, hparams['rrdb_ckpt'])
173
+ else:
174
+ rrdb = None
175
+ self.model = GaussianDiffusion_sam(
176
+ denoise_fn=denoise_fn,
177
+ rrdb_net=rrdb,
178
+ timesteps=hparams['timesteps'],
179
+ loss_type=hparams['loss_type'],
180
+ sam_config=hparams['sam_config']
181
+ )
182
+ self.global_step = 0
183
+ return self.model
184
+
185
+ # def sample_and_test(self, sample):
186
+ # ret = {k: 0 for k in self.metric_keys}
187
+ # ret['n_samples'] = 0
188
+ # img_hr = sample['img_hr']
189
+ # img_lr = sample['img_lr']
190
+ # img_lr_up = sample['img_lr_up']
191
+ # sam_mask = sample['sam_mask']
192
+ #
193
+ # img_sr, rrdb_out = self.model.sample(img_lr, img_lr_up, img_hr.shape, sam_mask=sam_mask)
194
+ #
195
+ # for b in range(img_sr.shape[0]):
196
+ # s = self.measure.measure(img_sr[b], img_hr[b], img_lr[b], hparams['sr_scale'])
197
+ # ret['psnr'] += s['psnr']
198
+ # ret['ssim'] += s['ssim']
199
+ # ret['lpips'] += s['lpips']
200
+ # ret['lr_psnr'] += s['lr_psnr']
201
+ # ret['n_samples'] += 1
202
+ # return img_sr, rrdb_out, ret
203
+
204
+ def training_step(self, batch):
205
+ img_hr = batch['img_hr']
206
+ img_lr = batch['img_lr']
207
+ img_lr_up = batch['img_lr_up']
208
+ sam_mask = batch['sam_mask']
209
+ losses, _, _ = self.model(img_hr, img_lr, img_lr_up, sam_mask=sam_mask)
210
+ total_loss = sum(losses.values())
211
+ return losses, total_loss
sam_diffsr/tasks/trainer.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import json
3
+ import os
4
+ import subprocess
5
+ import sys
6
+ from collections import OrderedDict
7
+ from pathlib import Path
8
+
9
+ parent_path = Path(__file__).absolute().parent.parent
10
+ sys.path.append(os.path.abspath(parent_path))
11
+ os.chdir(parent_path)
12
+ print(f'>-------------> parent path {parent_path}')
13
+ print(f'>-------------> current work dir {os.getcwd()}')
14
+
15
+ cache_path = os.path.join(parent_path, 'cache')
16
+ os.environ["HF_DATASETS_CACHE"] = cache_path
17
+ os.environ["TRANSFORMERS_CACHE"] = cache_path
18
+ os.environ["torch_HOME"] = cache_path
19
+
20
+ import torch
21
+ from PIL import Image
22
+ from tqdm import tqdm
23
+ import numpy as np
24
+ from torch.utils.tensorboard import SummaryWriter
25
+ from sam_diffsr.utils_sr.hparams import hparams, set_hparams
26
+ from sam_diffsr.utils_sr.utils import plot_img, move_to_cuda, load_checkpoint, save_checkpoint, tensors_to_scalars, Measure, \
27
+ get_all_ckpts
28
+
29
+
30
+
31
+ class Trainer:
32
+ def __init__(self):
33
+ self.logger = self.build_tensorboard(save_dir=hparams['work_dir'], name='tb_logs')
34
+ self.measure = Measure()
35
+ self.dataset_cls = None
36
+ self.metric_keys = ['psnr', 'ssim', 'lpips', 'lr_psnr']
37
+ self.metric_2_keys = ['psnr-Y', 'ssim', 'fid']
38
+ self.work_dir = hparams['work_dir']
39
+ self.first_val = True
40
+
41
+ self.val_steps = hparams['val_steps']
42
+
43
+ def build_tensorboard(self, save_dir, name, **kwargs):
44
+ log_dir = os.path.join(save_dir, name)
45
+ os.makedirs(log_dir, exist_ok=True)
46
+ return SummaryWriter(log_dir=log_dir, **kwargs)
47
+
48
+ def build_train_dataloader(self):
49
+ dataset = self.dataset_cls('train')
50
+ return torch.utils.data.DataLoader(
51
+ dataset, batch_size=hparams['batch_size'], shuffle=True,
52
+ pin_memory=False, num_workers=hparams['num_workers'])
53
+
54
+ def build_val_dataloader(self):
55
+ return torch.utils.data.DataLoader(
56
+ self.dataset_cls('valid'), batch_size=hparams['eval_batch_size'], shuffle=False, pin_memory=False)
57
+
58
+ def build_test_dataloader(self):
59
+ return torch.utils.data.DataLoader(
60
+ self.dataset_cls('test'), batch_size=hparams['eval_batch_size'], shuffle=False, pin_memory=False)
61
+
62
+ def build_model(self):
63
+ raise NotImplementedError
64
+
65
+ def sample_and_test(self, sample):
66
+ raise NotImplementedError
67
+
68
+ def build_optimizer(self, model):
69
+ raise NotImplementedError
70
+
71
+ def build_scheduler(self, optimizer):
72
+ raise NotImplementedError
73
+
74
+ def training_step(self, batch):
75
+ raise NotImplementedError
76
+
77
+ def train(self):
78
+ model = self.build_model()
79
+ optimizer = self.build_optimizer(model)
80
+ self.global_step = training_step = load_checkpoint(model, optimizer, hparams['work_dir'], steps=self.val_steps)
81
+ self.scheduler = scheduler = self.build_scheduler(optimizer)
82
+ scheduler.step(training_step)
83
+ dataloader = self.build_train_dataloader()
84
+
85
+ train_pbar = tqdm(dataloader, initial=training_step, total=float('inf'),
86
+ dynamic_ncols=True, unit='step')
87
+ while self.global_step < hparams['max_updates']:
88
+ for batch in train_pbar:
89
+ if training_step % hparams['val_check_interval'] == 0:
90
+ with torch.no_grad():
91
+ model.eval()
92
+ self.validate(training_step)
93
+ save_checkpoint(model, optimizer, self.work_dir, training_step, hparams['num_ckpt_keep'])
94
+ model.train()
95
+ batch = move_to_cuda(batch)
96
+ losses, total_loss = self.training_step(batch)
97
+ optimizer.zero_grad()
98
+
99
+ total_loss.backward()
100
+ optimizer.step()
101
+ training_step += 1
102
+ scheduler.step(training_step)
103
+ self.global_step = training_step
104
+ if training_step % 100 == 0:
105
+ self.log_metrics({f'tr/{k}': v for k, v in losses.items()}, training_step)
106
+ train_pbar.set_postfix(**tensors_to_scalars(losses))
107
+
108
+ def validate(self, training_step):
109
+ val_dataloader = self.build_val_dataloader()
110
+ pbar = tqdm(enumerate(val_dataloader), total=len(val_dataloader))
111
+ metrics = {}
112
+ for batch_idx, batch in pbar:
113
+ # 每次运行的第一次validation只跑一小部分数据,来验证代码能否跑通
114
+ if self.first_val and batch_idx > hparams['num_sanity_val_steps'] - 1:
115
+ break
116
+ batch = move_to_cuda(batch)
117
+ img, rrdb_out, ret = self.sample_and_test(batch)
118
+ img_hr = batch['img_hr']
119
+ img_lr = batch['img_lr']
120
+ img_lr_up = batch['img_lr_up']
121
+ if img is not None:
122
+ self.logger.add_image(f'Pred_{batch_idx}', plot_img(img[0]), self.global_step)
123
+ if hparams.get('aux_l1_loss'):
124
+ self.logger.add_image(f'rrdb_out_{batch_idx}', plot_img(rrdb_out[0]), self.global_step)
125
+ if self.global_step <= hparams['val_check_interval']:
126
+ self.logger.add_image(f'HR_{batch_idx}', plot_img(img_hr[0]), self.global_step)
127
+ self.logger.add_image(f'LR_{batch_idx}', plot_img(img_lr[0]), self.global_step)
128
+ self.logger.add_image(f'BL_{batch_idx}', plot_img(img_lr_up[0]), self.global_step)
129
+ metrics = {}
130
+ metrics.update({k: np.mean(ret[k]) for k in self.metric_keys})
131
+ pbar.set_postfix(**tensors_to_scalars(metrics))
132
+ if hparams['infer']:
133
+ print('Val results:', metrics)
134
+ else:
135
+ if not self.first_val:
136
+ self.log_metrics({f'val/{k}': v for k, v in metrics.items()}, training_step)
137
+ print('Val results:', metrics)
138
+ else:
139
+ print('Sanity val results:', metrics)
140
+ self.first_val = False
141
+
142
+ def build_test_my_dataloader(self, data_name):
143
+ return torch.utils.data.DataLoader(
144
+ self.dataset_cls(data_name), batch_size=hparams['eval_batch_size'], shuffle=False, pin_memory=False)
145
+
146
+ def benchmark(self, benchmark_name_list, metric_list):
147
+ from sam_diffsr.tools.caculate_iqa import eval_img_IQA
148
+
149
+ model = self.build_model()
150
+ optimizer = self.build_optimizer(model)
151
+ training_step = load_checkpoint(model, optimizer, hparams['work_dir'], hparams['val_steps'])
152
+ self.global_step = training_step
153
+
154
+ optimizer = None
155
+
156
+ for data_name in benchmark_name_list:
157
+ test_dataloader = self.build_test_my_dataloader(data_name)
158
+
159
+ self.results = {k: 0 for k in self.metric_keys}
160
+ self.n_samples = 0
161
+ self.gen_dir = f"{hparams['work_dir']}/results_{self.global_step}_{hparams['gen_dir_name']}/benchmark/{data_name}"
162
+ if hparams['test_save_png']:
163
+ subprocess.check_call(f'rm -rf {self.gen_dir}', shell=True)
164
+ os.makedirs(f'{self.gen_dir}/outputs', exist_ok=True)
165
+ os.makedirs(f'{self.gen_dir}/SR', exist_ok=True)
166
+
167
+ self.model.sample_tqdm = False
168
+ torch.backends.cudnn.benchmark = False
169
+ if hparams['test_save_png']:
170
+ if hasattr(self.model.denoise_fn, 'make_generation_fast_'):
171
+ self.model.denoise_fn.make_generation_fast_()
172
+ os.makedirs(f'{self.gen_dir}/HR', exist_ok=True)
173
+
174
+ result_dict = {}
175
+
176
+ with torch.no_grad():
177
+ model.eval()
178
+ pbar = tqdm(enumerate(test_dataloader), total=len(test_dataloader))
179
+ for batch_idx, batch in pbar:
180
+ move_to_cuda(batch)
181
+ gen_dir = self.gen_dir
182
+ item_names = batch['item_name']
183
+ img_hr = batch['img_hr']
184
+ img_lr = batch['img_lr']
185
+ img_lr_up = batch['img_lr_up']
186
+
187
+ res = self.sample_and_test(batch)
188
+ if len(res) == 3:
189
+ img_sr, rrdb_out, ret = res
190
+ else:
191
+ img_sr, ret = res
192
+ rrdb_out = img_sr
193
+
194
+ img_lr_up = batch.get('img_lr_up', img_lr_up)
195
+ if img_sr is not None:
196
+ metrics = list(self.metric_keys)
197
+ result_dict[batch['item_name'][0]] = {}
198
+ for k in metrics:
199
+ self.results[k] += ret[k]
200
+ result_dict[batch['item_name'][0]][k] = ret[k]
201
+ self.n_samples += ret['n_samples']
202
+
203
+ print({k: round(self.results[k] / self.n_samples, 3) for k in self.results}, 'total:',
204
+ self.n_samples)
205
+
206
+ if hparams['test_save_png'] and img_sr is not None:
207
+ img_sr = self.tensor2img(img_sr)
208
+ img_hr = self.tensor2img(img_hr)
209
+ img_lr = self.tensor2img(img_lr)
210
+ img_lr_up = self.tensor2img(img_lr_up)
211
+ rrdb_out = self.tensor2img(rrdb_out)
212
+ for item_name, hr_p, hr_g, lr, lr_up, rrdb_o in zip(
213
+ item_names, img_sr, img_hr, img_lr, img_lr_up, rrdb_out):
214
+ item_name = os.path.splitext(item_name)[0]
215
+ hr_p = Image.fromarray(hr_p)
216
+ hr_g = Image.fromarray(hr_g)
217
+ hr_p.save(f"{gen_dir}/SR/{item_name}.png")
218
+ hr_g.save(f"{gen_dir}/HR/{item_name}.png")
219
+
220
+ exp_name = hparams['work_dir'].split('/')[-1]
221
+ sr_img_dir = f"{gen_dir}/SR/"
222
+ gt_img_dir = f"{gen_dir}/HR/"
223
+ excel_path = f"{hparams['work_dir']}/IQA-val-benchmark-{exp_name}.xlsx"
224
+ epoch = training_step
225
+ eval_img_IQA(gt_img_dir, sr_img_dir, excel_path, metric_list, epoch, data_name)
226
+
227
+ os.makedirs(f'{self.gen_dir}', exist_ok=True)
228
+ eval_json_path = os.path.join(self.gen_dir, 'eval.json')
229
+ avg_result = {k: round(self.results[k] / self.n_samples, 4) for k in self.results}
230
+ with open(eval_json_path, 'w+') as file:
231
+ json.dump(avg_result, file, sort_keys=True, indent=4, separators=(',', ': '), ensure_ascii=False)
232
+ json.dump(result_dict, file, sort_keys=True, indent=4, separators=(',', ': '), ensure_ascii=False)
233
+
234
+ def benchmark_loop(self, benchmark_name_list, metric_list, gt_path):
235
+ # infer and evaluation all save checkpoint
236
+ from sam_diffsr.tools.caculate_iqa import eval_img_IQA
237
+
238
+ model = self.build_model()
239
+
240
+ def get_checkpoint(model, checkpoint):
241
+ stat_dict = checkpoint['state_dict']['model']
242
+
243
+ new_state_dict = OrderedDict()
244
+ for k, v in stat_dict.items():
245
+ if k[:7] == 'module.':
246
+ k = k[7:] # 去掉 `module.`
247
+ new_state_dict[k] = v
248
+
249
+ model.load_state_dict(new_state_dict)
250
+ model.cuda()
251
+ training_step = checkpoint['global_step']
252
+ del checkpoint
253
+ torch.cuda.empty_cache()
254
+
255
+ return training_step
256
+
257
+ ckpt_paths = get_all_ckpts(hparams['work_dir'])
258
+ for ckpt_path in ckpt_paths:
259
+ checkpoint = torch.load(ckpt_path, map_location='cpu')
260
+ training_step = get_checkpoint(model, checkpoint)
261
+
262
+ self.global_step = training_step
263
+
264
+ for data_name in benchmark_name_list:
265
+ test_dataloader = self.build_test_my_dataloader(data_name)
266
+
267
+ self.results = {k: 0 for k in self.metric_keys + self.metric_2_keys}
268
+ self.n_samples = 0
269
+ self.gen_dir = f"{hparams['work_dir']}/results_{training_step}_{hparams['gen_dir_name']}/benchmark/{data_name}"
270
+
271
+ os.makedirs(f'{self.gen_dir}/outputs', exist_ok=True)
272
+ os.makedirs(f'{self.gen_dir}/SR', exist_ok=True)
273
+
274
+ self.model.sample_tqdm = False
275
+ torch.backends.cudnn.benchmark = False
276
+
277
+ with torch.no_grad():
278
+ model.eval()
279
+ pbar = tqdm(enumerate(test_dataloader), total=len(test_dataloader))
280
+ for batch_idx, batch in pbar:
281
+ move_to_cuda(batch)
282
+ gen_dir = self.gen_dir
283
+ item_names = batch['item_name']
284
+
285
+ res = self.sample_and_test(batch)
286
+ if len(res) == 3:
287
+ img_sr, rrdb_out, ret = res
288
+ else:
289
+ img_sr, ret = res
290
+ rrdb_out = img_sr
291
+
292
+ img_sr = self.tensor2img(img_sr)
293
+
294
+ for item_name, hr_p in zip(item_names, img_sr):
295
+ item_name = os.path.splitext(item_name)[0]
296
+ hr_p = Image.fromarray(hr_p)
297
+ hr_p.save(f"{gen_dir}/SR/{item_name}.png")
298
+
299
+ exp_name = hparams['work_dir'].split('/')[-1]
300
+ sr_img_dir = f"{gen_dir}/SR/"
301
+ gt_img_dir = f"{gt_path}/{data_name}/HR"
302
+ excel_path = f"{hparams['work_dir']}/IQA-val-benchmark_loop-{exp_name}.xlsx"
303
+ epoch = training_step
304
+ eval_img_IQA(gt_img_dir, sr_img_dir, excel_path, metric_list, epoch, data_name)
305
+
306
+ # utils_sr
307
+ def log_metrics(self, metrics, step):
308
+ metrics = self.metrics_to_scalars(metrics)
309
+ logger = self.logger
310
+ for k, v in metrics.items():
311
+ if isinstance(v, torch.Tensor):
312
+ v = v.item()
313
+ logger.add_scalar(k, v, step)
314
+
315
+ def metrics_to_scalars(self, metrics):
316
+ new_metrics = {}
317
+ for k, v in metrics.items():
318
+ if isinstance(v, torch.Tensor):
319
+ v = v.item()
320
+
321
+ if type(v) is dict:
322
+ v = self.metrics_to_scalars(v)
323
+
324
+ new_metrics[k] = v
325
+
326
+ return new_metrics
327
+
328
+ @staticmethod
329
+ def tensor2img(img):
330
+ img = np.round((img.permute(0, 2, 3, 1).cpu().numpy() + 1) * 127.5)
331
+ img = img.clip(min=0, max=255).astype(np.uint8)
332
+ return img
333
+
334
+
335
+ if __name__ == '__main__':
336
+ set_hparams()
337
+
338
+ pkg = ".".join(hparams["trainer_cls"].split(".")[:-1])
339
+ cls_name = hparams["trainer_cls"].split(".")[-1]
340
+ trainer = getattr(importlib.import_module(pkg), cls_name)()
341
+ if hparams['benchmark_loop']:
342
+ trainer.benchmark_loop(hparams['benchmark_name_list'], hparams['metric_list'], hparams['gt_img_path'])
343
+ elif hparams['benchmark']:
344
+ trainer.benchmark(hparams['benchmark_name_list'], hparams['metric_list'])
345
+ else:
346
+ trainer.train()
sam_diffsr/tb_logs/events.out.tfevents.1709283169.wangchengchengdeMacBook-Pro.local.99018.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aba4ab7fc71e002fcd70117a9bb9ad042341fc80f7d51f5bd4bd9610c508a655
3
+ size 88
sam_diffsr/tb_logs/events.out.tfevents.1709284054.wangchengchengdeMacBook-Pro.local.99188.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:13225cd295f9d05736be890cd9b70fdf332846531c35760d37b7aae5ca02e584
3
+ size 88
sam_diffsr/tb_logs/events.out.tfevents.1709284076.wangchengchengdeMacBook-Pro.local.99198.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf4eadf9f990294906e556c51d92636c7af10c6aec626003341eb0d7b3f5bc38
3
+ size 88
sam_diffsr/tb_logs/events.out.tfevents.1709284101.wangchengchengdeMacBook-Pro.local.99211.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:687ac32fefdcbdb917bfa7c9197f8e64b050a506f9676d365f23484183da5629
3
+ size 88
sam_diffsr/tb_logs/events.out.tfevents.1709284193.wangchengchengdeMacBook-Pro.local.99233.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3fd3acbfec3acc9d5d582e81cd6819b2e5512f45bbf85918b89e6ea371ffbdf2
3
+ size 88
sam_diffsr/tb_logs/events.out.tfevents.1709284415.wangchengchengdeMacBook-Pro.local.99289.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5867fd2539cfbb896867c770c5c7c0f5d9687e29ea0322646fe820dae1cae08c
3
+ size 88
sam_diffsr/tb_logs/events.out.tfevents.1709284460.wangchengchengdeMacBook-Pro.local.99308.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b7999747d297acf0c198f9b8739134ee8ba7d4bc4fae209e47fc4bbd09a7206c
3
+ size 88
sam_diffsr/tb_logs/events.out.tfevents.1709284491.wangchengchengdeMacBook-Pro.local.99315.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:172e36eb89b0238ad5cdac335de91bc2721a1944c5e9807027a6c3eeea64f918
3
+ size 88
sam_diffsr/tb_logs/events.out.tfevents.1709285127.wangchengchengdeMacBook-Pro.local.785.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f51161b5901b64dd694dd06243d6143281b4df0f839ec89601407aad680cfc34
3
+ size 88
sam_diffsr/tb_logs/events.out.tfevents.1709285146.wangchengchengdeMacBook-Pro.local.901.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:776fa889c3c898a1b157f45aed579d9976791c718b507508421100c03baeb401
3
+ size 88
sam_diffsr/tools/caculate_iqa.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import ssl
3
+ from os.path import join
4
+ from pathlib import Path
5
+ from statistics import mean
6
+
7
+ parent_path = Path(__file__).absolute().parent.parent
8
+ parent_path = os.path.abspath(parent_path)
9
+
10
+ os.environ["CURL_CA_BUNDLE"] = ""
11
+ ssl._create_default_https_context = ssl._create_unverified_context
12
+
13
+ cache_path = os.path.join(parent_path, 'cache')
14
+ os.environ["HF_DATASETS_CACHE"] = cache_path
15
+ os.environ["TRANSFORMERS_CACHE"] = cache_path
16
+ os.environ["torch_HOME"] = cache_path
17
+
18
+ import PIL
19
+ import numpy as np
20
+ import pandas as pd
21
+ import pyiqa
22
+ import torch
23
+ from PIL import Image
24
+ from tqdm import tqdm
25
+
26
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
27
+
28
+ metric_dict = {
29
+ 'psnr-Y': pyiqa.create_metric('psnr', test_y_channel=True, color_space='ycbcr'),
30
+ 'ssim': pyiqa.create_metric('ssim', color_space='ycbcr'),
31
+ 'fid': pyiqa.create_metric('fid'),
32
+ }
33
+
34
+
35
+ def load_img(path, target_size=None):
36
+ image = Image.open(path).convert("RGB")
37
+ if target_size:
38
+ h, w = target_size
39
+ image = image.resize((w, h), resample=PIL.Image.LANCZOS)
40
+ image = np.array(image).astype(np.float32) / 255.0
41
+ image = image[None].transpose(0, 3, 1, 2)
42
+ image = torch.from_numpy(image)
43
+ return image
44
+
45
+
46
+ def eval_img_IQA(gt_dir, sr_dir, excel_path, metric_list, exp_name, data_name):
47
+ gt_img_list = os.listdir(gt_dir)
48
+
49
+ iqa_result = {}
50
+
51
+ for metric in metric_list:
52
+ iqa_metric = metric_dict[metric].to(device)
53
+ score_fr_list = []
54
+
55
+ if metric == 'fid':
56
+ score_fr = iqa_metric(sr_dir, gt_dir)
57
+ iqa_result[metric] = float(score_fr)
58
+ print(f'{metric}: {float(score_fr)}')
59
+ else:
60
+ for img_name in tqdm(gt_img_list):
61
+ base_name = img_name.split('.')[0]
62
+ sr_img_name = f'{base_name}.png'
63
+ gt_img_path = join(gt_dir, img_name)
64
+ sr_img_path = join(sr_dir, sr_img_name)
65
+
66
+ if not os.path.exists(sr_img_path):
67
+ print(f'File not exist: {sr_img_path}')
68
+ continue
69
+
70
+ gt_img = load_img(gt_img_path, target_size=None)
71
+ target_size = gt_img.shape[2:]
72
+ sr_img = load_img(sr_img_path, target_size=target_size)
73
+
74
+ score_fr = iqa_metric(sr_img, gt_img)
75
+
76
+ if score_fr.shape == (1,):
77
+ score_fr = score_fr[0]
78
+ if isinstance(score_fr, torch.Tensor):
79
+ score_fr = float(score_fr.cpu().numpy())
80
+ else:
81
+ score_fr = float(score_fr)
82
+ score_fr_list.append(score_fr)
83
+
84
+ mean_score = mean(score_fr_list)
85
+ iqa_result[metric] = float(mean_score)
86
+ print(f'{metric}: {mean_score}')
87
+
88
+ if os.path.exists(excel_path):
89
+ df = pd.read_excel(excel_path)
90
+ else:
91
+ df = pd.DataFrame(columns=['exp'])
92
+
93
+ new_index = len(df.index)
94
+
95
+ exp_name = int(exp_name)
96
+ if exp_name in df['exp'].to_list():
97
+ new_index = df[df['exp'] == exp_name].index.tolist()[0]
98
+ else:
99
+ df.loc[new_index, 'exp'] = exp_name
100
+
101
+ for index, metric in enumerate(metric_list):
102
+ df_metric = f'{data_name}-{metric}'
103
+ if df_metric not in df.columns.tolist():
104
+ df[df_metric] = ''
105
+
106
+ df.loc[new_index, df_metric] = iqa_result[metric]
107
+
108
+ df.sort_values(by='exp', inplace=True)
109
+
110
+ df.to_excel(excel_path, startcol=0, index=False)
111
+
112
+
113
+ def main():
114
+ epoch = 400000
115
+ add_name = ''
116
+ exp_root = '/home/ma-user/work/code/SRDiff-main/checkpoints'
117
+
118
+ model_type_list = ['diffsr_df2k4x_sam-pl_qs-zero']
119
+
120
+ metric_list = ['psnr-Y', 'ssim', 'fid']
121
+ benchmark_name_list = ['test_Set5', 'test_Set14', 'test_Urban100', 'test_Manga109', 'test_BSDS100']
122
+
123
+ # if benchmark:
124
+ for model_type in model_type_list:
125
+ excel_path = join(exp_root, model_type, f'IQA-val-{model_type}.xls')
126
+ for benchmark_name in benchmark_name_list:
127
+ exp_dir = join(exp_root, f'{model_type}/results_{epoch}_{add_name}/benchmark/{benchmark_name}')
128
+ gt_img_dir = join(exp_dir, 'HR')
129
+ sr_img_dir = join(exp_dir, 'SR')
130
+
131
+ data_name = benchmark_name[5:]
132
+ eval_img_IQA(gt_img_dir, sr_img_dir, excel_path, metric_list, epoch, data_name)
133
+
134
+
135
+ if __name__ == '__main__':
136
+ main()
sam_diffsr/tools/visualize_sam_mask.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+
4
+ import numpy as np
5
+ from matplotlib import pyplot as plt
6
+ from tqdm import tqdm
7
+
8
+ num = '0824'
9
+
10
+ sam_npy = '/home/ma-user/work/data/sr_sam/merge_RoPE/DF2K/DF2K_train_HR'
11
+ save_dir = '/home/ma-user/work/data/sr_sam/merge_RoPE/vis/DF2K/DF2K_train_HR'
12
+
13
+ os.makedirs(save_dir, exist_ok=True)
14
+
15
+ for file in tqdm(glob.glob(f'{sam_npy}/*.npy')):
16
+ name = os.path.basename(file).split('.')[0]
17
+ save_path = os.path.join(save_dir, f'{name}.png')
18
+ img = np.load(file)
19
+ plt.imshow(img)
20
+ plt.savefig(save_path)
sam_diffsr/utils_sr/__init__.py ADDED
File without changes
sam_diffsr/utils_sr/dataset.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ from torch.utils.data import Dataset
4
+ from torchvision import transforms
5
+
6
+ from .hparams import hparams
7
+ from .indexed_datasets import IndexedDataset
8
+ from .matlab_resize import imresize
9
+
10
+
11
+ class SRDataSet(Dataset):
12
+ def __init__(self, prefix='train'):
13
+ self.hparams = hparams
14
+ self.data_dir = hparams['binary_data_dir']
15
+ self.prefix = prefix
16
+ self.len = len(IndexedDataset(f'{self.data_dir}/{self.prefix}'))
17
+ self.to_tensor_norm = transforms.Compose([
18
+ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
19
+ ])
20
+ assert hparams['data_interp'] in ['bilinear', 'bicubic']
21
+ self.data_augmentation = hparams['data_augmentation']
22
+ self.indexed_ds = None
23
+ if self.prefix == 'valid':
24
+ self.len = hparams['eval_batch_size'] * hparams['valid_steps']
25
+
26
+ def _get_item(self, index):
27
+ if self.indexed_ds is None:
28
+ self.indexed_ds = IndexedDataset(f'{self.data_dir}/{self.prefix}')
29
+ return self.indexed_ds[index]
30
+
31
+ def __getitem__(self, index):
32
+ item = self._get_item(index)
33
+ hparams = self.hparams
34
+ img_hr = item['img']
35
+ img_hr = Image.fromarray(np.uint8(img_hr))
36
+ img_hr = self.pre_process(img_hr) # PIL
37
+ img_hr = np.asarray(img_hr) # np.uint8 [H, W, C]
38
+ img_lr = imresize(img_hr, 1 / hparams['sr_scale'], method=hparams['data_interp']) # np.uint8 [H, W, C]
39
+ img_lr_up = imresize(img_lr / 256, hparams['sr_scale']) # np.float [H, W, C]
40
+ img_hr, img_lr, img_lr_up = [self.to_tensor_norm(x).float() for x in [img_hr, img_lr, img_lr_up]]
41
+ return {
42
+ 'img_hr': img_hr, 'img_lr': img_lr, 'img_lr_up': img_lr_up,
43
+ 'item_name': item['item_name']
44
+ }
45
+
46
+ def pre_process(self, img_hr):
47
+ return img_hr
48
+
49
+ def __len__(self):
50
+ return self.len
sam_diffsr/utils_sr/hparams.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from pathlib import Path
4
+
5
+ import yaml
6
+
7
+ global_print_hparams = True
8
+ hparams = {}
9
+
10
+
11
+ class Args:
12
+ def __init__(self, **kwargs):
13
+ for k, v in kwargs.items():
14
+ self.__setattr__(k, v)
15
+
16
+
17
+ def override_config(old_config: dict, new_config: dict):
18
+ for k, v in new_config.items():
19
+ if isinstance(v, dict) and k in old_config:
20
+ override_config(old_config[k], new_config[k])
21
+ else:
22
+ old_config[k] = v
23
+
24
+
25
+ def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True):
26
+ parent_path = Path(__file__).absolute().parent.parent
27
+ fill_root = os.path.abspath(parent_path)
28
+
29
+ if config == '' and exp_name == '':
30
+ parser = argparse.ArgumentParser(description='')
31
+ parser.add_argument('--config', type=str, default=os.path.join(fill_root, 'configs/sam/sam_diffsr_df2k4x.yaml'),
32
+ help='location of the data corpus')
33
+ parser.add_argument('--exp_name', type=str, default='', help='exp_name')
34
+ parser.add_argument('--work_dir', type=str, default='', help='work dir')
35
+ parser.add_argument('--gt_img_path', type=str, default='data/sr_diff/benchmark', help='gt_img_path')
36
+ parser.add_argument('-hp', '--hparams', type=str, default='',
37
+ help='location of the data corpus')
38
+ parser.add_argument('--infer', action='store_true', help='infer')
39
+ parser.add_argument('--benchmark', action='store_true', help='test benchmark')
40
+ parser.add_argument('--benchmark_loop', action='store_true', help='loop test benchmark for all checkpoint')
41
+ parser.add_argument('--benchmark_name_list', nargs='+',
42
+ default=['test_Set5', 'test_Set14', 'test_Urban100', 'test_Manga109', 'test_BSDS100'])
43
+ parser.add_argument('--metric_list', nargs='+', default=['psnr-Y', 'ssim', 'fid'])
44
+ parser.add_argument('--validate', action='store_true', help='validate')
45
+ parser.add_argument('--val_steps', type=int, default=None, help='validate steps')
46
+ parser.add_argument('--reset', action='store_true', help='reset hparams')
47
+ parser.add_argument('--debug', action='store_true', help='debug')
48
+
49
+ parser.add_argument('--img_dir', type=str, default='', help='infer input image dir')
50
+ parser.add_argument('--save_dir', type=str, default='', help='infer output image dir')
51
+ parser.add_argument('--ckpt_path', type=str, default='', help='infer ckpt path')
52
+
53
+
54
+ args, unknown = parser.parse_known_args()
55
+ print("| Unknow hparams: ", unknown)
56
+ else:
57
+ args = Args(config=config, exp_name=exp_name, hparams=hparams_str,
58
+ infer=False, validate=False, reset=False, debug=False)
59
+ global hparams
60
+ assert args.config != '' or args.exp_name != ''
61
+ if args.config != '':
62
+ assert os.path.exists(args.config)
63
+
64
+ config_chains = []
65
+ loaded_config = set()
66
+
67
+ def load_config(config_fn):
68
+ # deep first inheritance and avoid the second visit of one node
69
+ if not os.path.exists(config_fn):
70
+ return {}
71
+ with open(config_fn) as f:
72
+ hparams_ = yaml.safe_load(f)
73
+ loaded_config.add(config_fn)
74
+ if 'base_config' in hparams_:
75
+ ret_hparams = {}
76
+ if not isinstance(hparams_['base_config'], list):
77
+ hparams_['base_config'] = [hparams_['base_config']]
78
+ for c in hparams_['base_config']:
79
+ if c.startswith('.'):
80
+ c = f'{os.path.dirname(config_fn)}/{c}'
81
+ c = os.path.normpath(c)
82
+ if c not in loaded_config:
83
+ override_config(ret_hparams, load_config(c))
84
+ override_config(ret_hparams, hparams_)
85
+ else:
86
+ ret_hparams = hparams_
87
+ config_chains.append(config_fn)
88
+ return ret_hparams
89
+
90
+ saved_hparams = {}
91
+ args_work_dir = ''
92
+ if args.exp_name != '':
93
+ args_work_dir = os.path.join(args.work_dir, 'checkpoints', args.exp_name)
94
+ ckpt_config_path = f'{args_work_dir}/config.yaml'
95
+ if os.path.exists(ckpt_config_path):
96
+ with open(ckpt_config_path) as f:
97
+ saved_hparams_ = yaml.safe_load(f)
98
+ if saved_hparams_ is not None:
99
+ saved_hparams.update(saved_hparams_)
100
+ hparams_ = {}
101
+ if args.config != '':
102
+ hparams_.update(load_config(args.config))
103
+ if not args.reset:
104
+ hparams_.update(saved_hparams)
105
+ hparams_['work_dir'] = args_work_dir
106
+
107
+ # Support config overriding in command line. Support list type config overriding.
108
+ # Examples: --hparams="a=1,b.c=2,d=[1 1 1]"
109
+ if args.hparams != "":
110
+ for new_hparam in args.hparams.split(","):
111
+ k, v = new_hparam.split("=")
112
+ v = v.strip("\'\" ")
113
+ config_node = hparams_
114
+ for k_ in k.split(".")[:-1]:
115
+ config_node = config_node[k_]
116
+ k = k.split(".")[-1]
117
+ if k not in config_node:
118
+ config_node[k] = v
119
+
120
+ elif v in ['True', 'False'] or type(config_node[k]) in [bool, list, dict]:
121
+ if type(config_node[k]) == list:
122
+ v = v.replace(" ", ",")
123
+ config_node[k] = eval(v)
124
+ else:
125
+ config_node[k] = type(config_node[k])(v)
126
+ if args_work_dir != '' and (not os.path.exists(ckpt_config_path) or args.reset) and not args.infer:
127
+ os.makedirs(hparams_['work_dir'], exist_ok=True)
128
+ with open(ckpt_config_path, 'w') as f:
129
+ yaml.safe_dump(hparams_, f)
130
+
131
+ hparams_['infer'] = args.infer
132
+ hparams_['debug'] = args.debug
133
+ hparams_['validate'] = args.validate
134
+ hparams_['exp_name'] = args.exp_name
135
+ hparams_['val_steps'] = args.val_steps
136
+ hparams_['benchmark'] = args.benchmark
137
+ hparams_['benchmark_loop'] = args.benchmark_loop
138
+ hparams_['benchmark_name_list'] = args.benchmark_name_list
139
+ hparams_['gt_img_path'] = args.gt_img_path
140
+ hparams_['metric_list'] = args.metric_list
141
+
142
+ hparams_['img_dir'] = args.img_dir
143
+ hparams_['save_dir'] = args.save_dir
144
+ hparams_['ckpt_path'] = args.ckpt_path
145
+
146
+ global global_print_hparams
147
+ if global_hparams:
148
+ hparams.clear()
149
+ hparams.update(hparams_)
150
+ if print_hparams and global_print_hparams and global_hparams:
151
+ print('| Hparams chains: ', config_chains)
152
+ print('| Hparams: ')
153
+ for i, (k, v) in enumerate(sorted(hparams_.items())):
154
+ print(f"\033[;33;m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "")
155
+ print("")
156
+ global_print_hparams = False
157
+ return hparams_
sam_diffsr/utils_sr/indexed_datasets.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+
3
+ import numpy as np
4
+
5
+
6
+ class IndexedDataset:
7
+ def __init__(self, path):
8
+ super().__init__()
9
+ self.path = path
10
+ self.data_file = None
11
+ index_data = np.load(f"{path}.idx", allow_pickle=True).item()
12
+ self.byte_offsets = index_data['offsets']
13
+ self.id2pos = index_data.get('id2pos', {})
14
+ self.data_file = open(f"{path}.data", 'rb', buffering=-1)
15
+
16
+ def check_index(self, i):
17
+ if i < 0 or i >= len(self.byte_offsets) - 1:
18
+ raise IndexError('index out of range')
19
+
20
+ def __del__(self):
21
+ if self.data_file:
22
+ self.data_file.close()
23
+
24
+ def __getitem__(self, i):
25
+ if self.id2pos is not None and len(self.id2pos) > 0:
26
+ i = self.id2pos[i]
27
+ self.check_index(i)
28
+ self.data_file.seek(self.byte_offsets[i])
29
+ b = self.data_file.read(self.byte_offsets[i + 1] - self.byte_offsets[i])
30
+ item = pickle.loads(b)
31
+ return item
32
+
33
+ def __len__(self):
34
+ return len(self.byte_offsets) - 1
35
+
36
+ def __iter__(self):
37
+ self.iter_i = 0
38
+ return self
39
+
40
+ def __next__(self):
41
+ if self.iter_i == len(self):
42
+ raise StopIteration
43
+ else:
44
+ item = self[self.iter_i]
45
+ self.iter_i += 1
46
+ return item
47
+
48
+
49
+ class IndexedDatasetBuilder:
50
+ def __init__(self, path, append=False):
51
+ self.path = path
52
+ if append:
53
+ self.data_file = open(f"{path}.data", 'ab')
54
+ index_data = np.load(f"{path}.idx", allow_pickle=True).item()
55
+ self.byte_offsets = index_data['offsets']
56
+ self.id2pos = index_data.get('id2pos', {})
57
+ else:
58
+ self.data_file = open(f"{path}.data", 'wb')
59
+ self.byte_offsets = [0]
60
+ self.id2pos = {}
61
+
62
+ def add_item(self, item, id=None):
63
+ s = pickle.dumps(item)
64
+ bytes = self.data_file.write(s)
65
+ if id is not None:
66
+ self.id2pos[id] = len(self.byte_offsets) - 1
67
+ self.byte_offsets.append(self.byte_offsets[-1] + bytes)
68
+
69
+ def finalize(self):
70
+ self.data_file.close()
71
+ np.save(open(f"{self.path}.idx", 'wb'),
72
+ {'offsets': self.byte_offsets, 'id2pos': self.id2pos})
sam_diffsr/utils_sr/matlab_resize.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/fatheral/matlab_imresize
2
+ #
3
+ # MIT License
4
+ #
5
+ # Copyright (c) 2020 Alex
6
+ #
7
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ # of this software and associated documentation files (the "Software"), to deal
9
+ # in the Software without restriction, including without limitation the rights
10
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ # copies of the Software, and to permit persons to whom the Software is
12
+ # furnished to do so, subject to the following conditions:
13
+ #
14
+ # The above copyright notice and this permission notice shall be included in all
15
+ # copies or substantial portions of the Software.
16
+ #
17
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ # SOFTWARE.
24
+
25
+
26
+ from __future__ import print_function
27
+
28
+ import numpy as np
29
+ from math import ceil
30
+
31
+
32
+ def deriveSizeFromScale(img_shape, scale):
33
+ output_shape = []
34
+ for k in range(2):
35
+ output_shape.append(int(ceil(scale[k] * img_shape[k])))
36
+ return output_shape
37
+
38
+
39
+ def deriveScaleFromSize(img_shape_in, img_shape_out):
40
+ scale = []
41
+ for k in range(2):
42
+ scale.append(1.0 * img_shape_out[k] / img_shape_in[k])
43
+ return scale
44
+
45
+
46
+ def triangle(x):
47
+ x = np.array(x).astype(np.float64)
48
+ lessthanzero = np.logical_and((x >= -1), x < 0)
49
+ greaterthanzero = np.logical_and((x <= 1), x >= 0)
50
+ f = np.multiply((x + 1), lessthanzero) + np.multiply((1 - x), greaterthanzero)
51
+ return f
52
+
53
+
54
+ def cubic(x):
55
+ x = np.array(x).astype(np.float64)
56
+ absx = np.absolute(x)
57
+ absx2 = np.multiply(absx, absx)
58
+ absx3 = np.multiply(absx2, absx)
59
+ f = np.multiply(1.5 * absx3 - 2.5 * absx2 + 1, absx <= 1) + np.multiply(-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2,
60
+ (1 < absx) & (absx <= 2))
61
+ return f
62
+
63
+
64
+ def contributions(in_length, out_length, scale, kernel, k_width):
65
+ if scale < 1:
66
+ h = lambda x: scale * kernel(scale * x)
67
+ kernel_width = 1.0 * k_width / scale
68
+ else:
69
+ h = kernel
70
+ kernel_width = k_width
71
+ x = np.arange(1, out_length + 1).astype(np.float64)
72
+ u = x / scale + 0.5 * (1 - 1 / scale)
73
+ left = np.floor(u - kernel_width / 2)
74
+ P = int(ceil(kernel_width)) + 2
75
+ ind = np.expand_dims(left, axis=1) + np.arange(P) - 1 # -1 because indexing from 0
76
+ indices = ind.astype(np.int32)
77
+ weights = h(np.expand_dims(u, axis=1) - indices - 1) # -1 because indexing from 0
78
+ weights = np.divide(weights, np.expand_dims(np.sum(weights, axis=1), axis=1))
79
+ aux = np.concatenate((np.arange(in_length), np.arange(in_length - 1, -1, step=-1))).astype(np.int32)
80
+ indices = aux[np.mod(indices, aux.size)]
81
+ ind2store = np.nonzero(np.any(weights, axis=0))
82
+ weights = weights[:, ind2store]
83
+ indices = indices[:, ind2store]
84
+ return weights, indices
85
+
86
+
87
+ def imresizemex(inimg, weights, indices, dim):
88
+ in_shape = inimg.shape
89
+ w_shape = weights.shape
90
+ out_shape = list(in_shape)
91
+ out_shape[dim] = w_shape[0]
92
+ outimg = np.zeros(out_shape)
93
+ if dim == 0:
94
+ for i_img in range(in_shape[1]):
95
+ for i_w in range(w_shape[0]):
96
+ w = weights[i_w, :]
97
+ ind = indices[i_w, :]
98
+ im_slice = inimg[ind, i_img].astype(np.float64)
99
+ outimg[i_w, i_img] = np.sum(np.multiply(np.squeeze(im_slice, axis=0), w.T), axis=0)
100
+ elif dim == 1:
101
+ for i_img in range(in_shape[0]):
102
+ for i_w in range(w_shape[0]):
103
+ w = weights[i_w, :]
104
+ ind = indices[i_w, :]
105
+ im_slice = inimg[i_img, ind].astype(np.float64)
106
+ outimg[i_img, i_w] = np.sum(np.multiply(np.squeeze(im_slice, axis=0), w.T), axis=0)
107
+ if inimg.dtype == np.uint8:
108
+ outimg = np.clip(outimg, 0, 255)
109
+ return np.around(outimg).astype(np.uint8)
110
+ else:
111
+ return outimg
112
+
113
+
114
+ def imresizevec(inimg, weights, indices, dim):
115
+ wshape = weights.shape
116
+ if dim == 0:
117
+ weights = weights.reshape((wshape[0], wshape[2], 1, 1))
118
+ outimg = np.sum(weights * ((inimg[indices].squeeze(axis=1)).astype(np.float64)), axis=1)
119
+ elif dim == 1:
120
+ weights = weights.reshape((1, wshape[0], wshape[2], 1))
121
+ outimg = np.sum(weights * ((inimg[:, indices].squeeze(axis=2)).astype(np.float64)), axis=2)
122
+ if inimg.dtype == np.uint8:
123
+ outimg = np.clip(outimg, 0, 255)
124
+ return np.around(outimg).astype(np.uint8)
125
+ else:
126
+ return outimg
127
+
128
+
129
+ def resizeAlongDim(A, dim, weights, indices, mode="vec"):
130
+ if mode == "org":
131
+ out = imresizemex(A, weights, indices, dim)
132
+ else:
133
+ out = imresizevec(A, weights, indices, dim)
134
+ return out
135
+
136
+
137
+ def imresize(I, scale=None, method='bicubic', sizes=None, mode="vec"):
138
+ if method == 'bicubic':
139
+ kernel = cubic
140
+ elif method == 'bilinear':
141
+ kernel = triangle
142
+ else:
143
+ print('Error: Unidentified method supplied')
144
+
145
+ kernel_width = 4.0
146
+ # Fill scale and output_size
147
+ if scale is not None:
148
+ scale = float(scale)
149
+ scale = [scale, scale]
150
+ output_size = deriveSizeFromScale(I.shape, scale)
151
+ elif sizes is not None:
152
+ scale = deriveScaleFromSize(I.shape, sizes)
153
+ output_size = list(sizes)
154
+ else:
155
+ print('Error: scalar_scale OR output_shape should be defined!')
156
+ return
157
+ scale_np = np.array(scale)
158
+ order = np.argsort(scale_np)
159
+ weights = []
160
+ indices = []
161
+ for k in range(2):
162
+ w, ind = contributions(I.shape[k], output_size[k], scale[k], kernel, kernel_width)
163
+ weights.append(w)
164
+ indices.append(ind)
165
+ B = np.copy(I)
166
+ flag2D = False
167
+ if B.ndim == 2:
168
+ B = np.expand_dims(B, axis=2)
169
+ flag2D = True
170
+ for k in range(2):
171
+ dim = order[k]
172
+ B = resizeAlongDim(B, dim, weights[dim], indices[dim], mode)
173
+ if flag2D:
174
+ B = np.squeeze(B, axis=2)
175
+ return B
176
+
177
+
178
+ def convertDouble2Byte(I):
179
+ B = np.clip(I, 0.0, 1.0)
180
+ B = 255 * B
181
+ return np.around(B).astype(np.uint8)
sam_diffsr/utils_sr/plt_img.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ from matplotlib import pyplot as plt
5
+ from torchvision.utils import make_grid
6
+
7
+
8
+ def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
9
+ """Convert torch Tensors into image numpy arrays.
10
+
11
+ After clamping to (min, max), image values will be normalized to [0, 1].
12
+
13
+ For different tensor shapes, this function will have different behaviors:
14
+
15
+ 1. 4D mini-batch Tensor of shape (N x 3/1 x H x W):
16
+ Use `make_grid` to stitch images in the batch dimension, and then
17
+ convert it to numpy array.
18
+ 2. 3D Tensor of shape (3/1 x H x W) and 2D Tensor of shape (H x W):
19
+ Directly change to numpy array.
20
+
21
+ Note that the image channel in input tensors should be RGB order. This
22
+ function will convert it to cv2 convention, i.e., (H x W x C) with BGR
23
+ order.
24
+
25
+ Args:
26
+ tensor (Tensor | list[Tensor]): Input tensors.
27
+ out_type (numpy type): Output types. If ``np.uint8``, transform outputs
28
+ to uint8 type with range [0, 255]; otherwise, float type with
29
+ range [0, 1]. Default: ``np.uint8``.
30
+ min_max (tuple): min and max values for clamp.
31
+
32
+ Returns:
33
+ (Tensor | list[Tensor]): 3D ndarray of shape (H x W x C) or 2D ndarray
34
+ of shape (H x W).
35
+ """
36
+ if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
37
+ raise TypeError(
38
+ f'tensor or list of tensors expected, got {type(tensor)}')
39
+
40
+ if torch.is_tensor(tensor):
41
+ tensor = [tensor]
42
+ result = []
43
+ for _tensor in tensor:
44
+ # Squeeze two times so that:
45
+ # 1. (1, 1, h, w) -> (h, w) or
46
+ # 3. (1, 3, h, w) -> (3, h, w) or
47
+ # 2. (n>1, 3/1, h, w) -> (n>1, 3/1, h, w)
48
+ _tensor = _tensor.squeeze(0).squeeze(0)
49
+ _tensor = _tensor.float().detach().cpu().clamp_(*min_max)
50
+ _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
51
+ n_dim = _tensor.dim()
52
+ if n_dim == 4:
53
+ img_np = make_grid(
54
+ _tensor, nrow=int(math.sqrt(_tensor.size(0))),
55
+ normalize=False).numpy()
56
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))
57
+ elif n_dim == 3:
58
+ img_np = _tensor.numpy()
59
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))
60
+ elif n_dim == 2:
61
+ img_np = _tensor.numpy()
62
+ else:
63
+ raise ValueError('Only support 4D, 3D or 2D tensor. '
64
+ f'But received with dimension: {n_dim}')
65
+ if out_type == np.uint8:
66
+ # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
67
+ img_np = (img_np * 255.0).round()
68
+ img_np = img_np.astype(out_type)
69
+ result.append(img_np)
70
+ result = result[0] if len(result) == 1 else result
71
+ return result
72
+
73
+
74
+ def plt_tensor_img(tensor, save_path=None):
75
+ plt.imshow(tensor2img(tensor))
76
+ plt.show()
77
+
78
+ if save_path:
79
+ plt.savefig(save_path)
80
+
81
+
82
+ def plt_tensor_img_one(tensor, t_dim=1):
83
+ if isinstance(tensor, list):
84
+ tensor = torch.cat(tensor, dim=t_dim)
85
+ nums = tensor.shape[t_dim]
86
+
87
+ mash = math.ceil(math.sqrt(nums))
88
+
89
+ plt.figure(dpi=300)
90
+ plt_range = min(nums, mash ** 2)
91
+ for i in range(plt_range):
92
+ plt.subplot(mash, mash, i + 1)
93
+ if t_dim == 1:
94
+ img = tensor2img(tensor[:, i, ...])
95
+ elif t_dim == 0:
96
+ img = tensor2img(tensor[i, ...])
97
+ plt.imshow(img)
98
+ plt.xticks([])
99
+ plt.yticks([])
100
+ plt.subplots_adjust(wspace=0, hspace=0)
101
+ plt.tight_layout()
102
+ plt.show()
103
+
104
+
105
+ def plt_img(img, save_path=None):
106
+ plt.imshow(img)
107
+ plt.show()
108
+ if save_path:
109
+ plt.savefig(save_path)
sam_diffsr/utils_sr/sr_utils.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torchvision
4
+ from torch.autograd import Variable
5
+ import numpy as np
6
+ from math import exp
7
+ import torch.nn as nn
8
+
9
+
10
+ class ImgMerger:
11
+ def __init__(self, eval_fn):
12
+ self.eval_fn = eval_fn
13
+ self.loc2imgs = {}
14
+ self.max_x = 0
15
+ self.max_y = 0
16
+ self.clear()
17
+
18
+ def clear(self):
19
+ self.loc2imgs = {}
20
+ self.max_x = 0
21
+ self.max_y = 0
22
+
23
+ def push(self, imgs, loc, loc_bdr):
24
+ """
25
+
26
+ Args:
27
+ imgs: each of img is [C, H, W] np.array, range: [0, 255]
28
+ loc: string, e.g., 0_0, 0_1 ...
29
+ """
30
+ self.max_x, self.max_y = loc_bdr
31
+ x, y = loc
32
+ self.loc2imgs[f'{x},{y}'] = imgs
33
+ if len(self.loc2imgs) == self.max_x * self.max_y:
34
+ return self.compute()
35
+
36
+ def compute(self):
37
+ img_inputs = []
38
+ for i in range(len(self.loc2imgs['0,0'])):
39
+ img_full = []
40
+ for x in range(self.max_x):
41
+ imgx = []
42
+ for y in range(self.max_y):
43
+ imgx.append(self.loc2imgs[f'{x},{y}'][i])
44
+ img_full.append(np.concatenate(imgx, 2))
45
+ img_inputs.append(np.concatenate(img_full, 1))
46
+ self.clear()
47
+ return self.eval_fn(*img_inputs)
48
+
49
+
50
+ ##########
51
+ # SSIM
52
+ ##########
53
+ def gaussian(window_size, sigma):
54
+ gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
55
+ return gauss / gauss.sum()
56
+
57
+
58
+ def create_window(window_size, channel):
59
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
60
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
61
+ window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
62
+ return window
63
+
64
+
65
+ def _ssim(img1, img2, window, window_size, channel, size_average=True):
66
+ mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
67
+ mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
68
+
69
+ mu1_sq = mu1.pow(2)
70
+ mu2_sq = mu2.pow(2)
71
+ mu1_mu2 = mu1 * mu2
72
+
73
+ sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
74
+ sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
75
+ sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
76
+
77
+ C1 = 0.01 ** 2
78
+ C2 = 0.03 ** 2
79
+
80
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
81
+
82
+ if size_average:
83
+ return ssim_map.mean()
84
+ else:
85
+ return ssim_map.mean(1).mean(1).mean(1)
86
+
87
+
88
+ class SSIM(torch.nn.Module):
89
+ def __init__(self, window_size=11, size_average=True):
90
+ super(SSIM, self).__init__()
91
+ self.window_size = window_size
92
+ self.size_average = size_average
93
+ self.channel = 1
94
+ self.window = create_window(window_size, self.channel)
95
+
96
+ def forward(self, img1, img2):
97
+ img1 = img1 * 0.5 + 0.5
98
+ img2 = img2 * 0.5 + 0.5
99
+ (_, channel, _, _) = img1.size()
100
+
101
+ if channel == self.channel and self.window.data.type() == img1.data.type():
102
+ window = self.window
103
+ else:
104
+ window = create_window(self.window_size, channel)
105
+
106
+ if img1.is_cuda:
107
+ window = window.cuda(img1.get_device())
108
+ window = window.type_as(img1)
109
+
110
+ self.window = window
111
+ self.channel = channel
112
+
113
+ return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
114
+
115
+
116
+ def ssim(img1, img2, window_size=11, size_average=True):
117
+ (_, channel, _, _) = img1.size()
118
+ window = create_window(window_size, channel)
119
+
120
+ if img1.is_cuda:
121
+ window = window.cuda(img1.get_device())
122
+ window = window.type_as(img1)
123
+
124
+ return _ssim(img1, img2, window, window_size, channel, size_average)
125
+
126
+
127
+ class VGGFeatureExtractor(nn.Module):
128
+ def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True):
129
+ super(VGGFeatureExtractor, self).__init__()
130
+ self.use_input_norm = use_input_norm
131
+ if use_bn:
132
+ model = torchvision.models.vgg19_bn(pretrained=True)
133
+ else:
134
+ model = torchvision.models.vgg19(pretrained=True)
135
+ if self.use_input_norm:
136
+ mean = torch.Tensor([0.485 - 1, 0.456 - 1, 0.406 - 1]).view(1, 3, 1, 1)
137
+ # mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
138
+ # [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1]
139
+ std = torch.Tensor([0.229 * 2, 0.224 * 2, 0.225 * 2]).view(1, 3, 1, 1)
140
+ # std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
141
+ # [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1]
142
+ self.register_buffer('mean', mean)
143
+ self.register_buffer('std', std)
144
+ self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)])
145
+ # No need to BP to variable
146
+ for k, v in self.features.named_parameters():
147
+ v.requires_grad = False
148
+
149
+ def forward(self, x):
150
+ # Assume input range is [0, 1]
151
+ if self.use_input_norm:
152
+ x = (x - self.mean) / self.std
153
+ output = self.features(x)
154
+ return output
155
+
156
+
157
+ class PerceptualLoss(nn.Module):
158
+ def __init__(self):
159
+ super(PerceptualLoss, self).__init__()
160
+ loss_network = VGGFeatureExtractor()
161
+ for param in loss_network.parameters():
162
+ param.requires_grad = False
163
+ self.loss_network = loss_network
164
+ self.l1_loss = nn.L1Loss()
165
+
166
+ def forward(self, high_resolution, fake_high_resolution):
167
+ if next(self.loss_network.parameters()).device != high_resolution.device:
168
+ self.loss_network.to(high_resolution.device)
169
+ self.loss_network.eval()
170
+ perception_loss = self.l1_loss(self.loss_network(high_resolution), self.loss_network(fake_high_resolution))
171
+ return perception_loss
sam_diffsr/utils_sr/utils.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import re
4
+ import subprocess
5
+ from collections import OrderedDict
6
+
7
+ import lpips
8
+ import numpy as np
9
+ import torch
10
+ import torch.distributed as dist
11
+ from skimage.metrics import peak_signal_noise_ratio as psnr
12
+ from skimage.metrics import structural_similarity as ssim
13
+
14
+ from .matlab_resize import imresize
15
+
16
+
17
+ def reduce_tensors(metrics):
18
+ new_metrics = {}
19
+ for k, v in metrics.items():
20
+ if isinstance(v, torch.Tensor):
21
+ dist.all_reduce(v)
22
+ v = v / dist.get_world_size()
23
+ if type(v) is dict:
24
+ v = reduce_tensors(v)
25
+ new_metrics[k] = v
26
+ return new_metrics
27
+
28
+
29
+ def tensors_to_scalars(tensors):
30
+ if isinstance(tensors, torch.Tensor):
31
+ tensors = tensors.item()
32
+ return tensors
33
+ elif isinstance(tensors, dict):
34
+ new_tensors = {}
35
+ for k, v in tensors.items():
36
+ v = tensors_to_scalars(v)
37
+ new_tensors[k] = v
38
+ return new_tensors
39
+ elif isinstance(tensors, list):
40
+ return [tensors_to_scalars(v) for v in tensors]
41
+ else:
42
+ return tensors
43
+
44
+
45
+ def tensors_to_np(tensors):
46
+ if isinstance(tensors, dict):
47
+ new_np = {}
48
+ for k, v in tensors.items():
49
+ if isinstance(v, torch.Tensor):
50
+ v = v.cpu().numpy()
51
+ if type(v) is dict:
52
+ v = tensors_to_np(v)
53
+ new_np[k] = v
54
+ elif isinstance(tensors, list):
55
+ new_np = []
56
+ for v in tensors:
57
+ if isinstance(v, torch.Tensor):
58
+ v = v.cpu().numpy()
59
+ if type(v) is dict:
60
+ v = tensors_to_np(v)
61
+ new_np.append(v)
62
+ elif isinstance(tensors, torch.Tensor):
63
+ v = tensors
64
+ if isinstance(v, torch.Tensor):
65
+ v = v.cpu().numpy()
66
+ if type(v) is dict:
67
+ v = tensors_to_np(v)
68
+ new_np = v
69
+ else:
70
+ raise Exception(f'tensors_to_np does not support type {type(tensors)}.')
71
+ return new_np
72
+
73
+
74
+ def move_to_cpu(tensors):
75
+ ret = {}
76
+ for k, v in tensors.items():
77
+ if isinstance(v, torch.Tensor):
78
+ v = v.cpu()
79
+ if type(v) is dict:
80
+ v = move_to_cpu(v)
81
+ ret[k] = v
82
+ return ret
83
+
84
+
85
+ def move_to_cuda(batch, gpu_id=0):
86
+ # base case: object can be directly moved using `cuda` or `to`
87
+ if callable(getattr(batch, 'cuda', None)):
88
+ return batch.cuda(gpu_id, non_blocking=True)
89
+ elif callable(getattr(batch, 'to', None)):
90
+ return batch.to(torch.device('cuda', gpu_id), non_blocking=True)
91
+ elif isinstance(batch, list):
92
+ for i, x in enumerate(batch):
93
+ batch[i] = move_to_cuda(x, gpu_id)
94
+ return batch
95
+ elif isinstance(batch, tuple):
96
+ batch = list(batch)
97
+ for i, x in enumerate(batch):
98
+ batch[i] = move_to_cuda(x, gpu_id)
99
+ return tuple(batch)
100
+ elif isinstance(batch, dict):
101
+ for k, v in batch.items():
102
+ batch[k] = move_to_cuda(v, gpu_id)
103
+ return batch
104
+ return batch
105
+
106
+
107
+ def get_last_checkpoint(work_dir, steps=None):
108
+ checkpoint = None
109
+ last_ckpt_path = None
110
+ ckpt_paths = get_all_ckpts(work_dir, steps)
111
+ if len(ckpt_paths) > 0:
112
+ last_ckpt_path = ckpt_paths[0]
113
+ checkpoint = torch.load(last_ckpt_path, map_location='cpu')
114
+ return checkpoint, last_ckpt_path
115
+
116
+
117
+ def get_all_ckpts(work_dir, steps=None):
118
+ if steps is None:
119
+ ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_*.ckpt'
120
+ else:
121
+ ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_{steps}.ckpt'
122
+ return sorted(glob.glob(ckpt_path_pattern),
123
+ key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0]))
124
+
125
+
126
+ def load_checkpoint(model, optimizer, work_dir, steps=None):
127
+ checkpoint, last_ckpt_path = get_last_checkpoint(work_dir, steps)
128
+ print(f'loding check from: {last_ckpt_path}')
129
+ if checkpoint is not None:
130
+ stat_dict = checkpoint['state_dict']['model']
131
+
132
+ new_state_dict = OrderedDict()
133
+ for k, v in stat_dict.items():
134
+ if k[:7] == 'module.':
135
+ k = k[7:] # 去掉 `module.`
136
+ new_state_dict[k] = v
137
+
138
+ model.load_state_dict(new_state_dict)
139
+ model.cuda()
140
+ optimizer.load_state_dict(checkpoint['optimizer_states'][0])
141
+ training_step = checkpoint['global_step']
142
+ del checkpoint
143
+ torch.cuda.empty_cache()
144
+ else:
145
+ training_step = 0
146
+ model.cuda()
147
+ return training_step
148
+
149
+
150
+ def save_checkpoint(model, optimizer, work_dir, global_step, num_ckpt_keep):
151
+ ckpt_path = f'{work_dir}/model_ckpt_steps_{global_step}.ckpt'
152
+ print(f'Step@{global_step}: saving model to {ckpt_path}')
153
+ checkpoint = {'global_step': global_step}
154
+ optimizer_states = []
155
+ optimizer_states.append(optimizer.state_dict())
156
+ checkpoint['optimizer_states'] = optimizer_states
157
+ checkpoint['state_dict'] = {'model': model.state_dict()}
158
+ torch.save(checkpoint, ckpt_path, _use_new_zipfile_serialization=False)
159
+ for old_ckpt in get_all_ckpts(work_dir)[num_ckpt_keep:]:
160
+ remove_file(old_ckpt)
161
+ print(f'Delete ckpt: {os.path.basename(old_ckpt)}')
162
+
163
+
164
+ def remove_file(*fns):
165
+ for f in fns:
166
+ subprocess.check_call(f'rm -rf "{f}"', shell=True)
167
+
168
+
169
+ def plot_img(img):
170
+ img = img.data.cpu().numpy()
171
+ return np.clip(img, 0, 1)
172
+
173
+
174
+ def load_ckpt(cur_model, ckpt_base_dir, model_name='model', force=True, strict=True):
175
+ if os.path.isfile(ckpt_base_dir):
176
+ base_dir = os.path.dirname(ckpt_base_dir)
177
+ ckpt_path = ckpt_base_dir
178
+ checkpoint = torch.load(ckpt_base_dir, map_location='cpu')
179
+ else:
180
+ base_dir = ckpt_base_dir
181
+ checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir)
182
+ if checkpoint is not None:
183
+ state_dict = checkpoint["state_dict"]
184
+ if len([k for k in state_dict.keys() if '.' in k]) > 0:
185
+ state_dict = {k[len(model_name) + 1:]: v for k, v in state_dict.items()
186
+ if k.startswith(f'{model_name}.')}
187
+ else:
188
+ state_dict = state_dict[model_name]
189
+ if not strict:
190
+ cur_model_state_dict = cur_model.state_dict()
191
+ unmatched_keys = []
192
+ for key, param in state_dict.items():
193
+ if key in cur_model_state_dict:
194
+ new_param = cur_model_state_dict[key]
195
+ if new_param.shape != param.shape:
196
+ unmatched_keys.append(key)
197
+ print("| Unmatched keys: ", key, new_param.shape, param.shape)
198
+ for key in unmatched_keys:
199
+ del state_dict[key]
200
+ cur_model.load_state_dict(state_dict, strict=strict)
201
+ print(f"| load '{model_name}' from '{ckpt_path}'.")
202
+ else:
203
+ e_msg = f"| ckpt not found in {base_dir}."
204
+ if force:
205
+ assert False, e_msg
206
+ else:
207
+ print(e_msg)
208
+
209
+
210
+ class Measure:
211
+ def __init__(self, net='alex'):
212
+ self.model = lpips.LPIPS(net=net)
213
+
214
+ def measure(self, imgA, imgB, img_lr, sr_scale):
215
+ """
216
+
217
+ Args:
218
+ imgA: [C, H, W] uint8 or torch.FloatTensor [-1,1]
219
+ imgB: [C, H, W] uint8 or torch.FloatTensor [-1,1]
220
+ img_lr: [C, H, W] uint8 or torch.FloatTensor [-1,1]
221
+ sr_scale:
222
+
223
+ Returns: dict of metrics
224
+
225
+ """
226
+ if isinstance(imgA, torch.Tensor):
227
+ imgA = np.round((imgA.cpu().numpy() + 1) * 127.5).clip(min=0, max=255).astype(np.uint8)
228
+ imgB = np.round((imgB.cpu().numpy() + 1) * 127.5).clip(min=0, max=255).astype(np.uint8)
229
+ img_lr = np.round((img_lr.cpu().numpy() + 1) * 127.5).clip(min=0, max=255).astype(np.uint8)
230
+ imgA = imgA.transpose(1, 2, 0)
231
+ imgA_lr = imresize(imgA, 1 / sr_scale)
232
+ imgB = imgB.transpose(1, 2, 0)
233
+ img_lr = img_lr.transpose(1, 2, 0)
234
+ psnr = self.psnr(imgA, imgB)
235
+ ssim = self.ssim(imgA, imgB)
236
+ lpips = self.lpips(imgA, imgB)
237
+ lr_psnr = self.psnr(imgA_lr, img_lr)
238
+ res = {'psnr': psnr, 'ssim': ssim, 'lpips': lpips, 'lr_psnr': lr_psnr}
239
+ return {k: float(v) for k, v in res.items()}
240
+
241
+ def lpips(self, imgA, imgB, model=None):
242
+ device = next(self.model.parameters()).device
243
+ tA = t(imgA).to(device)
244
+ tB = t(imgB).to(device)
245
+ dist01 = self.model.forward(tA, tB).item()
246
+ return dist01
247
+
248
+ def ssim(self, imgA, imgB):
249
+ score, diff = ssim(imgA, imgB, full=True, channel_axis=2, data_range=255)
250
+ return score
251
+
252
+ def psnr(self, imgA, imgB):
253
+ return psnr(imgA, imgB, data_range=255)
254
+
255
+
256
+ def t(img):
257
+ def to_4d(img):
258
+ assert len(img.shape) == 3
259
+ img_new = np.expand_dims(img, axis=0)
260
+ assert len(img_new.shape) == 4
261
+ return img_new
262
+
263
+ def to_CHW(img):
264
+ return np.transpose(img, [2, 0, 1])
265
+
266
+ def to_tensor(img):
267
+ return torch.Tensor(img)
268
+
269
+ return to_tensor(to_4d(to_CHW(img))) / 127.5 - 1
sam_diffsr/weight/model_ckpt_steps_400000.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ab89ee4160be868422459918eb69880042dc12544b1bf7807aa479c7eb329e55
3
+ size 204945145