|
""" |
|
@Date: 2021/09/19 |
|
@description: |
|
""" |
|
import json |
|
import os |
|
import argparse |
|
import cv2 |
|
import numpy as np |
|
import torch |
|
import matplotlib.pyplot as plt |
|
import glob |
|
|
|
from tqdm import tqdm |
|
from PIL import Image |
|
from config.defaults import merge_from_file, get_config |
|
from dataset.mp3d_dataset import MP3DDataset |
|
from dataset.zind_dataset import ZindDataset |
|
from models.build import build_model |
|
from loss import GradLoss |
|
from postprocessing.post_process import post_process |
|
from preprocessing.pano_lsd_align import panoEdgeDetection, rotatePanorama |
|
from utils.boundary import corners2boundaries, layout2depth |
|
from utils.conversion import depth2xyz |
|
from utils.logger import get_logger |
|
from utils.misc import tensor2np_d, tensor2np |
|
from evaluation.accuracy import show_grad |
|
from models.lgt_net import LGT_Net |
|
from utils.writer import xyz2json |
|
from visualization.boundary import draw_boundaries |
|
from visualization.floorplan import draw_floorplan, draw_iou_floorplan |
|
from visualization.obj3d import create_3d_obj |
|
|
|
|
|
def parse_option(): |
|
parser = argparse.ArgumentParser(description='Panorama Layout Transformer training and evaluation script') |
|
parser.add_argument('--img_glob', |
|
type=str, |
|
required=True, |
|
help='image glob path') |
|
|
|
parser.add_argument('--cfg', |
|
type=str, |
|
required=True, |
|
metavar='FILE', |
|
help='path of config file') |
|
|
|
parser.add_argument('--post_processing', |
|
type=str, |
|
default='manhattan', |
|
choices=['manhattan', 'atalanta', 'original'], |
|
help='post-processing type') |
|
|
|
parser.add_argument('--output_dir', |
|
type=str, |
|
default='src/output', |
|
help='path of output') |
|
|
|
parser.add_argument('--visualize_3d', action='store_true', |
|
help='visualize_3d') |
|
|
|
parser.add_argument('--output_3d', action='store_true', |
|
help='output_3d') |
|
|
|
parser.add_argument('--device', |
|
type=str, |
|
default='cuda', |
|
help='device') |
|
|
|
args = parser.parse_args() |
|
args.mode = 'test' |
|
|
|
print("arguments:") |
|
for arg in vars(args): |
|
print(arg, ":", getattr(args, arg)) |
|
print("-" * 50) |
|
return args |
|
|
|
|
|
def visualize_2d(img, dt, show_depth=True, show_floorplan=True, show=False, save_path=None): |
|
dt_np = tensor2np_d(dt) |
|
dt_depth = dt_np['depth'][0] |
|
dt_xyz = depth2xyz(np.abs(dt_depth)) |
|
dt_ratio = dt_np['ratio'][0][0] |
|
dt_boundaries = corners2boundaries(dt_ratio, corners_xyz=dt_xyz, step=None, visible=False, length=img.shape[1]) |
|
vis_img = draw_boundaries(img, boundary_list=dt_boundaries, boundary_color=[0, 1, 0]) |
|
|
|
if 'processed_xyz' in dt: |
|
dt_boundaries = corners2boundaries(dt_ratio, corners_xyz=dt['processed_xyz'][0], step=None, visible=False, |
|
length=img.shape[1]) |
|
vis_img = draw_boundaries(vis_img, boundary_list=dt_boundaries, boundary_color=[1, 0, 0]) |
|
|
|
if show_depth: |
|
dt_grad_img = show_depth_normal_grad(dt) |
|
grad_h = dt_grad_img.shape[0] |
|
vis_merge = [ |
|
vis_img[0:-grad_h, :, :], |
|
dt_grad_img, |
|
] |
|
vis_img = np.concatenate(vis_merge, axis=0) |
|
|
|
|
|
if show_floorplan: |
|
if 'processed_xyz' in dt: |
|
floorplan = draw_iou_floorplan(dt['processed_xyz'][0][..., ::2], dt_xyz[..., ::2], |
|
dt_board_color=[1, 0, 0, 1], gt_board_color=[0, 1, 0, 1]) |
|
else: |
|
floorplan = show_alpha_floorplan(dt_xyz, border_color=[0, 1, 0, 1]) |
|
|
|
vis_img = np.concatenate([vis_img, floorplan[:, 60:-60, :]], axis=1) |
|
if show: |
|
plt.imshow(vis_img) |
|
plt.show() |
|
if save_path: |
|
result = Image.fromarray((vis_img * 255).astype(np.uint8)) |
|
result.save(save_path) |
|
return vis_img |
|
|
|
|
|
def preprocess(img_ori, q_error=0.7, refine_iter=3, vp_cache_path=None): |
|
|
|
if os.path.exists(vp_cache_path): |
|
with open(vp_cache_path) as f: |
|
vp = [[float(v) for v in line.rstrip().split(' ')] for line in f.readlines()] |
|
vp = np.array(vp) |
|
else: |
|
|
|
_, vp, _, _, _, _, _ = panoEdgeDetection(img_ori, |
|
qError=q_error, |
|
refineIter=refine_iter) |
|
i_img = rotatePanorama(img_ori, vp[2::-1]) |
|
|
|
if vp_cache_path is not None: |
|
with open(vp_cache_path, 'w') as f: |
|
for i in range(3): |
|
f.write('%.6f %.6f %.6f\n' % (vp[i, 0], vp[i, 1], vp[i, 2])) |
|
|
|
return i_img, vp |
|
|
|
|
|
def show_depth_normal_grad(dt): |
|
grad_conv = GradLoss().to(dt['depth'].device).grad_conv |
|
dt_grad_img = show_grad(dt['depth'][0], grad_conv, 50) |
|
dt_grad_img = cv2.resize(dt_grad_img, (1024, 60), interpolation=cv2.INTER_NEAREST) |
|
return dt_grad_img |
|
|
|
|
|
def show_alpha_floorplan(dt_xyz, side_l=512, border_color=None): |
|
if border_color is None: |
|
border_color = [1, 0, 0, 1] |
|
fill_color = [0.2, 0.2, 0.2, 0.2] |
|
dt_floorplan = draw_floorplan(xz=dt_xyz[..., ::2], fill_color=fill_color, |
|
border_color=border_color, side_l=side_l, show=False, center_color=[1, 0, 0, 1]) |
|
dt_floorplan = Image.fromarray((dt_floorplan * 255).astype(np.uint8), mode='RGBA') |
|
back = np.zeros([side_l, side_l, len(fill_color)], dtype=np.float32) |
|
back[..., :] = [0.8, 0.8, 0.8, 1] |
|
back = Image.fromarray((back * 255).astype(np.uint8), mode='RGBA') |
|
iou_floorplan = Image.alpha_composite(back, dt_floorplan).convert("RGB") |
|
dt_floorplan = np.array(iou_floorplan) / 255.0 |
|
return dt_floorplan |
|
|
|
|
|
def save_pred_json(xyz, ration, save_path): |
|
|
|
json_data = xyz2json(xyz, ration) |
|
with open(save_path, 'w') as f: |
|
f.write(json.dumps(json_data, indent=4) + '\n') |
|
return json_data |
|
|
|
|
|
def inference(): |
|
if len(img_paths) == 0: |
|
logger.error('No images found') |
|
return |
|
|
|
bar = tqdm(img_paths, ncols=100) |
|
for img_path in bar: |
|
if not os.path.isfile(img_path): |
|
logger.error(f'The {img_path} not is file') |
|
continue |
|
name = os.path.basename(img_path).split('.')[0] |
|
bar.set_description(name) |
|
img = np.array(Image.open(img_path).resize((1024, 512), Image.Resampling.BICUBIC))[..., :3] |
|
if args.post_processing is not None and 'manhattan' in args.post_processing: |
|
bar.set_description("Preprocessing") |
|
img, vp = preprocess(img, vp_cache_path=os.path.join(args.output_dir, f"{name}_vp.txt")) |
|
|
|
img = (img / 255.0).astype(np.float32) |
|
run_one_inference(img, model, args, name) |
|
|
|
|
|
def inference_dataset(dataset): |
|
bar = tqdm(dataset, ncols=100) |
|
for data in bar: |
|
bar.set_description(data['id']) |
|
run_one_inference(data['image'].transpose(1, 2, 0), model, args, name=data['id'], logger=logger) |
|
|
|
|
|
@torch.no_grad() |
|
def run_one_inference(img, model, args, name, logger, show=True, show_depth=True, |
|
show_floorplan=True, mesh_format='.gltf', mesh_resolution=512): |
|
model.eval() |
|
logger.info("model inference...") |
|
dt = model(torch.from_numpy(img.transpose(2, 0, 1)[None]).to(args.device)) |
|
if args.post_processing != 'original': |
|
logger.info(f"post-processing, type:{args.post_processing}...") |
|
dt['processed_xyz'] = post_process(tensor2np(dt['depth']), type_name=args.post_processing) |
|
|
|
visualize_2d(img, dt, |
|
show_depth=show_depth, |
|
show_floorplan=show_floorplan, |
|
show=show, |
|
save_path=os.path.join(args.output_dir, f"{name}_pred.png")) |
|
output_xyz = dt['processed_xyz'][0] if 'processed_xyz' in dt else depth2xyz(tensor2np(dt['depth'][0])) |
|
|
|
logger.info(f"saving predicted layout json...") |
|
json_data = save_pred_json(output_xyz, tensor2np(dt['ratio'][0])[0], |
|
save_path=os.path.join(args.output_dir, f"{name}_pred.json")) |
|
|
|
|
|
|
|
|
|
if args.visualize_3d or args.output_3d: |
|
dt_boundaries = corners2boundaries(tensor2np(dt['ratio'][0])[0], corners_xyz=output_xyz, step=None, |
|
length=mesh_resolution if 'processed_xyz' in dt else None, |
|
visible=True if 'processed_xyz' in dt else False) |
|
dt_layout_depth = layout2depth(dt_boundaries, show=False) |
|
|
|
logger.info(f"creating 3d mesh ...") |
|
create_3d_obj(cv2.resize(img, dt_layout_depth.shape[::-1]), dt_layout_depth, |
|
save_path=os.path.join(args.output_dir, f"{name}_3d{mesh_format}") if args.output_3d else None, |
|
mesh=True, show=args.visualize_3d) |
|
|
|
|
|
if __name__ == '__main__': |
|
logger = get_logger() |
|
args = parse_option() |
|
config = get_config(args) |
|
|
|
if ('cuda' in args.device or 'cuda' in config.TRAIN.DEVICE) and not torch.cuda.is_available(): |
|
logger.info(f'The {args.device} is not available, will use cpu ...') |
|
config.defrost() |
|
args.device = "cpu" |
|
config.TRAIN.DEVICE = "cpu" |
|
config.freeze() |
|
|
|
model, _, _, _ = build_model(config, logger) |
|
os.makedirs(args.output_dir, exist_ok=True) |
|
img_paths = sorted(glob.glob(args.img_glob)) |
|
|
|
inference() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|