Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# Copyright (c) OpenMMLab. All rights reserved. | |
import argparse | |
import ast | |
import os | |
import os.path as osp | |
import mmcv | |
import numpy as np | |
import torch | |
from mmcv import Config | |
from mmcv.image import tensor2imgs | |
from mmcv.parallel import MMDataParallel | |
from mmcv.runner import load_checkpoint | |
from mmocr.datasets import build_dataloader, build_dataset | |
from mmocr.models import build_detector | |
def save_results(model, img_meta, gt_bboxes, result, out_dir): | |
assert 'filename' in img_meta, ('Please add "filename" ' | |
'to "meta_keys" in config.') | |
assert 'ori_texts' in img_meta, ('Please add "ori_texts" ' | |
'to "meta_keys" in config.') | |
out_json_file = osp.join(out_dir, | |
osp.basename(img_meta['filename']) + '.json') | |
idx_to_cls = {} | |
if model.module.class_list is not None: | |
for line in mmcv.list_from_file(model.module.class_list): | |
class_idx, class_label = line.strip().split() | |
idx_to_cls[int(class_idx)] = class_label | |
json_result = [{ | |
'text': | |
text, | |
'box': | |
box, | |
'pred': | |
idx_to_cls.get( | |
pred.argmax(-1).cpu().item(), | |
pred.argmax(-1).cpu().item()), | |
'conf': | |
pred.max(-1)[0].cpu().item() | |
} for text, box, pred in zip(img_meta['ori_texts'], gt_bboxes, | |
result['nodes'])] | |
mmcv.dump(json_result, out_json_file) | |
def test(model, data_loader, show=False, out_dir=None): | |
model.eval() | |
results = [] | |
dataset = data_loader.dataset | |
prog_bar = mmcv.ProgressBar(len(dataset)) | |
for i, data in enumerate(data_loader): | |
with torch.no_grad(): | |
result = model(return_loss=False, rescale=True, **data) | |
batch_size = len(result) | |
if show or out_dir: | |
img_tensor = data['img'].data[0] | |
img_metas = data['img_metas'].data[0] | |
if np.prod(img_tensor.shape) == 0: | |
imgs = [mmcv.imread(m['filename']) for m in img_metas] | |
else: | |
imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg']) | |
assert len(imgs) == len(img_metas) | |
gt_bboxes = [data['gt_bboxes'].data[0][0].numpy().tolist()] | |
for i, (img, img_meta) in enumerate(zip(imgs, img_metas)): | |
if 'img_shape' in img_meta: | |
h, w, _ = img_meta['img_shape'] | |
img_show = img[:h, :w, :] | |
else: | |
img_show = img | |
if out_dir: | |
out_file = osp.join(out_dir, | |
osp.basename(img_meta['filename'])) | |
else: | |
out_file = None | |
model.module.show_result( | |
img_show, | |
result[i], | |
gt_bboxes[i], | |
show=show, | |
out_file=out_file) | |
if out_dir: | |
save_results(model, img_meta, gt_bboxes[i], result[i], | |
out_dir) | |
for _ in range(batch_size): | |
prog_bar.update() | |
return results | |
def parse_args(): | |
parser = argparse.ArgumentParser( | |
description='MMOCR visualize for kie model.') | |
parser.add_argument('config', help='Test config file path.') | |
parser.add_argument('checkpoint', help='Checkpoint file.') | |
parser.add_argument('--show', action='store_true', help='Show results.') | |
parser.add_argument( | |
'--out-dir', | |
help='Directory where the output images and results will be saved.') | |
parser.add_argument('--local_rank', type=int, default=0) | |
parser.add_argument( | |
'--device', | |
help='Use int or int list for gpu. Default is cpu', | |
default=None) | |
args = parser.parse_args() | |
if 'LOCAL_RANK' not in os.environ: | |
os.environ['LOCAL_RANK'] = str(args.local_rank) | |
return args | |
def main(): | |
args = parse_args() | |
assert args.show or args.out_dir, ('Please specify at least one ' | |
'operation (show the results / save )' | |
'the results with the argument ' | |
'"--show" or "--out-dir".') | |
device = args.device | |
if device is not None: | |
device = ast.literal_eval(f'[{device}]') | |
cfg = Config.fromfile(args.config) | |
# import modules from string list. | |
if cfg.get('custom_imports', None): | |
from mmcv.utils import import_modules_from_strings | |
import_modules_from_strings(**cfg['custom_imports']) | |
# set cudnn_benchmark | |
if cfg.get('cudnn_benchmark', False): | |
torch.backends.cudnn.benchmark = True | |
distributed = False | |
# build the dataloader | |
dataset = build_dataset(cfg.data.test) | |
data_loader = build_dataloader( | |
dataset, | |
samples_per_gpu=1, | |
workers_per_gpu=cfg.data.workers_per_gpu, | |
dist=distributed, | |
shuffle=False) | |
# build the model and load checkpoint | |
cfg.model.train_cfg = None | |
model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg')) | |
load_checkpoint(model, args.checkpoint, map_location='cpu') | |
model = MMDataParallel(model, device_ids=device) | |
test(model, data_loader, args.show, args.out_dir) | |
if __name__ == '__main__': | |
main() | |