|
""" |
|
@Date: 2021/07/17 |
|
@description: |
|
""" |
|
import sys |
|
import os |
|
import shutil |
|
import argparse |
|
import numpy as np |
|
import json |
|
import torch |
|
import torch.nn.parallel |
|
import torch.optim |
|
import torch.multiprocessing as mp |
|
import torch.utils.data |
|
import torch.utils.data.distributed |
|
import torch.cuda |
|
|
|
from PIL import Image |
|
from tqdm import tqdm |
|
from torch.utils.tensorboard import SummaryWriter |
|
from config.defaults import get_config, get_rank_config |
|
from models.other.criterion import calc_criterion |
|
from models.build import build_model |
|
from models.other.init_env import init_env |
|
from utils.logger import build_logger |
|
from utils.misc import tensor2np_d, tensor2np |
|
from dataset.build import build_loader |
|
from evaluation.accuracy import calc_accuracy, show_heat_map, calc_ce, calc_pe, calc_rmse_delta_1, \ |
|
show_depth_normal_grad, calc_f1_score |
|
from postprocessing.post_process import post_process |
|
|
|
try: |
|
from apex import amp |
|
except ImportError: |
|
amp = None |
|
|
|
|
|
def parse_option(): |
|
debug = True if sys.gettrace() else False |
|
parser = argparse.ArgumentParser(description='Panorama Layout Transformer training and evaluation script') |
|
parser.add_argument('--cfg', |
|
type=str, |
|
metavar='FILE', |
|
help='path to config file') |
|
|
|
parser.add_argument('--mode', |
|
type=str, |
|
default='train', |
|
choices=['train', 'val', 'test'], |
|
help='train/val/test mode') |
|
|
|
parser.add_argument('--val_name', |
|
type=str, |
|
choices=['val', 'test'], |
|
help='val name') |
|
|
|
parser.add_argument('--bs', type=int, |
|
help='batch size') |
|
|
|
parser.add_argument('--save_eval', action='store_true', |
|
help='save eval result') |
|
|
|
parser.add_argument('--post_processing', type=str, |
|
choices=['manhattan', 'atalanta', 'manhattan_old'], |
|
help='type of postprocessing ') |
|
|
|
parser.add_argument('--need_cpe', action='store_true', |
|
help='need to evaluate corner error and pixel error') |
|
|
|
parser.add_argument('--need_f1', action='store_true', |
|
help='need to evaluate f1-score of corners') |
|
|
|
parser.add_argument('--need_rmse', action='store_true', |
|
help='need to evaluate root mean squared error and delta error') |
|
|
|
parser.add_argument('--force_cube', action='store_true', |
|
help='force cube shape when eval') |
|
|
|
parser.add_argument('--wall_num', type=int, |
|
help='wall number') |
|
|
|
args = parser.parse_args() |
|
args.debug = debug |
|
print("arguments:") |
|
for arg in vars(args): |
|
print(arg, ":", getattr(args, arg)) |
|
print("-" * 50) |
|
return args |
|
|
|
|
|
def main(): |
|
args = parse_option() |
|
config = get_config(args) |
|
|
|
if config.TRAIN.SCRATCH and os.path.exists(config.CKPT.DIR) and config.MODE == 'train': |
|
print(f"Train from scratch, delete checkpoint dir: {config.CKPT.DIR}") |
|
f = [int(f.split('_')[-1].split('.')[0]) for f in os.listdir(config.CKPT.DIR) if 'pkl' in f] |
|
if len(f) > 0: |
|
last_epoch = np.array(f).max() |
|
if last_epoch > 10: |
|
c = input(f"delete it (last_epoch: {last_epoch})?(Y/N)\n") |
|
if c != 'y' and c != 'Y': |
|
exit(0) |
|
|
|
shutil.rmtree(config.CKPT.DIR, ignore_errors=True) |
|
|
|
os.makedirs(config.CKPT.DIR, exist_ok=True) |
|
os.makedirs(config.CKPT.RESULT_DIR, exist_ok=True) |
|
os.makedirs(config.LOGGER.DIR, exist_ok=True) |
|
|
|
if ':' in config.TRAIN.DEVICE: |
|
nprocs = len(config.TRAIN.DEVICE.split(':')[-1].split(',')) |
|
if 'cuda' in config.TRAIN.DEVICE: |
|
if not torch.cuda.is_available(): |
|
print(f"Cuda is not available(config is: {config.TRAIN.DEVICE}), will use cpu ...") |
|
config.defrost() |
|
config.TRAIN.DEVICE = "cpu" |
|
config.freeze() |
|
nprocs = 1 |
|
|
|
if config.MODE == 'train': |
|
with open(os.path.join(config.CKPT.DIR, "config.yaml"), "w") as f: |
|
f.write(config.dump(allow_unicode=True)) |
|
|
|
if config.TRAIN.DEVICE == 'cpu' or nprocs < 2: |
|
print(f"Use single process, device:{config.TRAIN.DEVICE}") |
|
main_worker(0, config, 1) |
|
else: |
|
print(f"Use {nprocs} processes ...") |
|
mp.spawn(main_worker, nprocs=nprocs, args=(config, nprocs), join=True) |
|
|
|
|
|
def main_worker(local_rank, cfg, world_size): |
|
config = get_rank_config(cfg, local_rank, world_size) |
|
logger = build_logger(config) |
|
writer = SummaryWriter(config.CKPT.DIR) |
|
logger.info(f"Comment: {config.COMMENT}") |
|
cur_pid = os.getpid() |
|
logger.info(f"Current process id: {cur_pid}") |
|
torch.hub._hub_dir = config.CKPT.PYTORCH |
|
logger.info(f"Pytorch hub dir: {torch.hub._hub_dir}") |
|
init_env(config.SEED, config.TRAIN.DETERMINISTIC, config.DATA.NUM_WORKERS) |
|
|
|
model, optimizer, criterion, scheduler = build_model(config, logger) |
|
train_data_loader, val_data_loader = build_loader(config, logger) |
|
|
|
if 'cuda' in config.TRAIN.DEVICE: |
|
torch.cuda.set_device(config.TRAIN.DEVICE) |
|
|
|
if config.MODE == 'train': |
|
train(model, train_data_loader, val_data_loader, optimizer, criterion, config, logger, writer, scheduler) |
|
else: |
|
iou_results, other_results = val_an_epoch(model, val_data_loader, |
|
criterion, config, logger, writer=None, |
|
epoch=config.TRAIN.START_EPOCH) |
|
results = dict(iou_results, **other_results) |
|
if config.SAVE_EVAL: |
|
save_path = os.path.join(config.CKPT.RESULT_DIR, f"result.json") |
|
with open(save_path, 'w+') as f: |
|
json.dump(results, f, indent=4) |
|
|
|
|
|
def save(model, optimizer, epoch, iou_d, logger, writer, config): |
|
model.save(optimizer, epoch, accuracy=iou_d['full_3d'], logger=logger, acc_d=iou_d, config=config) |
|
for k in model.acc_d: |
|
writer.add_scalar(f"BestACC/{k}", model.acc_d[k]['acc'], epoch) |
|
|
|
|
|
def train(model, train_data_loader, val_data_loader, optimizer, criterion, config, logger, writer, scheduler): |
|
for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): |
|
logger.info("=" * 200) |
|
train_an_epoch(model, train_data_loader, optimizer, criterion, config, logger, writer, epoch) |
|
epoch_iou_d, _ = val_an_epoch(model, val_data_loader, criterion, config, logger, writer, epoch) |
|
|
|
if config.LOCAL_RANK == 0: |
|
ddp = config.WORLD_SIZE > 1 |
|
save(model.module if ddp else model, optimizer, epoch, epoch_iou_d, logger, writer, config) |
|
|
|
if scheduler is not None: |
|
if scheduler.min_lr is not None and optimizer.param_groups[0]['lr'] <= scheduler.min_lr: |
|
continue |
|
scheduler.step() |
|
writer.close() |
|
|
|
|
|
def train_an_epoch(model, train_data_loader, optimizer, criterion, config, logger, writer, epoch=0): |
|
logger.info(f'Start Train Epoch {epoch}/{config.TRAIN.EPOCHS - 1}') |
|
model.train() |
|
|
|
if len(config.MODEL.FINE_TUNE) > 0: |
|
model.feature_extractor.eval() |
|
|
|
optimizer.zero_grad() |
|
|
|
data_len = len(train_data_loader) |
|
start_i = data_len * epoch * config.WORLD_SIZE |
|
bar = enumerate(train_data_loader) |
|
if config.LOCAL_RANK == 0 and config.SHOW_BAR: |
|
bar = tqdm(bar, total=data_len, ncols=200) |
|
|
|
device = config.TRAIN.DEVICE |
|
epoch_loss_d = {} |
|
for i, gt in bar: |
|
imgs = gt['image'].to(device, non_blocking=True) |
|
gt['depth'] = gt['depth'].to(device, non_blocking=True) |
|
gt['ratio'] = gt['ratio'].to(device, non_blocking=True) |
|
if 'corner_heat_map' in gt: |
|
gt['corner_heat_map'] = gt['corner_heat_map'].to(device, non_blocking=True) |
|
if config.AMP_OPT_LEVEL != "O0" and 'cuda' in device: |
|
imgs = imgs.type(torch.float16) |
|
gt['depth'] = gt['depth'].type(torch.float16) |
|
gt['ratio'] = gt['ratio'].type(torch.float16) |
|
dt = model(imgs) |
|
loss, batch_loss_d, epoch_loss_d = calc_criterion(criterion, gt, dt, epoch_loss_d) |
|
if config.LOCAL_RANK == 0 and config.SHOW_BAR: |
|
bar.set_postfix(batch_loss_d) |
|
|
|
optimizer.zero_grad() |
|
if config.AMP_OPT_LEVEL != "O0" and 'cuda' in device: |
|
with amp.scale_loss(loss, optimizer) as scaled_loss: |
|
scaled_loss.backward() |
|
else: |
|
loss.backward() |
|
optimizer.step() |
|
|
|
global_step = start_i + i * config.WORLD_SIZE + config.LOCAL_RANK |
|
for key, val in batch_loss_d.items(): |
|
writer.add_scalar(f'TrainBatchLoss/{key}', val, global_step) |
|
|
|
if config.LOCAL_RANK != 0: |
|
return |
|
|
|
epoch_loss_d = dict(zip(epoch_loss_d.keys(), [np.array(epoch_loss_d[k]).mean() for k in epoch_loss_d.keys()])) |
|
s = 'TrainEpochLoss: ' |
|
for key, val in epoch_loss_d.items(): |
|
writer.add_scalar(f'TrainEpochLoss/{key}', val, epoch) |
|
s += f" {key}={val}" |
|
logger.info(s) |
|
writer.add_scalar('LearningRate', optimizer.param_groups[0]['lr'], epoch) |
|
logger.info(f"LearningRate: {optimizer.param_groups[0]['lr']}") |
|
|
|
|
|
@torch.no_grad() |
|
def val_an_epoch(model, val_data_loader, criterion, config, logger, writer, epoch=0): |
|
model.eval() |
|
logger.info(f'Start Validate Epoch {epoch}/{config.TRAIN.EPOCHS - 1}') |
|
data_len = len(val_data_loader) |
|
start_i = data_len * epoch * config.WORLD_SIZE |
|
bar = enumerate(val_data_loader) |
|
if config.LOCAL_RANK == 0 and config.SHOW_BAR: |
|
bar = tqdm(bar, total=data_len, ncols=200) |
|
device = config.TRAIN.DEVICE |
|
epoch_loss_d = {} |
|
epoch_iou_d = { |
|
'visible_2d': [], |
|
'visible_3d': [], |
|
'full_2d': [], |
|
'full_3d': [], |
|
'height': [] |
|
} |
|
|
|
epoch_other_d = { |
|
'ce': [], |
|
'pe': [], |
|
'f1': [], |
|
'precision': [], |
|
'recall': [], |
|
'rmse': [], |
|
'delta_1': [] |
|
} |
|
|
|
show_index = np.random.randint(0, data_len) |
|
for i, gt in bar: |
|
imgs = gt['image'].to(device, non_blocking=True) |
|
gt['depth'] = gt['depth'].to(device, non_blocking=True) |
|
gt['ratio'] = gt['ratio'].to(device, non_blocking=True) |
|
if 'corner_heat_map' in gt: |
|
gt['corner_heat_map'] = gt['corner_heat_map'].to(device, non_blocking=True) |
|
dt = model(imgs) |
|
|
|
vis_w = config.TRAIN.VIS_WEIGHT |
|
visualization = False |
|
|
|
loss, batch_loss_d, epoch_loss_d = calc_criterion(criterion, gt, dt, epoch_loss_d) |
|
|
|
if config.EVAL.POST_PROCESSING is not None: |
|
depth = tensor2np(dt['depth']) |
|
dt['processed_xyz'] = post_process(depth, type_name=config.EVAL.POST_PROCESSING, |
|
need_cube=config.EVAL.FORCE_CUBE) |
|
|
|
if config.EVAL.FORCE_CUBE and config.EVAL.NEED_CPE: |
|
ce = calc_ce(tensor2np_d(dt), tensor2np_d(gt)) |
|
pe = calc_pe(tensor2np_d(dt), tensor2np_d(gt)) |
|
|
|
epoch_other_d['ce'].append(ce) |
|
epoch_other_d['pe'].append(pe) |
|
|
|
if config.EVAL.NEED_F1: |
|
f1, precision, recall = calc_f1_score(tensor2np_d(dt), tensor2np_d(gt)) |
|
epoch_other_d['f1'].append(f1) |
|
epoch_other_d['precision'].append(precision) |
|
epoch_other_d['recall'].append(recall) |
|
|
|
if config.EVAL.NEED_RMSE: |
|
rmse, delta_1 = calc_rmse_delta_1(tensor2np_d(dt), tensor2np_d(gt)) |
|
epoch_other_d['rmse'].append(rmse) |
|
epoch_other_d['delta_1'].append(delta_1) |
|
|
|
visb_iou, full_iou, iou_height, pano_bds, full_iou_2ds = calc_accuracy(tensor2np_d(dt), tensor2np_d(gt), |
|
visualization, h=vis_w // 2) |
|
epoch_iou_d['visible_2d'].append(visb_iou[0]) |
|
epoch_iou_d['visible_3d'].append(visb_iou[1]) |
|
epoch_iou_d['full_2d'].append(full_iou[0]) |
|
epoch_iou_d['full_3d'].append(full_iou[1]) |
|
epoch_iou_d['height'].append(iou_height) |
|
|
|
if config.LOCAL_RANK == 0 and config.SHOW_BAR: |
|
bar.set_postfix(batch_loss_d) |
|
|
|
global_step = start_i + i * config.WORLD_SIZE + config.LOCAL_RANK |
|
|
|
if writer: |
|
for key, val in batch_loss_d.items(): |
|
writer.add_scalar(f'ValBatchLoss/{key}', val, global_step) |
|
|
|
if not visualization: |
|
continue |
|
|
|
gt_grad_imgs, dt_grad_imgs = show_depth_normal_grad(dt, gt, device, vis_w) |
|
|
|
dt_heat_map_imgs = None |
|
gt_heat_map_imgs = None |
|
if 'corner_heat_map' in gt: |
|
dt_heat_map_imgs, gt_heat_map_imgs = show_heat_map(dt, gt, vis_w) |
|
|
|
if config.TRAIN.VIS_MERGE or config.SAVE_EVAL: |
|
imgs = [] |
|
for j in range(len(pano_bds)): |
|
|
|
floorplan = full_iou[2][j] |
|
margin_w = int(floorplan.shape[-1] * (60/512)) |
|
floorplan = floorplan[:, :, margin_w:-margin_w] |
|
|
|
grad_h = dt_grad_imgs[0].shape[1] |
|
vis_merge = [ |
|
gt_grad_imgs[j], |
|
pano_bds[j][:, grad_h:-grad_h], |
|
dt_grad_imgs[j] |
|
] |
|
if 'corner_heat_map' in gt: |
|
vis_merge = [dt_heat_map_imgs[j], gt_heat_map_imgs[j]] + vis_merge |
|
img = np.concatenate(vis_merge, axis=-2) |
|
|
|
img = np.concatenate([img, ], axis=-1) |
|
|
|
imgs.append(img) |
|
if writer: |
|
writer.add_images('VIS/Merge', np.array(imgs), global_step) |
|
|
|
if config.SAVE_EVAL: |
|
for k in range(len(imgs)): |
|
img = imgs[k] * 255.0 |
|
save_path = os.path.join(config.CKPT.RESULT_DIR, f"{gt['id'][k]}_{full_iou_2ds[k]:.5f}.png") |
|
Image.fromarray(img.transpose(1, 2, 0).astype(np.uint8)).save(save_path) |
|
|
|
elif writer: |
|
writer.add_images('IoU/Visible_Floorplan', visb_iou[2], global_step) |
|
writer.add_images('IoU/Full_Floorplan', full_iou[2], global_step) |
|
writer.add_images('IoU/Boundary', pano_bds, global_step) |
|
writer.add_images('Grad/gt', gt_grad_imgs, global_step) |
|
writer.add_images('Grad/dt', dt_grad_imgs, global_step) |
|
|
|
if config.LOCAL_RANK != 0: |
|
return |
|
|
|
epoch_loss_d = dict(zip(epoch_loss_d.keys(), [np.array(epoch_loss_d[k]).mean() for k in epoch_loss_d.keys()])) |
|
s = 'ValEpochLoss: ' |
|
for key, val in epoch_loss_d.items(): |
|
if writer: |
|
writer.add_scalar(f'ValEpochLoss/{key}', val, epoch) |
|
s += f" {key}={val}" |
|
logger.info(s) |
|
|
|
epoch_iou_d = dict(zip(epoch_iou_d.keys(), [np.array(epoch_iou_d[k]).mean() for k in epoch_iou_d.keys()])) |
|
s = 'ValEpochIoU: ' |
|
for key, val in epoch_iou_d.items(): |
|
if writer: |
|
writer.add_scalar(f'ValEpochIoU/{key}', val, epoch) |
|
s += f" {key}={val}" |
|
logger.info(s) |
|
epoch_other_d = dict(zip(epoch_other_d.keys(), |
|
[np.array(epoch_other_d[k]).mean() if len(epoch_other_d[k]) > 0 else 0 for k in |
|
epoch_other_d.keys()])) |
|
|
|
logger.info(f'other acc: {epoch_other_d}') |
|
return epoch_iou_d, epoch_other_d |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|