You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_common.py 1.7 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import unittest
  3. import torch
  4. from modelscope.preprocessors import (PREPROCESSORS, Compose, Filter,
  5. Preprocessor, ToTensor)
  6. class ComposeTest(unittest.TestCase):
  7. def test_compose(self):
  8. @PREPROCESSORS.register_module()
  9. class Tmp1(Preprocessor):
  10. def __call__(self, input):
  11. input['tmp1'] = 'tmp1'
  12. return input
  13. @PREPROCESSORS.register_module()
  14. class Tmp2(Preprocessor):
  15. def __call__(self, input):
  16. input['tmp2'] = 'tmp2'
  17. return input
  18. pipeline = [
  19. dict(type='Tmp1'),
  20. dict(type='Tmp2'),
  21. ]
  22. trans = Compose(pipeline)
  23. input = {}
  24. output = trans(input)
  25. self.assertEqual(output['tmp1'], 'tmp1')
  26. self.assertEqual(output['tmp2'], 'tmp2')
  27. class ToTensorTest(unittest.TestCase):
  28. def test_totensor(self):
  29. to_tensor_op = ToTensor(keys=['img'])
  30. inputs = {'img': [1, 2, 3], 'label': 1, 'path': 'test.jpg'}
  31. inputs = to_tensor_op(inputs)
  32. self.assertIsInstance(inputs['img'], torch.Tensor)
  33. self.assertEqual(inputs['label'], 1)
  34. self.assertEqual(inputs['path'], 'test.jpg')
  35. class FilterTest(unittest.TestCase):
  36. def test_filter(self):
  37. filter_op = Filter(reserved_keys=['img', 'label'])
  38. inputs = {'img': [1, 2, 3], 'label': 1, 'path': 'test.jpg'}
  39. inputs = filter_op(inputs)
  40. self.assertIn('img', inputs)
  41. self.assertIn('label', inputs)
  42. self.assertNotIn('path', inputs)
  43. if __name__ == '__main__':
  44. unittest.main()