You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_registry.py 3.0 kB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import unittest
  3. from modelscope.utils.constant import Tasks
  4. from modelscope.utils.registry import Registry, build_from_cfg, default_group
  5. class RegistryTest(unittest.TestCase):
  6. def test_register_class_no_task(self):
  7. MODELS = Registry('models')
  8. self.assertTrue(MODELS.name == 'models')
  9. self.assertTrue(default_group in MODELS.modules)
  10. self.assertTrue(MODELS.modules[default_group] == {})
  11. self.assertEqual(len(MODELS.modules), 1)
  12. @MODELS.register_module(module_name='cls-resnet')
  13. class ResNetForCls(object):
  14. pass
  15. self.assertTrue(default_group in MODELS.modules)
  16. self.assertTrue(MODELS.get('cls-resnet') is ResNetForCls)
  17. def test_register_class_with_task(self):
  18. MODELS = Registry('models')
  19. @MODELS.register_module(Tasks.image_classification, 'SwinT')
  20. class SwinTForCls(object):
  21. pass
  22. self.assertTrue(Tasks.image_classification in MODELS.modules)
  23. self.assertTrue(
  24. MODELS.get('SwinT', Tasks.image_classification) is SwinTForCls)
  25. @MODELS.register_module(Tasks.sentiment_analysis, 'Bert')
  26. class BertForSentimentAnalysis(object):
  27. pass
  28. self.assertTrue(Tasks.sentiment_analysis in MODELS.modules)
  29. self.assertTrue(
  30. MODELS.get('Bert', Tasks.sentiment_analysis) is
  31. BertForSentimentAnalysis)
  32. @MODELS.register_module(Tasks.image_object_detection)
  33. class DETR(object):
  34. pass
  35. self.assertTrue(Tasks.image_object_detection in MODELS.modules)
  36. self.assertTrue(
  37. MODELS.get('DETR', Tasks.image_object_detection) is DETR)
  38. self.assertEqual(len(MODELS.modules), 4)
  39. def test_list(self):
  40. MODELS = Registry('models')
  41. @MODELS.register_module(Tasks.image_classification, 'SwinT')
  42. class SwinTForCls(object):
  43. pass
  44. @MODELS.register_module(Tasks.sentiment_analysis, 'Bert')
  45. class BertForSentimentAnalysis(object):
  46. pass
  47. MODELS.list()
  48. print(MODELS)
  49. def test_build(self):
  50. MODELS = Registry('models')
  51. @MODELS.register_module(Tasks.image_classification, 'SwinT')
  52. class SwinTForCls(object):
  53. pass
  54. @MODELS.register_module(Tasks.sentiment_analysis, 'Bert')
  55. class BertForSentimentAnalysis(object):
  56. pass
  57. cfg = dict(type='SwinT')
  58. model = build_from_cfg(cfg, MODELS, Tasks.image_classification)
  59. self.assertTrue(isinstance(model, SwinTForCls))
  60. cfg = dict(type='Bert')
  61. model = build_from_cfg(cfg, MODELS, Tasks.sentiment_analysis)
  62. self.assertTrue(isinstance(model, BertForSentimentAnalysis))
  63. with self.assertRaises(KeyError):
  64. cfg = dict(type='Bert')
  65. model = build_from_cfg(cfg, MODELS, Tasks.image_classification)
  66. if __name__ == '__main__':
  67. unittest.main()