|
- # Copyright (c) OpenMMLab. All rights reserved.
- # Copyright (c) 2019 Western Digital Corporation or its affiliates.
- import torch
-
- from ..builder import DETECTORS
- from .single_stage import SingleStageDetector
-
-
- @DETECTORS.register_module()
- class YOLOV3(SingleStageDetector):
-
- def __init__(self,
- backbone,
- neck,
- bbox_head,
- train_cfg=None,
- test_cfg=None,
- pretrained=None,
- init_cfg=None):
- super(YOLOV3, self).__init__(backbone, neck, bbox_head, train_cfg,
- test_cfg, pretrained, init_cfg)
-
- def onnx_export(self, img, img_metas):
- """Test function for exporting to ONNX, without test time augmentation.
-
- Args:
- img (torch.Tensor): input images.
- img_metas (list[dict]): List of image information.
-
- Returns:
- tuple[Tensor, Tensor]: dets of shape [N, num_det, 5]
- and class labels of shape [N, num_det].
- """
- x = self.extract_feat(img)
- outs = self.bbox_head.forward(x)
- # get shape as tensor
- img_shape = torch._shape_as_tensor(img)[2:]
- img_metas[0]['img_shape_for_onnx'] = img_shape
-
- det_bboxes, det_labels = self.bbox_head.onnx_export(*outs, img_metas)
-
- return det_bboxes, det_labels
|