|
- # Copyright (c) OpenMMLab. All rights reserved.
- import warnings
-
- import torch
-
- from ..builder import DETECTORS
- from .single_stage import SingleStageDetector
-
-
- @DETECTORS.register_module()
- class DETR(SingleStageDetector):
- r"""Implementation of `DETR: End-to-End Object Detection with
- Transformers <https://arxiv.org/pdf/2005.12872>`_"""
-
- def __init__(self,
- backbone,
- bbox_head,
- train_cfg=None,
- test_cfg=None,
- pretrained=None,
- init_cfg=None):
- super(DETR, self).__init__(backbone, None, bbox_head, train_cfg,
- test_cfg, pretrained, init_cfg)
-
- # over-write `forward_dummy` because:
- # the forward of bbox_head requires img_metas
- def forward_dummy(self, img):
- """Used for computing network flops.
-
- See `mmdetection/tools/analysis_tools/get_flops.py`
- """
- warnings.warn('Warning! MultiheadAttention in DETR does not '
- 'support flops computation! Do not use the '
- 'results in your papers!')
-
- batch_size, _, height, width = img.shape
- dummy_img_metas = [
- dict(
- batch_input_shape=(height, width),
- img_shape=(height, width, 3)) for _ in range(batch_size)
- ]
- x = self.extract_feat(img)
- outs = self.bbox_head(x, dummy_img_metas)
- return outs
-
- # over-write `onnx_export` because:
- # (1) the forward of bbox_head requires img_metas
- # (2) the different behavior (e.g. construction of `masks`) between
- # torch and ONNX model, during the forward of bbox_head
- def onnx_export(self, img, img_metas):
- """Test function for exporting to ONNX, without test time augmentation.
-
- Args:
- img (torch.Tensor): input images.
- img_metas (list[dict]): List of image information.
-
- Returns:
- tuple[Tensor, Tensor]: dets of shape [N, num_det, 5]
- and class labels of shape [N, num_det].
- """
- x = self.extract_feat(img)
- # forward of this head requires img_metas
- outs = self.bbox_head.forward_onnx(x, img_metas)
- # get shape as tensor
- img_shape = torch._shape_as_tensor(img)[2:]
- img_metas[0]['img_shape_for_onnx'] = img_shape
-
- det_bboxes, det_labels = self.bbox_head.onnx_export(*outs, img_metas)
-
- return det_bboxes, det_labels
|