Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import copy | |
import mmcv | |
import numpy as np | |
from mmdet.core import INSTANCE_OFFSET, bbox2result | |
from mmdet.core.visualization import imshow_det_bboxes | |
from ..builder import DETECTORS, build_backbone, build_head, build_neck | |
from .single_stage import SingleStageDetector | |
class MaskFormer(SingleStageDetector): | |
r"""Implementation of `Per-Pixel Classification is | |
NOT All You Need for Semantic Segmentation | |
<https://arxiv.org/pdf/2107.06278>`_.""" | |
def __init__(self, | |
backbone, | |
neck=None, | |
panoptic_head=None, | |
panoptic_fusion_head=None, | |
train_cfg=None, | |
test_cfg=None, | |
init_cfg=None): | |
super(SingleStageDetector, self).__init__(init_cfg=init_cfg) | |
self.backbone = build_backbone(backbone) | |
if neck is not None: | |
self.neck = build_neck(neck) | |
panoptic_head_ = copy.deepcopy(panoptic_head) | |
panoptic_head_.update(train_cfg=train_cfg) | |
panoptic_head_.update(test_cfg=test_cfg) | |
self.panoptic_head = build_head(panoptic_head_) | |
panoptic_fusion_head_ = copy.deepcopy(panoptic_fusion_head) | |
panoptic_fusion_head_.update(test_cfg=test_cfg) | |
self.panoptic_fusion_head = build_head(panoptic_fusion_head_) | |
self.num_things_classes = self.panoptic_head.num_things_classes | |
self.num_stuff_classes = self.panoptic_head.num_stuff_classes | |
self.num_classes = self.panoptic_head.num_classes | |
self.train_cfg = train_cfg | |
self.test_cfg = test_cfg | |
# BaseDetector.show_result default for instance segmentation | |
if self.num_stuff_classes > 0: | |
self.show_result = self._show_pan_result | |
def forward_dummy(self, img, img_metas): | |
"""Used for computing network flops. See | |
`mmdetection/tools/analysis_tools/get_flops.py` | |
Args: | |
img (Tensor): of shape (N, C, H, W) encoding input images. | |
Typically these should be mean centered and std scaled. | |
img_metas (list[Dict]): list of image info dict where each dict | |
has: 'img_shape', 'scale_factor', 'flip', and may also contain | |
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. | |
For details on the values of these keys see | |
`mmdet/datasets/pipelines/formatting.py:Collect`. | |
""" | |
super(SingleStageDetector, self).forward_train(img, img_metas) | |
x = self.extract_feat(img) | |
outs = self.panoptic_head(x, img_metas) | |
return outs | |
def forward_train(self, | |
img, | |
img_metas, | |
gt_bboxes, | |
gt_labels, | |
gt_masks, | |
gt_semantic_seg=None, | |
gt_bboxes_ignore=None, | |
**kargs): | |
""" | |
Args: | |
img (Tensor): of shape (N, C, H, W) encoding input images. | |
Typically these should be mean centered and std scaled. | |
img_metas (list[Dict]): list of image info dict where each dict | |
has: 'img_shape', 'scale_factor', 'flip', and may also contain | |
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. | |
For details on the values of these keys see | |
`mmdet/datasets/pipelines/formatting.py:Collect`. | |
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with | |
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. | |
gt_labels (list[Tensor]): class indices corresponding to each box. | |
gt_masks (list[BitmapMasks]): true segmentation masks for each box | |
used if the architecture supports a segmentation task. | |
gt_semantic_seg (list[tensor]): semantic segmentation mask for | |
images for panoptic segmentation. | |
Defaults to None for instance segmentation. | |
gt_bboxes_ignore (list[Tensor]): specify which bounding | |
boxes can be ignored when computing the loss. | |
Defaults to None. | |
Returns: | |
dict[str, Tensor]: a dictionary of loss components | |
""" | |
# add batch_input_shape in img_metas | |
super(SingleStageDetector, self).forward_train(img, img_metas) | |
x = self.extract_feat(img) | |
losses = self.panoptic_head.forward_train(x, img_metas, gt_bboxes, | |
gt_labels, gt_masks, | |
gt_semantic_seg, | |
gt_bboxes_ignore) | |
return losses | |
def simple_test(self, imgs, img_metas, **kwargs): | |
"""Test without augmentation. | |
Args: | |
imgs (Tensor): A batch of images. | |
img_metas (list[dict]): List of image information. | |
Returns: | |
list[dict[str, np.array | tuple[list]] | tuple[list]]: | |
Semantic segmentation results and panoptic segmentation \ | |
results of each image for panoptic segmentation, or formatted \ | |
bbox and mask results of each image for instance segmentation. | |
.. code-block:: none | |
[ | |
# panoptic segmentation | |
{ | |
'pan_results': np.array, # shape = [h, w] | |
'ins_results': tuple[list], | |
# semantic segmentation results are not supported yet | |
'sem_results': np.array | |
}, | |
... | |
] | |
or | |
.. code-block:: none | |
[ | |
# instance segmentation | |
( | |
bboxes, # list[np.array] | |
masks # list[list[np.array]] | |
), | |
... | |
] | |
""" | |
feats = self.extract_feat(imgs) | |
mask_cls_results, mask_pred_results = self.panoptic_head.simple_test( | |
feats, img_metas, **kwargs) | |
results = self.panoptic_fusion_head.simple_test( | |
mask_cls_results, mask_pred_results, img_metas, **kwargs) | |
for i in range(len(results)): | |
if 'pan_results' in results[i]: | |
results[i]['pan_results'] = results[i]['pan_results'].detach( | |
).cpu().numpy() | |
if 'ins_results' in results[i]: | |
labels_per_image, bboxes, mask_pred_binary = results[i][ | |
'ins_results'] | |
bbox_results = bbox2result(bboxes, labels_per_image, | |
self.num_things_classes) | |
mask_results = [[] for _ in range(self.num_things_classes)] | |
for j, label in enumerate(labels_per_image): | |
mask = mask_pred_binary[j].detach().cpu().numpy() | |
mask_results[label].append(mask) | |
results[i]['ins_results'] = bbox_results, mask_results | |
assert 'sem_results' not in results[i], 'segmantic segmentation '\ | |
'results are not supported yet.' | |
if self.num_stuff_classes == 0: | |
results = [res['ins_results'] for res in results] | |
return results | |
def aug_test(self, imgs, img_metas, **kwargs): | |
raise NotImplementedError | |
def onnx_export(self, img, img_metas): | |
raise NotImplementedError | |
def _show_pan_result(self, | |
img, | |
result, | |
score_thr=0.3, | |
bbox_color=(72, 101, 241), | |
text_color=(72, 101, 241), | |
mask_color=None, | |
thickness=2, | |
font_size=13, | |
win_name='', | |
show=False, | |
wait_time=0, | |
out_file=None): | |
"""Draw `panoptic result` over `img`. | |
Args: | |
img (str or Tensor): The image to be displayed. | |
result (dict): The results. | |
score_thr (float, optional): Minimum score of bboxes to be shown. | |
Default: 0.3. | |
bbox_color (str or tuple(int) or :obj:`Color`):Color of bbox lines. | |
The tuple of color should be in BGR order. Default: 'green'. | |
text_color (str or tuple(int) or :obj:`Color`):Color of texts. | |
The tuple of color should be in BGR order. Default: 'green'. | |
mask_color (None or str or tuple(int) or :obj:`Color`): | |
Color of masks. The tuple of color should be in BGR order. | |
Default: None. | |
thickness (int): Thickness of lines. Default: 2. | |
font_size (int): Font size of texts. Default: 13. | |
win_name (str): The window name. Default: ''. | |
wait_time (float): Value of waitKey param. | |
Default: 0. | |
show (bool): Whether to show the image. | |
Default: False. | |
out_file (str or None): The filename to write the image. | |
Default: None. | |
Returns: | |
img (Tensor): Only if not `show` or `out_file`. | |
""" | |
img = mmcv.imread(img) | |
img = img.copy() | |
pan_results = result['pan_results'] | |
# keep objects ahead | |
ids = np.unique(pan_results)[::-1] | |
legal_indices = ids != self.num_classes # for VOID label | |
ids = ids[legal_indices] | |
labels = np.array([id % INSTANCE_OFFSET for id in ids], dtype=np.int64) | |
segms = (pan_results[None] == ids[:, None, None]) | |
# if out_file specified, do not show image in window | |
if out_file is not None: | |
show = False | |
# draw bounding boxes | |
img = imshow_det_bboxes( | |
img, | |
segms=segms, | |
labels=labels, | |
class_names=self.CLASSES, | |
bbox_color=bbox_color, | |
text_color=text_color, | |
mask_color=mask_color, | |
thickness=thickness, | |
font_size=font_size, | |
win_name=win_name, | |
show=show, | |
wait_time=wait_time, | |
out_file=out_file) | |
if not (show or out_file): | |
return img | |