""" |
@Date: 2021/07/17 |
@description: |
""" |
import sys |
import os |
import shutil |
import argparse |
import numpy as np |
import json |
import torch |
import torch.nn.parallel |
import torch.optim |
import torch.multiprocessing as mp |
import torch.utils.data |
import torch.utils.data.distributed |
import torch.cuda |
from PIL import Image |
from tqdm import tqdm |
from torch.utils.tensorboard import SummaryWriter |
from config.defaults import get_config, get_rank_config |
from models.other.criterion import calc_criterion |
from models.build import build_model |
from models.other.init_env import init_env |
from utils.logger import build_logger |
from utils.misc import tensor2np_d, tensor2np |
from dataset.build import build_loader |
from evaluation.accuracy import calc_accuracy, show_heat_map, calc_ce, calc_pe, calc_rmse_delta_1, \ |
show_depth_normal_grad, calc_f1_score |
from postprocessing.post_process import post_process |
try: |
from apex import amp |
except ImportError: |
amp = None |
def parse_option(): |
debug = True if sys.gettrace() else False |
parser = argparse.ArgumentParser(description='Panorama Layout Transformer training and evaluation script') |
parser.add_argument('--cfg', |
type=str, |
metavar='FILE', |
help='path to config file') |
parser.add_argument('--mode', |
type=str, |
default='train', |
choices=['train', 'val', 'test'], |
help='train/val/test mode') |
parser.add_argument('--val_name', |
type=str, |
choices=['val', 'test'], |
help='val name') |
parser.add_argument('--bs', type=int, |
help='batch size') |
parser.add_argument('--save_eval', action='store_true', |
help='save eval result') |
parser.add_argument('--post_processing', type=str, |
choices=['manhattan', 'atalanta', 'manhattan_old'], |
help='type of postprocessing ') |
parser.add_argument('--need_cpe', action='store_true', |
help='need to evaluate corner error and pixel error') |
parser.add_argument('--need_f1', action='store_true', |
help='need to evaluate f1-score of corners') |
parser.add_argument('--need_rmse', action='store_true', |
help='need to evaluate root mean squared error and delta error') |
parser.add_argument('--force_cube', action='store_true', |
help='force cube shape when eval') |
parser.add_argument('--wall_num', type=int, |
help='wall number') |
args = parser.parse_args() |
args.debug = debug |
print("arguments:") |
for arg in vars(args): |
print(arg, ":", getattr(args, arg)) |
print("-" * 50) |
return args |
def main(): |
args = parse_option() |
config = get_config(args) |
if config.TRAIN.SCRATCH and os.path.exists(config.CKPT.DIR) and config.MODE == 'train': |
print(f"Train from scratch, delete checkpoint dir: {config.CKPT.DIR}") |
f = [int(f.split('_')[-1].split('.')[0]) for f in os.listdir(config.CKPT.DIR) if 'pkl' in f] |
if len(f) > 0: |
last_epoch = np.array(f).max() |
if last_epoch > 10: |
c = input(f"delete it (last_epoch: {last_epoch})?(Y/N)\n") |
if c != 'y' and c != 'Y': |
exit(0) |
shutil.rmtree(config.CKPT.DIR, ignore_errors=True) |
os.makedirs(config.CKPT.DIR, exist_ok=True) |
os.makedirs(config.CKPT.RESULT_DIR, exist_ok=True) |
os.makedirs(config.LOGGER.DIR, exist_ok=True) |
if ':' in config.TRAIN.DEVICE: |
nprocs = len(config.TRAIN.DEVICE.split(':')[-1].split(',')) |
if 'cuda' in config.TRAIN.DEVICE: |
if not torch.cuda.is_available(): |
print(f"Cuda is not available(config is: {config.TRAIN.DEVICE}), will use cpu ...") |
config.defrost() |
config.TRAIN.DEVICE = "cpu" |
config.freeze() |
nprocs = 1 |
if config.MODE == 'train': |
with open(os.path.join(config.CKPT.DIR, "config.yaml"), "w") as f: |
f.write(config.dump(allow_unicode=True)) |
if config.TRAIN.DEVICE == 'cpu' or nprocs < 2: |
print(f"Use single process, device:{config.TRAIN.DEVICE}") |
main_worker(0, config, 1) |
else: |
print(f"Use {nprocs} processes ...") |
mp.spawn(main_worker, nprocs=nprocs, args=(config, nprocs), join=True) |
def main_worker(local_rank, cfg, world_size): |
config = get_rank_config(cfg, local_rank, world_size) |
logger = build_logger(config) |
writer = SummaryWriter(config.CKPT.DIR) |
logger.info(f"Comment: {config.COMMENT}") |
cur_pid = os.getpid() |
logger.info(f"Current process id: {cur_pid}") |
torch.hub._hub_dir = config.CKPT.PYTORCH |
logger.info(f"Pytorch hub dir: {torch.hub._hub_dir}") |
init_env(config.SEED, config.TRAIN.DETERMINISTIC, config.DATA.NUM_WORKERS) |
model, optimizer, criterion, scheduler = build_model(config, logger) |
train_data_loader, val_data_loader = build_loader(config, logger) |
if 'cuda' in config.TRAIN.DEVICE: |
torch.cuda.set_device(config.TRAIN.DEVICE) |
if config.MODE == 'train': |
train(model, train_data_loader, val_data_loader, optimizer, criterion, config, logger, writer, scheduler) |
else: |
iou_results, other_results = val_an_epoch(model, val_data_loader, |
criterion, config, logger, writer=None, |
epoch=config.TRAIN.START_EPOCH) |
results = dict(iou_results, **other_results) |
if config.SAVE_EVAL: |
save_path = os.path.join(config.CKPT.RESULT_DIR, f"result.json") |
with open(save_path, 'w+') as f: |
json.dump(results, f, indent=4) |
def save(model, optimizer, epoch, iou_d, logger, writer, config): |
model.save(optimizer, epoch, accuracy=iou_d['full_3d'], logger=logger, acc_d=iou_d, config=config) |
for k in model.acc_d: |
writer.add_scalar(f"BestACC/{k}", model.acc_d[k]['acc'], epoch) |
def train(model, train_data_loader, val_data_loader, optimizer, criterion, config, logger, writer, scheduler): |
for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): |
logger.info("=" * 200) |
train_an_epoch(model, train_data_loader, optimizer, criterion, config, logger, writer, epoch) |
epoch_iou_d, _ = val_an_epoch(model, val_data_loader, criterion, config, logger, writer, epoch) |
if config.LOCAL_RANK == 0: |
ddp = config.WORLD_SIZE > 1 |
save(model.module if ddp else model, optimizer, epoch, epoch_iou_d, logger, writer, config) |
if scheduler is not None: |
if scheduler.min_lr is not None and optimizer.param_groups[0]['lr'] <= scheduler.min_lr: |
continue |
scheduler.step() |
writer.close() |
def train_an_epoch(model, train_data_loader, optimizer, criterion, config, logger, writer, epoch=0): |
logger.info(f'Start Train Epoch {epoch}/{config.TRAIN.EPOCHS - 1}') |
model.train() |
if len(config.MODEL.FINE_TUNE) > 0: |
model.feature_extractor.eval() |
optimizer.zero_grad() |
data_len = len(train_data_loader) |
start_i = data_len * epoch * config.WORLD_SIZE |
bar = enumerate(train_data_loader) |
if config.LOCAL_RANK == 0 and config.SHOW_BAR: |
bar = tqdm(bar, total=data_len, ncols=200) |
device = config.TRAIN.DEVICE |
epoch_loss_d = {} |
for i, gt in bar: |
imgs = gt['image'].to(device, non_blocking=True) |
gt['depth'] = gt['depth'].to(device, non_blocking=True) |
gt['ratio'] = gt['ratio'].to(device, non_blocking=True) |
if 'corner_heat_map' in gt: |
gt['corner_heat_map'] = gt['corner_heat_map'].to(device, non_blocking=True) |
if config.AMP_OPT_LEVEL != "O0" and 'cuda' in device: |
imgs = imgs.type(torch.float16) |
gt['depth'] = gt['depth'].type(torch.float16) |
gt['ratio'] = gt['ratio'].type(torch.float16) |
dt = model(imgs) |
loss, batch_loss_d, epoch_loss_d = calc_criterion(criterion, gt, dt, epoch_loss_d) |
if config.LOCAL_RANK == 0 and config.SHOW_BAR: |
bar.set_postfix(batch_loss_d) |
optimizer.zero_grad() |
if config.AMP_OPT_LEVEL != "O0" and 'cuda' in device: |
with amp.scale_loss(loss, optimizer) as scaled_loss: |
scaled_loss.backward() |
else: |
loss.backward() |
optimizer.step() |
global_step = start_i + i * config.WORLD_SIZE + config.LOCAL_RANK |
for key, val in batch_loss_d.items(): |
writer.add_scalar(f'TrainBatchLoss/{key}', val, global_step) |
if config.LOCAL_RANK != 0: |
return |
epoch_loss_d = dict(zip(epoch_loss_d.keys(), [np.array(epoch_loss_d[k]).mean() for k in epoch_loss_d.keys()])) |
s = 'TrainEpochLoss: ' |
for key, val in epoch_loss_d.items(): |
writer.add_scalar(f'TrainEpochLoss/{key}', val, epoch) |
s += f" {key}={val}" |
logger.info(s) |
writer.add_scalar('LearningRate', optimizer.param_groups[0]['lr'], epoch) |
logger.info(f"LearningRate: {optimizer.param_groups[0]['lr']}") |
@torch.no_grad() |
def val_an_epoch(model, val_data_loader, criterion, config, logger, writer, epoch=0): |
model.eval() |
logger.info(f'Start Validate Epoch {epoch}/{config.TRAIN.EPOCHS - 1}') |
data_len = len(val_data_loader) |
start_i = data_len * epoch * config.WORLD_SIZE |
bar = enumerate(val_data_loader) |
if config.LOCAL_RANK == 0 and config.SHOW_BAR: |
bar = tqdm(bar, total=data_len, ncols=200) |
device = config.TRAIN.DEVICE |
epoch_loss_d = {} |
epoch_iou_d = { |
'visible_2d': [], |
'visible_3d': [], |
'full_2d': [], |
'full_3d': [], |
'height': [] |
} |
epoch_other_d = { |
'ce': [], |
'pe': [], |
'f1': [], |
'precision': [], |
'recall': [], |
'rmse': [], |
'delta_1': [] |
} |
show_index = np.random.randint(0, data_len) |
for i, gt in bar: |
imgs = gt['image'].to(device, non_blocking=True) |
gt['depth'] = gt['depth'].to(device, non_blocking=True) |
gt['ratio'] = gt['ratio'].to(device, non_blocking=True) |
if 'corner_heat_map' in gt: |
gt['corner_heat_map'] = gt['corner_heat_map'].to(device, non_blocking=True) |
dt = model(imgs) |
vis_w = config.TRAIN.VIS_WEIGHT |
visualization = False |
loss, batch_loss_d, epoch_loss_d = calc_criterion(criterion, gt, dt, epoch_loss_d) |
if config.EVAL.POST_PROCESSING is not None: |
depth = tensor2np(dt['depth']) |
dt['processed_xyz'] = post_process(depth, type_name=config.EVAL.POST_PROCESSING, |
need_cube=config.EVAL.FORCE_CUBE) |
if config.EVAL.FORCE_CUBE and config.EVAL.NEED_CPE: |
ce = calc_ce(tensor2np_d(dt), tensor2np_d(gt)) |
pe = calc_pe(tensor2np_d(dt), tensor2np_d(gt)) |
epoch_other_d['ce'].append(ce) |
epoch_other_d['pe'].append(pe) |
if config.EVAL.NEED_F1: |
f1, precision, recall = calc_f1_score(tensor2np_d(dt), tensor2np_d(gt)) |
epoch_other_d['f1'].append(f1) |
epoch_other_d['precision'].append(precision) |
epoch_other_d['recall'].append(recall) |
if config.EVAL.NEED_RMSE: |
rmse, delta_1 = calc_rmse_delta_1(tensor2np_d(dt), tensor2np_d(gt)) |
epoch_other_d['rmse'].append(rmse) |
epoch_other_d['delta_1'].append(delta_1) |
visb_iou, full_iou, iou_height, pano_bds, full_iou_2ds = calc_accuracy(tensor2np_d(dt), tensor2np_d(gt), |
visualization, h=vis_w // 2) |
epoch_iou_d['visible_2d'].append(visb_iou[0]) |
epoch_iou_d['visible_3d'].append(visb_iou[1]) |
epoch_iou_d['full_2d'].append(full_iou[0]) |
epoch_iou_d['full_3d'].append(full_iou[1]) |
epoch_iou_d['height'].append(iou_height) |
if config.LOCAL_RANK == 0 and config.SHOW_BAR: |
bar.set_postfix(batch_loss_d) |
global_step = start_i + i * config.WORLD_SIZE + config.LOCAL_RANK |
if writer: |
for key, val in batch_loss_d.items(): |
writer.add_scalar(f'ValBatchLoss/{key}', val, global_step) |
if not visualization: |
continue |
gt_grad_imgs, dt_grad_imgs = show_depth_normal_grad(dt, gt, device, vis_w) |
dt_heat_map_imgs = None |
gt_heat_map_imgs = None |
if 'corner_heat_map' in gt: |
dt_heat_map_imgs, gt_heat_map_imgs = show_heat_map(dt, gt, vis_w) |
if config.TRAIN.VIS_MERGE or config.SAVE_EVAL: |
imgs = [] |
for j in range(len(pano_bds)): |
floorplan = full_iou[2][j] |
margin_w = int(floorplan.shape[-1] * (60/512)) |
floorplan = floorplan[:, :, margin_w:-margin_w] |
grad_h = dt_grad_imgs[0].shape[1] |
vis_merge = [ |
gt_grad_imgs[j], |
pano_bds[j][:, grad_h:-grad_h], |
dt_grad_imgs[j] |
] |
if 'corner_heat_map' in gt: |
vis_merge = [dt_heat_map_imgs[j], gt_heat_map_imgs[j]] + vis_merge |
img = np.concatenate(vis_merge, axis=-2) |
img = np.concatenate([img, ], axis=-1) |
imgs.append(img) |
if writer: |
writer.add_images('VIS/Merge', np.array(imgs), global_step) |
if config.SAVE_EVAL: |
for k in range(len(imgs)): |
img = imgs[k] * 255.0 |
save_path = os.path.join(config.CKPT.RESULT_DIR, f"{gt['id'][k]}_{full_iou_2ds[k]:.5f}.png") |
Image.fromarray(img.transpose(1, 2, 0).astype(np.uint8)).save(save_path) |
elif writer: |
writer.add_images('IoU/Visible_Floorplan', visb_iou[2], global_step) |
writer.add_images('IoU/Full_Floorplan', full_iou[2], global_step) |
writer.add_images('IoU/Boundary', pano_bds, global_step) |
writer.add_images('Grad/gt', gt_grad_imgs, global_step) |
writer.add_images('Grad/dt', dt_grad_imgs, global_step) |
if config.LOCAL_RANK != 0: |
return |
epoch_loss_d = dict(zip(epoch_loss_d.keys(), [np.array(epoch_loss_d[k]).mean() for k in epoch_loss_d.keys()])) |
s = 'ValEpochLoss: ' |
for key, val in epoch_loss_d.items(): |
if writer: |
writer.add_scalar(f'ValEpochLoss/{key}', val, epoch) |
s += f" {key}={val}" |
logger.info(s) |
epoch_iou_d = dict(zip(epoch_iou_d.keys(), [np.array(epoch_iou_d[k]).mean() for k in epoch_iou_d.keys()])) |
s = 'ValEpochIoU: ' |
for key, val in epoch_iou_d.items(): |
if writer: |
writer.add_scalar(f'ValEpochIoU/{key}', val, epoch) |
s += f" {key}={val}" |
logger.info(s) |
epoch_other_d = dict(zip(epoch_other_d.keys(), |
[np.array(epoch_other_d[k]).mean() if len(epoch_other_d[k]) > 0 else 0 for k in |
epoch_other_d.keys()])) |
logger.info(f'other acc: {epoch_other_d}') |
return epoch_iou_d, epoch_other_d |
if __name__ == '__main__': |
main() |