MMOCR / tools /kie_test_imgs.py
tomofi's picture
Add application file
2366e36
#!/usr/bin/env python
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import ast
import os
import os.path as osp
import mmcv
import numpy as np
import torch
from mmcv import Config
from mmcv.image import tensor2imgs
from mmcv.parallel import MMDataParallel
from mmcv.runner import load_checkpoint
from mmocr.datasets import build_dataloader, build_dataset
from mmocr.models import build_detector
def save_results(model, img_meta, gt_bboxes, result, out_dir):
assert 'filename' in img_meta, ('Please add "filename" '
'to "meta_keys" in config.')
assert 'ori_texts' in img_meta, ('Please add "ori_texts" '
'to "meta_keys" in config.')
out_json_file = osp.join(out_dir,
osp.basename(img_meta['filename']) + '.json')
idx_to_cls = {}
if model.module.class_list is not None:
for line in mmcv.list_from_file(model.module.class_list):
class_idx, class_label = line.strip().split()
idx_to_cls[int(class_idx)] = class_label
json_result = [{
'text':
text,
'box':
box,
'pred':
idx_to_cls.get(
pred.argmax(-1).cpu().item(),
pred.argmax(-1).cpu().item()),
'conf':
pred.max(-1)[0].cpu().item()
} for text, box, pred in zip(img_meta['ori_texts'], gt_bboxes,
result['nodes'])]
mmcv.dump(json_result, out_json_file)
def test(model, data_loader, show=False, out_dir=None):
model.eval()
results = []
dataset = data_loader.dataset
prog_bar = mmcv.ProgressBar(len(dataset))
for i, data in enumerate(data_loader):
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
batch_size = len(result)
if show or out_dir:
img_tensor = data['img'].data[0]
img_metas = data['img_metas'].data[0]
if np.prod(img_tensor.shape) == 0:
imgs = [mmcv.imread(m['filename']) for m in img_metas]
else:
imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
assert len(imgs) == len(img_metas)
gt_bboxes = [data['gt_bboxes'].data[0][0].numpy().tolist()]
for i, (img, img_meta) in enumerate(zip(imgs, img_metas)):
if 'img_shape' in img_meta:
h, w, _ = img_meta['img_shape']
img_show = img[:h, :w, :]
else:
img_show = img
if out_dir:
out_file = osp.join(out_dir,
osp.basename(img_meta['filename']))
else:
out_file = None
model.module.show_result(
img_show,
result[i],
gt_bboxes[i],
show=show,
out_file=out_file)
if out_dir:
save_results(model, img_meta, gt_bboxes[i], result[i],
out_dir)
for _ in range(batch_size):
prog_bar.update()
return results
def parse_args():
parser = argparse.ArgumentParser(
description='MMOCR visualize for kie model.')
parser.add_argument('config', help='Test config file path.')
parser.add_argument('checkpoint', help='Checkpoint file.')
parser.add_argument('--show', action='store_true', help='Show results.')
parser.add_argument(
'--out-dir',
help='Directory where the output images and results will be saved.')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument(
'--device',
help='Use int or int list for gpu. Default is cpu',
default=None)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def main():
args = parse_args()
assert args.show or args.out_dir, ('Please specify at least one '
'operation (show the results / save )'
'the results with the argument '
'"--show" or "--out-dir".')
device = args.device
if device is not None:
device = ast.literal_eval(f'[{device}]')
cfg = Config.fromfile(args.config)
# import modules from string list.
if cfg.get('custom_imports', None):
from mmcv.utils import import_modules_from_strings
import_modules_from_strings(**cfg['custom_imports'])
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
distributed = False
# build the dataloader
dataset = build_dataset(cfg.data.test)
data_loader = build_dataloader(
dataset,
samples_per_gpu=1,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=distributed,
shuffle=False)
# build the model and load checkpoint
cfg.model.train_cfg = None
model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
load_checkpoint(model, args.checkpoint, map_location='cpu')
model = MMDataParallel(model, device_ids=device)
test(model, data_loader, args.show, args.out_dir)
if __name__ == '__main__':
main()