You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

registry.py 7.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import inspect
  3. from email.policy import default
  4. from maas_lib.utils.logger import get_logger
  5. default_group = 'default'
  6. logger = get_logger()
  7. class Registry(object):
  8. """ Registry which support registering modules and group them by a keyname
  9. If group name is not provided, modules will be registered to default group.
  10. """
  11. def __init__(self, name: str):
  12. self._name = name
  13. self._modules = {default_group: {}}
  14. def __repr__(self):
  15. format_str = self.__class__.__name__ + f' ({self._name})\n'
  16. for group_name, group in self._modules.items():
  17. format_str += f'group_name={group_name}, '\
  18. f'modules={list(group.keys())}\n'
  19. return format_str
  20. @property
  21. def name(self):
  22. return self._name
  23. @property
  24. def modules(self):
  25. return self._modules
  26. def list(self):
  27. """ logging the list of module in current registry
  28. """
  29. for group_name, group in self._modules.items():
  30. logger.info(f'group_name={group_name}')
  31. for m in group.keys():
  32. logger.info(f'\t{m}')
  33. logger.info('')
  34. def get(self, module_key, group_key=default_group):
  35. if group_key not in self._modules:
  36. return None
  37. else:
  38. return self._modules[group_key].get(module_key, None)
  39. def _register_module(self,
  40. group_key=default_group,
  41. module_name=None,
  42. module_cls=None):
  43. assert isinstance(group_key,
  44. str), 'group_key is required and must be str'
  45. if group_key not in self._modules:
  46. self._modules[group_key] = dict()
  47. if not inspect.isclass(module_cls):
  48. raise TypeError(f'module is not a class type: {type(module_cls)}')
  49. if module_name is None:
  50. module_name = module_cls.__name__
  51. if module_name in self._modules[group_key]:
  52. raise KeyError(f'{module_name} is already registered in '
  53. f'{self._name}[{group_key}]')
  54. self._modules[group_key][module_name] = module_cls
  55. if module_name in self._modules[default_group]:
  56. if id(self._modules[default_group][module_name]) == id(module_cls):
  57. return
  58. else:
  59. logger.warning(f'{module_name} is already registered in '
  60. f'{self._name}[{default_group}] and will '
  61. 'be overwritten')
  62. logger.warning(f'{self._modules[default_group][module_name]}'
  63. 'to {module_cls}')
  64. # also register module in the default group for faster access
  65. # only by module name
  66. self._modules[default_group][module_name] = module_cls
  67. def register_module(self,
  68. group_key: str = default_group,
  69. module_name: str = None,
  70. module_cls: type = None):
  71. """ Register module
  72. Example:
  73. >>> models = Registry('models')
  74. >>> @models.register_module('image-classification', 'SwinT')
  75. >>> class SwinTransformer:
  76. >>> pass
  77. >>> @models.register_module('SwinDefault')
  78. >>> class SwinTransformerDefaultGroup:
  79. >>> pass
  80. Args:
  81. group_key: Group name of which module will be registered,
  82. default group name is 'default'
  83. module_name: Module name
  84. module_cls: Module class object
  85. """
  86. if not (module_name is None or isinstance(module_name, str)):
  87. raise TypeError(f'module_name must be either of None, str,'
  88. f'got {type(module_name)}')
  89. if module_cls is not None:
  90. self._register_module(
  91. group_key=group_key,
  92. module_name=module_name,
  93. module_cls=module_cls)
  94. return module_cls
  95. # if module_cls is None, should return a decorator function
  96. def _register(module_cls):
  97. self._register_module(
  98. group_key=group_key,
  99. module_name=module_name,
  100. module_cls=module_cls)
  101. return module_cls
  102. return _register
  103. def build_from_cfg(cfg,
  104. registry: Registry,
  105. group_key: str = default_group,
  106. default_args: dict = None) -> object:
  107. """Build a module from config dict when it is a class configuration, or
  108. call a function from config dict when it is a function configuration.
  109. Example:
  110. >>> models = Registry('models')
  111. >>> @models.register_module('image-classification', 'SwinT')
  112. >>> class SwinTransformer:
  113. >>> pass
  114. >>> swint = build_from_cfg(dict(type='SwinT'), MODELS,
  115. >>> 'image-classification')
  116. >>> # Returns an instantiated object
  117. >>>
  118. >>> @MODELS.register_module()
  119. >>> def swin_transformer():
  120. >>> pass
  121. >>> = build_from_cfg(dict(type='swin_transformer'), MODELS)
  122. >>> # Return a result of the calling function
  123. Args:
  124. cfg (dict): Config dict. It should at least contain the key "type".
  125. registry (:obj:`Registry`): The registry to search the type from.
  126. group_key (str, optional): The name of registry group from which
  127. module should be searched.
  128. default_args (dict, optional): Default initialization arguments.
  129. Returns:
  130. object: The constructed object.
  131. """
  132. if not isinstance(cfg, dict):
  133. raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
  134. if 'type' not in cfg:
  135. if default_args is None or 'type' not in default_args:
  136. raise KeyError(
  137. '`cfg` or `default_args` must contain the key "type", '
  138. f'but got {cfg}\n{default_args}')
  139. if not isinstance(registry, Registry):
  140. raise TypeError('registry must be an maas_lib.Registry object, '
  141. f'but got {type(registry)}')
  142. if not (isinstance(default_args, dict) or default_args is None):
  143. raise TypeError('default_args must be a dict or None, '
  144. f'but got {type(default_args)}')
  145. args = cfg.copy()
  146. if default_args is not None:
  147. for name, value in default_args.items():
  148. args.setdefault(name, value)
  149. if group_key is None:
  150. group_key = default_group
  151. obj_type = args.pop('type')
  152. if isinstance(obj_type, str):
  153. obj_cls = registry.get(obj_type, group_key=group_key)
  154. if obj_cls is None:
  155. raise KeyError(f'{obj_type} is not in the {registry.name}'
  156. f' registry group {group_key}')
  157. elif inspect.isclass(obj_type) or inspect.isfunction(obj_type):
  158. obj_cls = obj_type
  159. else:
  160. raise TypeError(
  161. f'type must be a str or valid type, but got {type(obj_type)}')
  162. try:
  163. return obj_cls(**args)
  164. except Exception as e:
  165. # Normal TypeError does not print class name.
  166. raise type(e)(f'{obj_cls.__name__}: {e}')

致力于通过开放的社区合作,开源AI模型以及相关创新技术,推动基于模型即服务的生态繁荣发展