Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from ..builder import DETECTORS | |
from .two_stage import TwoStageDetector | |
class SparseRCNN(TwoStageDetector): | |
r"""Implementation of `Sparse R-CNN: End-to-End Object Detection with | |
Learnable Proposals <https://arxiv.org/abs/2011.12450>`_""" | |
def __init__(self, *args, **kwargs): | |
super(SparseRCNN, self).__init__(*args, **kwargs) | |
assert self.with_rpn, 'Sparse R-CNN and QueryInst ' \ | |
'do not support external proposals' | |
def forward_train(self, | |
img, | |
img_metas, | |
gt_bboxes, | |
gt_labels, | |
gt_bboxes_ignore=None, | |
gt_masks=None, | |
proposals=None, | |
**kwargs): | |
"""Forward function of SparseR-CNN and QueryInst in train stage. | |
Args: | |
img (Tensor): of shape (N, C, H, W) encoding input images. | |
Typically these should be mean centered and std scaled. | |
img_metas (list[dict]): list of image info dict where each dict | |
has: 'img_shape', 'scale_factor', 'flip', and may also contain | |
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. | |
For details on the values of these keys see | |
:class:`mmdet.datasets.pipelines.Collect`. | |
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with | |
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. | |
gt_labels (list[Tensor]): class indices corresponding to each box | |
gt_bboxes_ignore (None | list[Tensor): specify which bounding | |
boxes can be ignored when computing the loss. | |
gt_masks (List[Tensor], optional) : Segmentation masks for | |
each box. This is required to train QueryInst. | |
proposals (List[Tensor], optional): override rpn proposals with | |
custom proposals. Use when `with_rpn` is False. | |
Returns: | |
dict[str, Tensor]: a dictionary of loss components | |
""" | |
assert proposals is None, 'Sparse R-CNN and QueryInst ' \ | |
'do not support external proposals' | |
x = self.extract_feat(img) | |
proposal_boxes, proposal_features, imgs_whwh = \ | |
self.rpn_head.forward_train(x, img_metas) | |
roi_losses = self.roi_head.forward_train( | |
x, | |
proposal_boxes, | |
proposal_features, | |
img_metas, | |
gt_bboxes, | |
gt_labels, | |
gt_bboxes_ignore=gt_bboxes_ignore, | |
gt_masks=gt_masks, | |
imgs_whwh=imgs_whwh) | |
return roi_losses | |
def simple_test(self, img, img_metas, rescale=False): | |
"""Test function without test time augmentation. | |
Args: | |
imgs (list[torch.Tensor]): List of multiple images | |
img_metas (list[dict]): List of image information. | |
rescale (bool): Whether to rescale the results. | |
Defaults to False. | |
Returns: | |
list[list[np.ndarray]]: BBox results of each image and classes. | |
The outer list corresponds to each image. The inner list | |
corresponds to each class. | |
""" | |
x = self.extract_feat(img) | |
proposal_boxes, proposal_features, imgs_whwh = \ | |
self.rpn_head.simple_test_rpn(x, img_metas) | |
results = self.roi_head.simple_test( | |
x, | |
proposal_boxes, | |
proposal_features, | |
img_metas, | |
imgs_whwh=imgs_whwh, | |
rescale=rescale) | |
return results | |
def forward_dummy(self, img): | |
"""Used for computing network flops. | |
See `mmdetection/tools/analysis_tools/get_flops.py` | |
""" | |
# backbone | |
x = self.extract_feat(img) | |
# rpn | |
num_imgs = len(img) | |
dummy_img_metas = [ | |
dict(img_shape=(800, 1333, 3)) for _ in range(num_imgs) | |
] | |
proposal_boxes, proposal_features, imgs_whwh = \ | |
self.rpn_head.simple_test_rpn(x, dummy_img_metas) | |
# roi_head | |
roi_outs = self.roi_head.forward_dummy(x, proposal_boxes, | |
proposal_features, | |
dummy_img_metas) | |
return roi_outs | |