You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_registry.py 3.0 kB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import unittest
  3. from maas_lib.utils.constant import Tasks
  4. from maas_lib.utils.registry import Registry, build_from_cfg, default_group
  5. class RegistryTest(unittest.TestCase):
  6. def test_register_class_no_task(self):
  7. MODELS = Registry('models')
  8. self.assertTrue(MODELS.name == 'models')
  9. self.assertTrue(default_group in MODELS.modules)
  10. self.assertTrue(MODELS.modules[default_group] == {})
  11. self.assertEqual(len(MODELS.modules), 1)
  12. @MODELS.register_module(module_name='cls-resnet')
  13. class ResNetForCls(object):
  14. pass
  15. self.assertTrue(default_group in MODELS.modules)
  16. self.assertTrue(MODELS.get('cls-resnet') is ResNetForCls)
  17. def test_register_class_with_task(self):
  18. MODELS = Registry('models')
  19. @MODELS.register_module(Tasks.image_classification, 'SwinT')
  20. class SwinTForCls(object):
  21. pass
  22. self.assertTrue(Tasks.image_classification in MODELS.modules)
  23. self.assertTrue(
  24. MODELS.get('SwinT', Tasks.image_classification) is SwinTForCls)
  25. @MODELS.register_module(Tasks.sentiment_analysis, 'Bert')
  26. class BertForSentimentAnalysis(object):
  27. pass
  28. self.assertTrue(Tasks.sentiment_analysis in MODELS.modules)
  29. self.assertTrue(
  30. MODELS.get('Bert', Tasks.sentiment_analysis) is
  31. BertForSentimentAnalysis)
  32. @MODELS.register_module(Tasks.object_detection)
  33. class DETR(object):
  34. pass
  35. self.assertTrue(Tasks.object_detection in MODELS.modules)
  36. self.assertTrue(MODELS.get('DETR', Tasks.object_detection) is DETR)
  37. self.assertEqual(len(MODELS.modules), 4)
  38. def test_list(self):
  39. MODELS = Registry('models')
  40. @MODELS.register_module(Tasks.image_classification, 'SwinT')
  41. class SwinTForCls(object):
  42. pass
  43. @MODELS.register_module(Tasks.sentiment_analysis, 'Bert')
  44. class BertForSentimentAnalysis(object):
  45. pass
  46. MODELS.list()
  47. print(MODELS)
  48. def test_build(self):
  49. MODELS = Registry('models')
  50. @MODELS.register_module(Tasks.image_classification, 'SwinT')
  51. class SwinTForCls(object):
  52. pass
  53. @MODELS.register_module(Tasks.sentiment_analysis, 'Bert')
  54. class BertForSentimentAnalysis(object):
  55. pass
  56. cfg = dict(type='SwinT')
  57. model = build_from_cfg(cfg, MODELS, Tasks.image_classification)
  58. self.assertTrue(isinstance(model, SwinTForCls))
  59. cfg = dict(type='Bert')
  60. model = build_from_cfg(cfg, MODELS, Tasks.sentiment_analysis)
  61. self.assertTrue(isinstance(model, BertForSentimentAnalysis))
  62. with self.assertRaises(KeyError):
  63. cfg = dict(type='Bert')
  64. model = build_from_cfg(cfg, MODELS, Tasks.image_classification)
  65. if __name__ == '__main__':
  66. unittest.main()

致力于通过开放的社区合作,开源AI模型以及相关创新技术,推动基于模型即服务的生态繁荣发展