|
- from unittest.mock import patch
-
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
-
- from mmdet.models.utils import AdaptiveAvgPool2d, adaptive_avg_pool2d
-
- if torch.__version__ != 'parrots':
- torch_version = '1.7'
- else:
- torch_version = 'parrots'
-
-
- @patch('torch.__version__', torch_version)
- def test_adaptive_avg_pool2d():
- # Test the empty batch dimension
- # Test the two input conditions
- x_empty = torch.randn(0, 3, 4, 5)
- # 1. tuple[int, int]
- wrapper_out = adaptive_avg_pool2d(x_empty, (2, 2))
- assert wrapper_out.shape == (0, 3, 2, 2)
- # 2. int
- wrapper_out = adaptive_avg_pool2d(x_empty, 2)
- assert wrapper_out.shape == (0, 3, 2, 2)
-
- # wrapper op with 3-dim input
- x_normal = torch.randn(3, 3, 4, 5)
- wrapper_out = adaptive_avg_pool2d(x_normal, (2, 2))
- ref_out = F.adaptive_avg_pool2d(x_normal, (2, 2))
- assert wrapper_out.shape == (3, 3, 2, 2)
- assert torch.equal(wrapper_out, ref_out)
-
- wrapper_out = adaptive_avg_pool2d(x_normal, 2)
- ref_out = F.adaptive_avg_pool2d(x_normal, 2)
- assert wrapper_out.shape == (3, 3, 2, 2)
- assert torch.equal(wrapper_out, ref_out)
-
-
- @patch('torch.__version__', torch_version)
- def test_AdaptiveAvgPool2d():
- # Test the empty batch dimension
- x_empty = torch.randn(0, 3, 4, 5)
- # Test the four input conditions
- # 1. tuple[int, int]
- wrapper = AdaptiveAvgPool2d((2, 2))
- wrapper_out = wrapper(x_empty)
- assert wrapper_out.shape == (0, 3, 2, 2)
-
- # 2. int
- wrapper = AdaptiveAvgPool2d(2)
- wrapper_out = wrapper(x_empty)
- assert wrapper_out.shape == (0, 3, 2, 2)
-
- # 3. tuple[None, int]
- wrapper = AdaptiveAvgPool2d((None, 2))
- wrapper_out = wrapper(x_empty)
- assert wrapper_out.shape == (0, 3, 4, 2)
-
- # 3. tuple[int, None]
- wrapper = AdaptiveAvgPool2d((2, None))
- wrapper_out = wrapper(x_empty)
- assert wrapper_out.shape == (0, 3, 2, 5)
-
- # Test the normal batch dimension
- x_normal = torch.randn(3, 3, 4, 5)
- wrapper = AdaptiveAvgPool2d((2, 2))
- ref = nn.AdaptiveAvgPool2d((2, 2))
- wrapper_out = wrapper(x_normal)
- ref_out = ref(x_normal)
- assert wrapper_out.shape == (3, 3, 2, 2)
- assert torch.equal(wrapper_out, ref_out)
-
- wrapper = AdaptiveAvgPool2d(2)
- ref = nn.AdaptiveAvgPool2d(2)
- wrapper_out = wrapper(x_normal)
- ref_out = ref(x_normal)
- assert wrapper_out.shape == (3, 3, 2, 2)
- assert torch.equal(wrapper_out, ref_out)
-
- wrapper = AdaptiveAvgPool2d((None, 2))
- ref = nn.AdaptiveAvgPool2d((None, 2))
- wrapper_out = wrapper(x_normal)
- ref_out = ref(x_normal)
- assert wrapper_out.shape == (3, 3, 4, 2)
- assert torch.equal(wrapper_out, ref_out)
-
- wrapper = AdaptiveAvgPool2d((2, None))
- ref = nn.AdaptiveAvgPool2d((2, None))
- wrapper_out = wrapper(x_normal)
- ref_out = ref(x_normal)
- assert wrapper_out.shape == (3, 3, 2, 5)
- assert torch.equal(wrapper_out, ref_out)
|