|
- # Copyright (c) OpenMMLab. All rights reserved.
- import numpy as np
- import pytest
- import torch
- import torch.nn as nn
- from mmcv.runner import auto_fp16, force_fp32
- from mmcv.runner.fp16_utils import cast_tensor_type
-
-
- def test_cast_tensor_type():
- inputs = torch.FloatTensor([5.])
- src_type = torch.float32
- dst_type = torch.int32
- outputs = cast_tensor_type(inputs, src_type, dst_type)
- assert isinstance(outputs, torch.Tensor)
- assert outputs.dtype == dst_type
-
- inputs = 'tensor'
- src_type = str
- dst_type = str
- outputs = cast_tensor_type(inputs, src_type, dst_type)
- assert isinstance(outputs, str)
-
- inputs = np.array([5.])
- src_type = np.ndarray
- dst_type = np.ndarray
- outputs = cast_tensor_type(inputs, src_type, dst_type)
- assert isinstance(outputs, np.ndarray)
-
- inputs = dict(
- tensor_a=torch.FloatTensor([1.]), tensor_b=torch.FloatTensor([2.]))
- src_type = torch.float32
- dst_type = torch.int32
- outputs = cast_tensor_type(inputs, src_type, dst_type)
- assert isinstance(outputs, dict)
- assert outputs['tensor_a'].dtype == dst_type
- assert outputs['tensor_b'].dtype == dst_type
-
- inputs = [torch.FloatTensor([1.]), torch.FloatTensor([2.])]
- src_type = torch.float32
- dst_type = torch.int32
- outputs = cast_tensor_type(inputs, src_type, dst_type)
- assert isinstance(outputs, list)
- assert outputs[0].dtype == dst_type
- assert outputs[1].dtype == dst_type
-
- inputs = 5
- outputs = cast_tensor_type(inputs, None, None)
- assert isinstance(outputs, int)
-
-
- def test_auto_fp16():
-
- with pytest.raises(TypeError):
- # ExampleObject is not a subclass of nn.Module
-
- class ExampleObject:
-
- @auto_fp16()
- def __call__(self, x):
- return x
-
- model = ExampleObject()
- input_x = torch.ones(1, dtype=torch.float32)
- model(input_x)
-
- # apply to all input args
- class ExampleModule(nn.Module):
-
- @auto_fp16()
- def forward(self, x, y):
- return x, y
-
- model = ExampleModule()
- input_x = torch.ones(1, dtype=torch.float32)
- input_y = torch.ones(1, dtype=torch.float32)
- output_x, output_y = model(input_x, input_y)
- assert output_x.dtype == torch.float32
- assert output_y.dtype == torch.float32
-
- model.fp16_enabled = True
- output_x, output_y = model(input_x, input_y)
- assert output_x.dtype == torch.half
- assert output_y.dtype == torch.half
-
- if torch.cuda.is_available():
- model.cuda()
- output_x, output_y = model(input_x.cuda(), input_y.cuda())
- assert output_x.dtype == torch.half
- assert output_y.dtype == torch.half
-
- # apply to specified input args
- class ExampleModule(nn.Module):
-
- @auto_fp16(apply_to=('x', ))
- def forward(self, x, y):
- return x, y
-
- model = ExampleModule()
- input_x = torch.ones(1, dtype=torch.float32)
- input_y = torch.ones(1, dtype=torch.float32)
- output_x, output_y = model(input_x, input_y)
- assert output_x.dtype == torch.float32
- assert output_y.dtype == torch.float32
-
- model.fp16_enabled = True
- output_x, output_y = model(input_x, input_y)
- assert output_x.dtype == torch.half
- assert output_y.dtype == torch.float32
-
- if torch.cuda.is_available():
- model.cuda()
- output_x, output_y = model(input_x.cuda(), input_y.cuda())
- assert output_x.dtype == torch.half
- assert output_y.dtype == torch.float32
-
- # apply to optional input args
- class ExampleModule(nn.Module):
-
- @auto_fp16(apply_to=('x', 'y'))
- def forward(self, x, y=None, z=None):
- return x, y, z
-
- model = ExampleModule()
- input_x = torch.ones(1, dtype=torch.float32)
- input_y = torch.ones(1, dtype=torch.float32)
- input_z = torch.ones(1, dtype=torch.float32)
- output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
- assert output_x.dtype == torch.float32
- assert output_y.dtype == torch.float32
- assert output_z.dtype == torch.float32
-
- model.fp16_enabled = True
- output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
- assert output_x.dtype == torch.half
- assert output_y.dtype == torch.half
- assert output_z.dtype == torch.float32
-
- if torch.cuda.is_available():
- model.cuda()
- output_x, output_y, output_z = model(
- input_x.cuda(), y=input_y.cuda(), z=input_z.cuda())
- assert output_x.dtype == torch.half
- assert output_y.dtype == torch.half
- assert output_z.dtype == torch.float32
-
- # out_fp32=True
- class ExampleModule(nn.Module):
-
- @auto_fp16(apply_to=('x', 'y'), out_fp32=True)
- def forward(self, x, y=None, z=None):
- return x, y, z
-
- model = ExampleModule()
- input_x = torch.ones(1, dtype=torch.half)
- input_y = torch.ones(1, dtype=torch.float32)
- input_z = torch.ones(1, dtype=torch.float32)
- output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
- assert output_x.dtype == torch.half
- assert output_y.dtype == torch.float32
- assert output_z.dtype == torch.float32
-
- model.fp16_enabled = True
- output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
- assert output_x.dtype == torch.float32
- assert output_y.dtype == torch.float32
- assert output_z.dtype == torch.float32
-
- if torch.cuda.is_available():
- model.cuda()
- output_x, output_y, output_z = model(
- input_x.cuda(), y=input_y.cuda(), z=input_z.cuda())
- assert output_x.dtype == torch.float32
- assert output_y.dtype == torch.float32
- assert output_z.dtype == torch.float32
-
-
- def test_force_fp32():
-
- with pytest.raises(TypeError):
- # ExampleObject is not a subclass of nn.Module
-
- class ExampleObject:
-
- @force_fp32()
- def __call__(self, x):
- return x
-
- model = ExampleObject()
- input_x = torch.ones(1, dtype=torch.float32)
- model(input_x)
-
- # apply to all input args
- class ExampleModule(nn.Module):
-
- @force_fp32()
- def forward(self, x, y):
- return x, y
-
- model = ExampleModule()
- input_x = torch.ones(1, dtype=torch.half)
- input_y = torch.ones(1, dtype=torch.half)
- output_x, output_y = model(input_x, input_y)
- assert output_x.dtype == torch.half
- assert output_y.dtype == torch.half
-
- model.fp16_enabled = True
- output_x, output_y = model(input_x, input_y)
- assert output_x.dtype == torch.float32
- assert output_y.dtype == torch.float32
-
- if torch.cuda.is_available():
- model.cuda()
- output_x, output_y = model(input_x.cuda(), input_y.cuda())
- assert output_x.dtype == torch.float32
- assert output_y.dtype == torch.float32
-
- # apply to specified input args
- class ExampleModule(nn.Module):
-
- @force_fp32(apply_to=('x', ))
- def forward(self, x, y):
- return x, y
-
- model = ExampleModule()
- input_x = torch.ones(1, dtype=torch.half)
- input_y = torch.ones(1, dtype=torch.half)
- output_x, output_y = model(input_x, input_y)
- assert output_x.dtype == torch.half
- assert output_y.dtype == torch.half
-
- model.fp16_enabled = True
- output_x, output_y = model(input_x, input_y)
- assert output_x.dtype == torch.float32
- assert output_y.dtype == torch.half
-
- if torch.cuda.is_available():
- model.cuda()
- output_x, output_y = model(input_x.cuda(), input_y.cuda())
- assert output_x.dtype == torch.float32
- assert output_y.dtype == torch.half
-
- # apply to optional input args
- class ExampleModule(nn.Module):
-
- @force_fp32(apply_to=('x', 'y'))
- def forward(self, x, y=None, z=None):
- return x, y, z
-
- model = ExampleModule()
- input_x = torch.ones(1, dtype=torch.half)
- input_y = torch.ones(1, dtype=torch.half)
- input_z = torch.ones(1, dtype=torch.half)
- output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
- assert output_x.dtype == torch.half
- assert output_y.dtype == torch.half
- assert output_z.dtype == torch.half
-
- model.fp16_enabled = True
- output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
- assert output_x.dtype == torch.float32
- assert output_y.dtype == torch.float32
- assert output_z.dtype == torch.half
-
- if torch.cuda.is_available():
- model.cuda()
- output_x, output_y, output_z = model(
- input_x.cuda(), y=input_y.cuda(), z=input_z.cuda())
- assert output_x.dtype == torch.float32
- assert output_y.dtype == torch.float32
- assert output_z.dtype == torch.half
-
- # out_fp16=True
- class ExampleModule(nn.Module):
-
- @force_fp32(apply_to=('x', 'y'), out_fp16=True)
- def forward(self, x, y=None, z=None):
- return x, y, z
-
- model = ExampleModule()
- input_x = torch.ones(1, dtype=torch.float32)
- input_y = torch.ones(1, dtype=torch.half)
- input_z = torch.ones(1, dtype=torch.half)
- output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
- assert output_x.dtype == torch.float32
- assert output_y.dtype == torch.half
- assert output_z.dtype == torch.half
-
- model.fp16_enabled = True
- output_x, output_y, output_z = model(input_x, y=input_y, z=input_z)
- assert output_x.dtype == torch.half
- assert output_y.dtype == torch.half
- assert output_z.dtype == torch.half
-
- if torch.cuda.is_available():
- model.cuda()
- output_x, output_y, output_z = model(
- input_x.cuda(), y=input_y.cuda(), z=input_z.cuda())
- assert output_x.dtype == torch.half
- assert output_y.dtype == torch.half
- assert output_z.dtype == torch.half
|