Traly's picture
init
193c713
import importlib
import json
import os
import subprocess
import sys
from collections import OrderedDict
from pathlib import Path
parent_path = Path(__file__).absolute().parent.parent
sys.path.append(os.path.abspath(parent_path))
os.chdir(parent_path)
print(f'>-------------> parent path {parent_path}')
print(f'>-------------> current work dir {os.getcwd()}')
cache_path = os.path.join(parent_path, 'cache')
os.environ["HF_DATASETS_CACHE"] = cache_path
os.environ["TRANSFORMERS_CACHE"] = cache_path
os.environ["torch_HOME"] = cache_path
import torch
from PIL import Image
from tqdm import tqdm
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from sam_diffsr.utils_sr.hparams import hparams, set_hparams
from sam_diffsr.utils_sr.utils import plot_img, move_to_cuda, load_checkpoint, save_checkpoint, tensors_to_scalars, Measure, \
get_all_ckpts
class Trainer:
def __init__(self):
self.logger = self.build_tensorboard(save_dir=hparams['work_dir'], name='tb_logs')
self.measure = Measure()
self.dataset_cls = None
self.metric_keys = ['psnr', 'ssim', 'lpips', 'lr_psnr']
self.metric_2_keys = ['psnr-Y', 'ssim', 'fid']
self.work_dir = hparams['work_dir']
self.first_val = True
self.val_steps = hparams['val_steps']
def build_tensorboard(self, save_dir, name, **kwargs):
log_dir = os.path.join(save_dir, name)
os.makedirs(log_dir, exist_ok=True)
return SummaryWriter(log_dir=log_dir, **kwargs)
def build_train_dataloader(self):
dataset = self.dataset_cls('train')
return torch.utils.data.DataLoader(
dataset, batch_size=hparams['batch_size'], shuffle=True,
pin_memory=False, num_workers=hparams['num_workers'])
def build_val_dataloader(self):
return torch.utils.data.DataLoader(
self.dataset_cls('valid'), batch_size=hparams['eval_batch_size'], shuffle=False, pin_memory=False)
def build_test_dataloader(self):
return torch.utils.data.DataLoader(
self.dataset_cls('test'), batch_size=hparams['eval_batch_size'], shuffle=False, pin_memory=False)
def build_model(self):
raise NotImplementedError
def sample_and_test(self, sample):
raise NotImplementedError
def build_optimizer(self, model):
raise NotImplementedError
def build_scheduler(self, optimizer):
raise NotImplementedError
def training_step(self, batch):
raise NotImplementedError
def train(self):
model = self.build_model()
optimizer = self.build_optimizer(model)
self.global_step = training_step = load_checkpoint(model, optimizer, hparams['work_dir'], steps=self.val_steps)
self.scheduler = scheduler = self.build_scheduler(optimizer)
scheduler.step(training_step)
dataloader = self.build_train_dataloader()
train_pbar = tqdm(dataloader, initial=training_step, total=float('inf'),
dynamic_ncols=True, unit='step')
while self.global_step < hparams['max_updates']:
for batch in train_pbar:
if training_step % hparams['val_check_interval'] == 0:
with torch.no_grad():
model.eval()
self.validate(training_step)
save_checkpoint(model, optimizer, self.work_dir, training_step, hparams['num_ckpt_keep'])
model.train()
batch = move_to_cuda(batch)
losses, total_loss = self.training_step(batch)
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
training_step += 1
scheduler.step(training_step)
self.global_step = training_step
if training_step % 100 == 0:
self.log_metrics({f'tr/{k}': v for k, v in losses.items()}, training_step)
train_pbar.set_postfix(**tensors_to_scalars(losses))
def validate(self, training_step):
val_dataloader = self.build_val_dataloader()
pbar = tqdm(enumerate(val_dataloader), total=len(val_dataloader))
metrics = {}
for batch_idx, batch in pbar:
# 每次运行的第一次validation只跑一小部分数据,来验证代码能否跑通
if self.first_val and batch_idx > hparams['num_sanity_val_steps'] - 1:
break
batch = move_to_cuda(batch)
img, rrdb_out, ret = self.sample_and_test(batch)
img_hr = batch['img_hr']
img_lr = batch['img_lr']
img_lr_up = batch['img_lr_up']
if img is not None:
self.logger.add_image(f'Pred_{batch_idx}', plot_img(img[0]), self.global_step)
if hparams.get('aux_l1_loss'):
self.logger.add_image(f'rrdb_out_{batch_idx}', plot_img(rrdb_out[0]), self.global_step)
if self.global_step <= hparams['val_check_interval']:
self.logger.add_image(f'HR_{batch_idx}', plot_img(img_hr[0]), self.global_step)
self.logger.add_image(f'LR_{batch_idx}', plot_img(img_lr[0]), self.global_step)
self.logger.add_image(f'BL_{batch_idx}', plot_img(img_lr_up[0]), self.global_step)
metrics = {}
metrics.update({k: np.mean(ret[k]) for k in self.metric_keys})
pbar.set_postfix(**tensors_to_scalars(metrics))
if hparams['infer']:
print('Val results:', metrics)
else:
if not self.first_val:
self.log_metrics({f'val/{k}': v for k, v in metrics.items()}, training_step)
print('Val results:', metrics)
else:
print('Sanity val results:', metrics)
self.first_val = False
def build_test_my_dataloader(self, data_name):
return torch.utils.data.DataLoader(
self.dataset_cls(data_name), batch_size=hparams['eval_batch_size'], shuffle=False, pin_memory=False)
def benchmark(self, benchmark_name_list, metric_list):
from sam_diffsr.tools.caculate_iqa import eval_img_IQA
model = self.build_model()
optimizer = self.build_optimizer(model)
training_step = load_checkpoint(model, optimizer, hparams['work_dir'], hparams['val_steps'])
self.global_step = training_step
optimizer = None
for data_name in benchmark_name_list:
test_dataloader = self.build_test_my_dataloader(data_name)
self.results = {k: 0 for k in self.metric_keys}
self.n_samples = 0
self.gen_dir = f"{hparams['work_dir']}/results_{self.global_step}_{hparams['gen_dir_name']}/benchmark/{data_name}"
if hparams['test_save_png']:
subprocess.check_call(f'rm -rf {self.gen_dir}', shell=True)
os.makedirs(f'{self.gen_dir}/outputs', exist_ok=True)
os.makedirs(f'{self.gen_dir}/SR', exist_ok=True)
self.model.sample_tqdm = False
torch.backends.cudnn.benchmark = False
if hparams['test_save_png']:
if hasattr(self.model.denoise_fn, 'make_generation_fast_'):
self.model.denoise_fn.make_generation_fast_()
os.makedirs(f'{self.gen_dir}/HR', exist_ok=True)
result_dict = {}
with torch.no_grad():
model.eval()
pbar = tqdm(enumerate(test_dataloader), total=len(test_dataloader))
for batch_idx, batch in pbar:
move_to_cuda(batch)
gen_dir = self.gen_dir
item_names = batch['item_name']
img_hr = batch['img_hr']
img_lr = batch['img_lr']
img_lr_up = batch['img_lr_up']
res = self.sample_and_test(batch)
if len(res) == 3:
img_sr, rrdb_out, ret = res
else:
img_sr, ret = res
rrdb_out = img_sr
img_lr_up = batch.get('img_lr_up', img_lr_up)
if img_sr is not None:
metrics = list(self.metric_keys)
result_dict[batch['item_name'][0]] = {}
for k in metrics:
self.results[k] += ret[k]
result_dict[batch['item_name'][0]][k] = ret[k]
self.n_samples += ret['n_samples']
print({k: round(self.results[k] / self.n_samples, 3) for k in self.results}, 'total:',
self.n_samples)
if hparams['test_save_png'] and img_sr is not None:
img_sr = self.tensor2img(img_sr)
img_hr = self.tensor2img(img_hr)
img_lr = self.tensor2img(img_lr)
img_lr_up = self.tensor2img(img_lr_up)
rrdb_out = self.tensor2img(rrdb_out)
for item_name, hr_p, hr_g, lr, lr_up, rrdb_o in zip(
item_names, img_sr, img_hr, img_lr, img_lr_up, rrdb_out):
item_name = os.path.splitext(item_name)[0]
hr_p = Image.fromarray(hr_p)
hr_g = Image.fromarray(hr_g)
hr_p.save(f"{gen_dir}/SR/{item_name}.png")
hr_g.save(f"{gen_dir}/HR/{item_name}.png")
exp_name = hparams['work_dir'].split('/')[-1]
sr_img_dir = f"{gen_dir}/SR/"
gt_img_dir = f"{gen_dir}/HR/"
excel_path = f"{hparams['work_dir']}/IQA-val-benchmark-{exp_name}.xlsx"
epoch = training_step
eval_img_IQA(gt_img_dir, sr_img_dir, excel_path, metric_list, epoch, data_name)
os.makedirs(f'{self.gen_dir}', exist_ok=True)
eval_json_path = os.path.join(self.gen_dir, 'eval.json')
avg_result = {k: round(self.results[k] / self.n_samples, 4) for k in self.results}
with open(eval_json_path, 'w+') as file:
json.dump(avg_result, file, sort_keys=True, indent=4, separators=(',', ': '), ensure_ascii=False)
json.dump(result_dict, file, sort_keys=True, indent=4, separators=(',', ': '), ensure_ascii=False)
def benchmark_loop(self, benchmark_name_list, metric_list, gt_path):
# infer and evaluation all save checkpoint
from sam_diffsr.tools.caculate_iqa import eval_img_IQA
model = self.build_model()
def get_checkpoint(model, checkpoint):
stat_dict = checkpoint['state_dict']['model']
new_state_dict = OrderedDict()
for k, v in stat_dict.items():
if k[:7] == 'module.':
k = k[7:] # 去掉 `module.`
new_state_dict[k] = v
model.load_state_dict(new_state_dict)
model.cuda()
training_step = checkpoint['global_step']
del checkpoint
torch.cuda.empty_cache()
return training_step
ckpt_paths = get_all_ckpts(hparams['work_dir'])
for ckpt_path in ckpt_paths:
checkpoint = torch.load(ckpt_path, map_location='cpu')
training_step = get_checkpoint(model, checkpoint)
self.global_step = training_step
for data_name in benchmark_name_list:
test_dataloader = self.build_test_my_dataloader(data_name)
self.results = {k: 0 for k in self.metric_keys + self.metric_2_keys}
self.n_samples = 0
self.gen_dir = f"{hparams['work_dir']}/results_{training_step}_{hparams['gen_dir_name']}/benchmark/{data_name}"
os.makedirs(f'{self.gen_dir}/outputs', exist_ok=True)
os.makedirs(f'{self.gen_dir}/SR', exist_ok=True)
self.model.sample_tqdm = False
torch.backends.cudnn.benchmark = False
with torch.no_grad():
model.eval()
pbar = tqdm(enumerate(test_dataloader), total=len(test_dataloader))
for batch_idx, batch in pbar:
move_to_cuda(batch)
gen_dir = self.gen_dir
item_names = batch['item_name']
res = self.sample_and_test(batch)
if len(res) == 3:
img_sr, rrdb_out, ret = res
else:
img_sr, ret = res
rrdb_out = img_sr
img_sr = self.tensor2img(img_sr)
for item_name, hr_p in zip(item_names, img_sr):
item_name = os.path.splitext(item_name)[0]
hr_p = Image.fromarray(hr_p)
hr_p.save(f"{gen_dir}/SR/{item_name}.png")
exp_name = hparams['work_dir'].split('/')[-1]
sr_img_dir = f"{gen_dir}/SR/"
gt_img_dir = f"{gt_path}/{data_name}/HR"
excel_path = f"{hparams['work_dir']}/IQA-val-benchmark_loop-{exp_name}.xlsx"
epoch = training_step
eval_img_IQA(gt_img_dir, sr_img_dir, excel_path, metric_list, epoch, data_name)
# utils_sr
def log_metrics(self, metrics, step):
metrics = self.metrics_to_scalars(metrics)
logger = self.logger
for k, v in metrics.items():
if isinstance(v, torch.Tensor):
v = v.item()
logger.add_scalar(k, v, step)
def metrics_to_scalars(self, metrics):
new_metrics = {}
for k, v in metrics.items():
if isinstance(v, torch.Tensor):
v = v.item()
if type(v) is dict:
v = self.metrics_to_scalars(v)
new_metrics[k] = v
return new_metrics
@staticmethod
def tensor2img(img):
img = np.round((img.permute(0, 2, 3, 1).cpu().numpy() + 1) * 127.5)
img = img.clip(min=0, max=255).astype(np.uint8)
return img
if __name__ == '__main__':
set_hparams()
pkg = ".".join(hparams["trainer_cls"].split(".")[:-1])
cls_name = hparams["trainer_cls"].split(".")[-1]
trainer = getattr(importlib.import_module(pkg), cls_name)()
if hparams['benchmark_loop']:
trainer.benchmark_loop(hparams['benchmark_name_list'], hparams['metric_list'], hparams['gt_img_path'])
elif hparams['benchmark']:
trainer.benchmark(hparams['benchmark_name_list'], hparams['metric_list'])
else:
trainer.train()