|
- # Copyright (c) Alibaba, Inc. and its affiliates.
-
- import unittest
-
- import torch
-
- from modelscope.preprocessors import (PREPROCESSORS, Compose, Filter,
- Preprocessor, ToTensor)
-
-
- class ComposeTest(unittest.TestCase):
-
- def test_compose(self):
-
- @PREPROCESSORS.register_module()
- class Tmp1(Preprocessor):
-
- def __call__(self, input):
- input['tmp1'] = 'tmp1'
- return input
-
- @PREPROCESSORS.register_module()
- class Tmp2(Preprocessor):
-
- def __call__(self, input):
- input['tmp2'] = 'tmp2'
- return input
-
- pipeline = [
- dict(type='Tmp1'),
- dict(type='Tmp2'),
- ]
- trans = Compose(pipeline)
-
- input = {}
- output = trans(input)
- self.assertEqual(output['tmp1'], 'tmp1')
- self.assertEqual(output['tmp2'], 'tmp2')
-
-
- class ToTensorTest(unittest.TestCase):
-
- def test_totensor(self):
- to_tensor_op = ToTensor(keys=['img'])
- inputs = {'img': [1, 2, 3], 'label': 1, 'path': 'test.jpg'}
- inputs = to_tensor_op(inputs)
- self.assertIsInstance(inputs['img'], torch.Tensor)
- self.assertEqual(inputs['label'], 1)
- self.assertEqual(inputs['path'], 'test.jpg')
-
-
- class FilterTest(unittest.TestCase):
-
- def test_filter(self):
- filter_op = Filter(reserved_keys=['img', 'label'])
- inputs = {'img': [1, 2, 3], 'label': 1, 'path': 'test.jpg'}
- inputs = filter_op(inputs)
- self.assertIn('img', inputs)
- self.assertIn('label', inputs)
- self.assertNotIn('path', inputs)
-
-
- if __name__ == '__main__':
- unittest.main()
|