|
- # Copyright (c) OpenMMLab. All rights reserved.
- import os.path as osp
-
- import mmcv
- import pytest
- import torch
-
- from mmdet import digit_version
- from mmdet.models.necks import FPN, YOLOV3Neck
- from .utils import ort_validate
-
- if digit_version(torch.__version__) <= digit_version('1.5.0'):
- pytest.skip(
- 'ort backend does not support version below 1.5.0',
- allow_module_level=True)
-
- # Control the returned model of fpn_neck_config()
- fpn_test_step_names = {
- 'fpn_normal': 0,
- 'fpn_wo_extra_convs': 1,
- 'fpn_lateral_bns': 2,
- 'fpn_bilinear_upsample': 3,
- 'fpn_scale_factor': 4,
- 'fpn_extra_convs_inputs': 5,
- 'fpn_extra_convs_laterals': 6,
- 'fpn_extra_convs_outputs': 7,
- }
-
- # Control the returned model of yolo_neck_config()
- yolo_test_step_names = {'yolo_normal': 0}
-
- data_path = osp.join(osp.dirname(__file__), 'data')
-
-
- def fpn_neck_config(test_step_name):
- """Return the class containing the corresponding attributes according to
- the fpn_test_step_names."""
- s = 64
- in_channels = [8, 16, 32, 64]
- feat_sizes = [s // 2**i for i in range(4)] # [64, 32, 16, 8]
- out_channels = 8
-
- feats = [
- torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i])
- for i in range(len(in_channels))
- ]
-
- if (fpn_test_step_names[test_step_name] == 0):
- fpn_model = FPN(
- in_channels=in_channels,
- out_channels=out_channels,
- add_extra_convs=True,
- num_outs=5)
- elif (fpn_test_step_names[test_step_name] == 1):
- fpn_model = FPN(
- in_channels=in_channels,
- out_channels=out_channels,
- add_extra_convs=False,
- num_outs=5)
- elif (fpn_test_step_names[test_step_name] == 2):
- fpn_model = FPN(
- in_channels=in_channels,
- out_channels=out_channels,
- add_extra_convs=True,
- no_norm_on_lateral=False,
- norm_cfg=dict(type='BN', requires_grad=True),
- num_outs=5)
- elif (fpn_test_step_names[test_step_name] == 3):
- fpn_model = FPN(
- in_channels=in_channels,
- out_channels=out_channels,
- add_extra_convs=True,
- upsample_cfg=dict(mode='bilinear', align_corners=True),
- num_outs=5)
- elif (fpn_test_step_names[test_step_name] == 4):
- fpn_model = FPN(
- in_channels=in_channels,
- out_channels=out_channels,
- add_extra_convs=True,
- upsample_cfg=dict(scale_factor=2),
- num_outs=5)
- elif (fpn_test_step_names[test_step_name] == 5):
- fpn_model = FPN(
- in_channels=in_channels,
- out_channels=out_channels,
- add_extra_convs='on_input',
- num_outs=5)
- elif (fpn_test_step_names[test_step_name] == 6):
- fpn_model = FPN(
- in_channels=in_channels,
- out_channels=out_channels,
- add_extra_convs='on_lateral',
- num_outs=5)
- elif (fpn_test_step_names[test_step_name] == 7):
- fpn_model = FPN(
- in_channels=in_channels,
- out_channels=out_channels,
- add_extra_convs='on_output',
- num_outs=5)
- return fpn_model, feats
-
-
- def yolo_neck_config(test_step_name):
- """Config yolov3 Neck."""
- in_channels = [16, 8, 4]
- out_channels = [8, 4, 2]
-
- # The data of yolov3_neck.pkl contains a list of
- # torch.Tensor, where each torch.Tensor is generated by
- # torch.rand and each tensor size is:
- # (1, 4, 64, 64), (1, 8, 32, 32), (1, 16, 16, 16).
- yolov3_neck_data = 'yolov3_neck.pkl'
- feats = mmcv.load(osp.join(data_path, yolov3_neck_data))
-
- if (yolo_test_step_names[test_step_name] == 0):
- yolo_model = YOLOV3Neck(
- in_channels=in_channels, out_channels=out_channels, num_scales=3)
- return yolo_model, feats
-
-
- def test_fpn_normal():
- outs = fpn_neck_config('fpn_normal')
- ort_validate(*outs)
-
-
- def test_fpn_wo_extra_convs():
- outs = fpn_neck_config('fpn_wo_extra_convs')
- ort_validate(*outs)
-
-
- def test_fpn_lateral_bns():
- outs = fpn_neck_config('fpn_lateral_bns')
- ort_validate(*outs)
-
-
- def test_fpn_bilinear_upsample():
- outs = fpn_neck_config('fpn_bilinear_upsample')
- ort_validate(*outs)
-
-
- def test_fpn_scale_factor():
- outs = fpn_neck_config('fpn_scale_factor')
- ort_validate(*outs)
-
-
- def test_fpn_extra_convs_inputs():
- outs = fpn_neck_config('fpn_extra_convs_inputs')
- ort_validate(*outs)
-
-
- def test_fpn_extra_convs_laterals():
- outs = fpn_neck_config('fpn_extra_convs_laterals')
- ort_validate(*outs)
-
-
- def test_fpn_extra_convs_outputs():
- outs = fpn_neck_config('fpn_extra_convs_outputs')
- ort_validate(*outs)
-
-
- def test_yolo_normal():
- outs = yolo_neck_config('yolo_normal')
- ort_validate(*outs)
|