# Copyright (c) OpenMMLab. All rights reserved. import warnings import torch from ..builder import DETECTORS from .single_stage import SingleStageDetector @DETECTORS.register_module() class DETR(SingleStageDetector): r"""Implementation of `DETR: End-to-End Object Detection with Transformers `_""" def __init__(self, backbone, bbox_head, train_cfg=None, test_cfg=None, pretrained=None, init_cfg=None): super(DETR, self).__init__(backbone, None, bbox_head, train_cfg, test_cfg, pretrained, init_cfg) # over-write `forward_dummy` because: # the forward of bbox_head requires img_metas def forward_dummy(self, img): """Used for computing network flops. See `mmdetection/tools/analysis_tools/get_flops.py` """ warnings.warn('Warning! MultiheadAttention in DETR does not ' 'support flops computation! Do not use the ' 'results in your papers!') batch_size, _, height, width = img.shape dummy_img_metas = [ dict( batch_input_shape=(height, width), img_shape=(height, width, 3)) for _ in range(batch_size) ] x = self.extract_feat(img) outs = self.bbox_head(x, dummy_img_metas) return outs # over-write `onnx_export` because: # (1) the forward of bbox_head requires img_metas # (2) the different behavior (e.g. construction of `masks`) between # torch and ONNX model, during the forward of bbox_head def onnx_export(self, img, img_metas): """Test function for exporting to ONNX, without test time augmentation. Args: img (torch.Tensor): input images. img_metas (list[dict]): List of image information. Returns: tuple[Tensor, Tensor]: dets of shape [N, num_det, 5] and class labels of shape [N, num_det]. """ x = self.extract_feat(img) # forward of this head requires img_metas outs = self.bbox_head.forward_onnx(x, img_metas) # get shape as tensor img_shape = torch._shape_as_tensor(img)[2:] img_metas[0]['img_shape_for_onnx'] = img_shape det_bboxes, det_labels = self.bbox_head.onnx_export(*outs, img_metas) return det_bboxes, det_labels