Spaces:
Runtime error
Runtime error
import importlib | |
import json | |
import os | |
import subprocess | |
import sys | |
from collections import OrderedDict | |
from pathlib import Path | |
parent_path = Path(__file__).absolute().parent.parent | |
sys.path.append(os.path.abspath(parent_path)) | |
os.chdir(parent_path) | |
print(f'>-------------> parent path {parent_path}') | |
print(f'>-------------> current work dir {os.getcwd()}') | |
cache_path = os.path.join(parent_path, 'cache') | |
os.environ["HF_DATASETS_CACHE"] = cache_path | |
os.environ["TRANSFORMERS_CACHE"] = cache_path | |
os.environ["torch_HOME"] = cache_path | |
import torch | |
from PIL import Image | |
from tqdm import tqdm | |
import numpy as np | |
from torch.utils.tensorboard import SummaryWriter | |
from sam_diffsr.utils_sr.hparams import hparams, set_hparams | |
from sam_diffsr.utils_sr.utils import plot_img, move_to_cuda, load_checkpoint, save_checkpoint, tensors_to_scalars, Measure, \ | |
get_all_ckpts | |
class Trainer: | |
def __init__(self): | |
self.logger = self.build_tensorboard(save_dir=hparams['work_dir'], name='tb_logs') | |
self.measure = Measure() | |
self.dataset_cls = None | |
self.metric_keys = ['psnr', 'ssim', 'lpips', 'lr_psnr'] | |
self.metric_2_keys = ['psnr-Y', 'ssim', 'fid'] | |
self.work_dir = hparams['work_dir'] | |
self.first_val = True | |
self.val_steps = hparams['val_steps'] | |
def build_tensorboard(self, save_dir, name, **kwargs): | |
log_dir = os.path.join(save_dir, name) | |
os.makedirs(log_dir, exist_ok=True) | |
return SummaryWriter(log_dir=log_dir, **kwargs) | |
def build_train_dataloader(self): | |
dataset = self.dataset_cls('train') | |
return torch.utils.data.DataLoader( | |
dataset, batch_size=hparams['batch_size'], shuffle=True, | |
pin_memory=False, num_workers=hparams['num_workers']) | |
def build_val_dataloader(self): | |
return torch.utils.data.DataLoader( | |
self.dataset_cls('valid'), batch_size=hparams['eval_batch_size'], shuffle=False, pin_memory=False) | |
def build_test_dataloader(self): | |
return torch.utils.data.DataLoader( | |
self.dataset_cls('test'), batch_size=hparams['eval_batch_size'], shuffle=False, pin_memory=False) | |
def build_model(self): | |
raise NotImplementedError | |
def sample_and_test(self, sample): | |
raise NotImplementedError | |
def build_optimizer(self, model): | |
raise NotImplementedError | |
def build_scheduler(self, optimizer): | |
raise NotImplementedError | |
def training_step(self, batch): | |
raise NotImplementedError | |
def train(self): | |
model = self.build_model() | |
optimizer = self.build_optimizer(model) | |
self.global_step = training_step = load_checkpoint(model, optimizer, hparams['work_dir'], steps=self.val_steps) | |
self.scheduler = scheduler = self.build_scheduler(optimizer) | |
scheduler.step(training_step) | |
dataloader = self.build_train_dataloader() | |
train_pbar = tqdm(dataloader, initial=training_step, total=float('inf'), | |
dynamic_ncols=True, unit='step') | |
while self.global_step < hparams['max_updates']: | |
for batch in train_pbar: | |
if training_step % hparams['val_check_interval'] == 0: | |
with torch.no_grad(): | |
model.eval() | |
self.validate(training_step) | |
save_checkpoint(model, optimizer, self.work_dir, training_step, hparams['num_ckpt_keep']) | |
model.train() | |
batch = move_to_cuda(batch) | |
losses, total_loss = self.training_step(batch) | |
optimizer.zero_grad() | |
total_loss.backward() | |
optimizer.step() | |
training_step += 1 | |
scheduler.step(training_step) | |
self.global_step = training_step | |
if training_step % 100 == 0: | |
self.log_metrics({f'tr/{k}': v for k, v in losses.items()}, training_step) | |
train_pbar.set_postfix(**tensors_to_scalars(losses)) | |
def validate(self, training_step): | |
val_dataloader = self.build_val_dataloader() | |
pbar = tqdm(enumerate(val_dataloader), total=len(val_dataloader)) | |
metrics = {} | |
for batch_idx, batch in pbar: | |
# 每次运行的第一次validation只跑一小部分数据,来验证代码能否跑通 | |
if self.first_val and batch_idx > hparams['num_sanity_val_steps'] - 1: | |
break | |
batch = move_to_cuda(batch) | |
img, rrdb_out, ret = self.sample_and_test(batch) | |
img_hr = batch['img_hr'] | |
img_lr = batch['img_lr'] | |
img_lr_up = batch['img_lr_up'] | |
if img is not None: | |
self.logger.add_image(f'Pred_{batch_idx}', plot_img(img[0]), self.global_step) | |
if hparams.get('aux_l1_loss'): | |
self.logger.add_image(f'rrdb_out_{batch_idx}', plot_img(rrdb_out[0]), self.global_step) | |
if self.global_step <= hparams['val_check_interval']: | |
self.logger.add_image(f'HR_{batch_idx}', plot_img(img_hr[0]), self.global_step) | |
self.logger.add_image(f'LR_{batch_idx}', plot_img(img_lr[0]), self.global_step) | |
self.logger.add_image(f'BL_{batch_idx}', plot_img(img_lr_up[0]), self.global_step) | |
metrics = {} | |
metrics.update({k: np.mean(ret[k]) for k in self.metric_keys}) | |
pbar.set_postfix(**tensors_to_scalars(metrics)) | |
if hparams['infer']: | |
print('Val results:', metrics) | |
else: | |
if not self.first_val: | |
self.log_metrics({f'val/{k}': v for k, v in metrics.items()}, training_step) | |
print('Val results:', metrics) | |
else: | |
print('Sanity val results:', metrics) | |
self.first_val = False | |
def build_test_my_dataloader(self, data_name): | |
return torch.utils.data.DataLoader( | |
self.dataset_cls(data_name), batch_size=hparams['eval_batch_size'], shuffle=False, pin_memory=False) | |
def benchmark(self, benchmark_name_list, metric_list): | |
from sam_diffsr.tools.caculate_iqa import eval_img_IQA | |
model = self.build_model() | |
optimizer = self.build_optimizer(model) | |
training_step = load_checkpoint(model, optimizer, hparams['work_dir'], hparams['val_steps']) | |
self.global_step = training_step | |
optimizer = None | |
for data_name in benchmark_name_list: | |
test_dataloader = self.build_test_my_dataloader(data_name) | |
self.results = {k: 0 for k in self.metric_keys} | |
self.n_samples = 0 | |
self.gen_dir = f"{hparams['work_dir']}/results_{self.global_step}_{hparams['gen_dir_name']}/benchmark/{data_name}" | |
if hparams['test_save_png']: | |
subprocess.check_call(f'rm -rf {self.gen_dir}', shell=True) | |
os.makedirs(f'{self.gen_dir}/outputs', exist_ok=True) | |
os.makedirs(f'{self.gen_dir}/SR', exist_ok=True) | |
self.model.sample_tqdm = False | |
torch.backends.cudnn.benchmark = False | |
if hparams['test_save_png']: | |
if hasattr(self.model.denoise_fn, 'make_generation_fast_'): | |
self.model.denoise_fn.make_generation_fast_() | |
os.makedirs(f'{self.gen_dir}/HR', exist_ok=True) | |
result_dict = {} | |
with torch.no_grad(): | |
model.eval() | |
pbar = tqdm(enumerate(test_dataloader), total=len(test_dataloader)) | |
for batch_idx, batch in pbar: | |
move_to_cuda(batch) | |
gen_dir = self.gen_dir | |
item_names = batch['item_name'] | |
img_hr = batch['img_hr'] | |
img_lr = batch['img_lr'] | |
img_lr_up = batch['img_lr_up'] | |
res = self.sample_and_test(batch) | |
if len(res) == 3: | |
img_sr, rrdb_out, ret = res | |
else: | |
img_sr, ret = res | |
rrdb_out = img_sr | |
img_lr_up = batch.get('img_lr_up', img_lr_up) | |
if img_sr is not None: | |
metrics = list(self.metric_keys) | |
result_dict[batch['item_name'][0]] = {} | |
for k in metrics: | |
self.results[k] += ret[k] | |
result_dict[batch['item_name'][0]][k] = ret[k] | |
self.n_samples += ret['n_samples'] | |
print({k: round(self.results[k] / self.n_samples, 3) for k in self.results}, 'total:', | |
self.n_samples) | |
if hparams['test_save_png'] and img_sr is not None: | |
img_sr = self.tensor2img(img_sr) | |
img_hr = self.tensor2img(img_hr) | |
img_lr = self.tensor2img(img_lr) | |
img_lr_up = self.tensor2img(img_lr_up) | |
rrdb_out = self.tensor2img(rrdb_out) | |
for item_name, hr_p, hr_g, lr, lr_up, rrdb_o in zip( | |
item_names, img_sr, img_hr, img_lr, img_lr_up, rrdb_out): | |
item_name = os.path.splitext(item_name)[0] | |
hr_p = Image.fromarray(hr_p) | |
hr_g = Image.fromarray(hr_g) | |
hr_p.save(f"{gen_dir}/SR/{item_name}.png") | |
hr_g.save(f"{gen_dir}/HR/{item_name}.png") | |
exp_name = hparams['work_dir'].split('/')[-1] | |
sr_img_dir = f"{gen_dir}/SR/" | |
gt_img_dir = f"{gen_dir}/HR/" | |
excel_path = f"{hparams['work_dir']}/IQA-val-benchmark-{exp_name}.xlsx" | |
epoch = training_step | |
eval_img_IQA(gt_img_dir, sr_img_dir, excel_path, metric_list, epoch, data_name) | |
os.makedirs(f'{self.gen_dir}', exist_ok=True) | |
eval_json_path = os.path.join(self.gen_dir, 'eval.json') | |
avg_result = {k: round(self.results[k] / self.n_samples, 4) for k in self.results} | |
with open(eval_json_path, 'w+') as file: | |
json.dump(avg_result, file, sort_keys=True, indent=4, separators=(',', ': '), ensure_ascii=False) | |
json.dump(result_dict, file, sort_keys=True, indent=4, separators=(',', ': '), ensure_ascii=False) | |
def benchmark_loop(self, benchmark_name_list, metric_list, gt_path): | |
# infer and evaluation all save checkpoint | |
from sam_diffsr.tools.caculate_iqa import eval_img_IQA | |
model = self.build_model() | |
def get_checkpoint(model, checkpoint): | |
stat_dict = checkpoint['state_dict']['model'] | |
new_state_dict = OrderedDict() | |
for k, v in stat_dict.items(): | |
if k[:7] == 'module.': | |
k = k[7:] # 去掉 `module.` | |
new_state_dict[k] = v | |
model.load_state_dict(new_state_dict) | |
model.cuda() | |
training_step = checkpoint['global_step'] | |
del checkpoint | |
torch.cuda.empty_cache() | |
return training_step | |
ckpt_paths = get_all_ckpts(hparams['work_dir']) | |
for ckpt_path in ckpt_paths: | |
checkpoint = torch.load(ckpt_path, map_location='cpu') | |
training_step = get_checkpoint(model, checkpoint) | |
self.global_step = training_step | |
for data_name in benchmark_name_list: | |
test_dataloader = self.build_test_my_dataloader(data_name) | |
self.results = {k: 0 for k in self.metric_keys + self.metric_2_keys} | |
self.n_samples = 0 | |
self.gen_dir = f"{hparams['work_dir']}/results_{training_step}_{hparams['gen_dir_name']}/benchmark/{data_name}" | |
os.makedirs(f'{self.gen_dir}/outputs', exist_ok=True) | |
os.makedirs(f'{self.gen_dir}/SR', exist_ok=True) | |
self.model.sample_tqdm = False | |
torch.backends.cudnn.benchmark = False | |
with torch.no_grad(): | |
model.eval() | |
pbar = tqdm(enumerate(test_dataloader), total=len(test_dataloader)) | |
for batch_idx, batch in pbar: | |
move_to_cuda(batch) | |
gen_dir = self.gen_dir | |
item_names = batch['item_name'] | |
res = self.sample_and_test(batch) | |
if len(res) == 3: | |
img_sr, rrdb_out, ret = res | |
else: | |
img_sr, ret = res | |
rrdb_out = img_sr | |
img_sr = self.tensor2img(img_sr) | |
for item_name, hr_p in zip(item_names, img_sr): | |
item_name = os.path.splitext(item_name)[0] | |
hr_p = Image.fromarray(hr_p) | |
hr_p.save(f"{gen_dir}/SR/{item_name}.png") | |
exp_name = hparams['work_dir'].split('/')[-1] | |
sr_img_dir = f"{gen_dir}/SR/" | |
gt_img_dir = f"{gt_path}/{data_name}/HR" | |
excel_path = f"{hparams['work_dir']}/IQA-val-benchmark_loop-{exp_name}.xlsx" | |
epoch = training_step | |
eval_img_IQA(gt_img_dir, sr_img_dir, excel_path, metric_list, epoch, data_name) | |
# utils_sr | |
def log_metrics(self, metrics, step): | |
metrics = self.metrics_to_scalars(metrics) | |
logger = self.logger | |
for k, v in metrics.items(): | |
if isinstance(v, torch.Tensor): | |
v = v.item() | |
logger.add_scalar(k, v, step) | |
def metrics_to_scalars(self, metrics): | |
new_metrics = {} | |
for k, v in metrics.items(): | |
if isinstance(v, torch.Tensor): | |
v = v.item() | |
if type(v) is dict: | |
v = self.metrics_to_scalars(v) | |
new_metrics[k] = v | |
return new_metrics | |
def tensor2img(img): | |
img = np.round((img.permute(0, 2, 3, 1).cpu().numpy() + 1) * 127.5) | |
img = img.clip(min=0, max=255).astype(np.uint8) | |
return img | |
if __name__ == '__main__': | |
set_hparams() | |
pkg = ".".join(hparams["trainer_cls"].split(".")[:-1]) | |
cls_name = hparams["trainer_cls"].split(".")[-1] | |
trainer = getattr(importlib.import_module(pkg), cls_name)() | |
if hparams['benchmark_loop']: | |
trainer.benchmark_loop(hparams['benchmark_name_list'], hparams['metric_list'], hparams['gt_img_path']) | |
elif hparams['benchmark']: | |
trainer.benchmark(hparams['benchmark_name_list'], hparams['metric_list']) | |
else: | |
trainer.train() | |