|
- """
- Copyright (c) OpenMMLab. All rights reserved.
- """
-
- import inspect
- import threading
- import warnings
- from collections import OrderedDict
- from typing import Type, TypeVar
-
- _lock = threading.RLock()
- T = TypeVar("T")
-
-
- def _accquire_lock() -> None:
- """Acquire the module-level lock for serializing access to shared data.
-
- This should be released with _release_lock().
- """
- if _lock:
- _lock.acquire()
-
-
- def _release_lock() -> None:
- """Release the module-level lock acquired by calling _accquire_lock()."""
- if _lock:
- _lock.release()
-
-
- class ManagerMeta(type):
- """The metaclass for global accessible class.
-
- The subclasses inheriting from ``ManagerMeta`` will manage their
- own ``_instance_dict`` and root instances. The constructors of subclasses
- must contain the ``name`` argument.
-
- Examples:
- >>> class SubClass1(metaclass=ManagerMeta):
- >>> def __init__(self, *args, **kwargs):
- >>> pass
- AssertionError: <class '__main__.SubClass1'>.__init__ must have the
- name argument.
- >>> class SubClass2(metaclass=ManagerMeta):
- >>> def __init__(self, name):
- >>> pass
- >>> # valid format.
- """
-
- def __init__(cls, *args):
- cls._instance_dict = OrderedDict()
- params = inspect.getfullargspec(cls)
- params_names = params[0] if params[0] else []
- assert "name" in params_names, f"{cls} must have the `name` argument"
- super().__init__(*args)
-
-
- class ManagerMixin(metaclass=ManagerMeta):
- """``ManagerMixin`` is the base class for classes that have global access
- requirements.
-
- The subclasses inheriting from ``ManagerMixin`` can get their
- global instances.
-
- Examples:
- >>> class GlobalAccessible(ManagerMixin):
- >>> def __init__(self, name=''):
- >>> super().__init__(name)
- >>>
- >>> GlobalAccessible.get_instance('name')
- >>> instance_1 = GlobalAccessible.get_instance('name')
- >>> instance_2 = GlobalAccessible.get_instance('name')
- >>> assert id(instance_1) == id(instance_2)
-
- Args:
- name (str): Name of the instance. Defaults to ''.
- """
-
- def __init__(self, name: str = ""):
- assert isinstance(name, str) and name, "name argument must be an non-empty string."
- self._instance_name = name
-
- @classmethod
- def get_instance(cls: Type[T], name: str, **kwargs) -> T:
- """Get subclass instance by name if the name exists.
-
- If corresponding name instance has not been created, ``get_instance``
- will create an instance, otherwise ``get_instance`` will return the
- corresponding instance.
-
- Examples
- >>> instance1 = GlobalAccessible.get_instance('name1')
- >>> # Create name1 instance.
- >>> instance.instance_name
- name1
- >>> instance2 = GlobalAccessible.get_instance('name1')
- >>> # Get name1 instance.
- >>> assert id(instance1) == id(instance2)
-
- Args:
- name (str): Name of instance. Defaults to ''.
-
- Returns:
- object: Corresponding name instance, the latest instance, or root
- instance.
- """
- _accquire_lock()
- assert isinstance(name, str), f"type of name should be str, but got {type(cls)}"
- instance_dict = cls._instance_dict # type: ignore
- # Get the instance by name.
- if name not in instance_dict:
- instance = cls(name=name, **kwargs) # type: ignore
- instance_dict[name] = instance # type: ignore
- elif kwargs:
- warnings.warn(
- f"{cls} instance named of {name} has been created, "
- "the method `get_instance` should not accept any other "
- "arguments"
- )
- # Get latest instantiated instance or root instance.
- _release_lock()
- return instance_dict[name]
-
- @classmethod
- def get_current_instance(cls):
- """Get latest created instance.
-
- Before calling ``get_current_instance``, The subclass must have called
- ``get_instance(xxx)`` at least once.
-
- Examples
- >>> instance = GlobalAccessible.get_current_instance()
- AssertionError: At least one of name and current needs to be set
- >>> instance = GlobalAccessible.get_instance('name1')
- >>> instance.instance_name
- name1
- >>> instance = GlobalAccessible.get_current_instance()
- >>> instance.instance_name
- name1
-
- Returns:
- object: Latest created instance.
- """
- _accquire_lock()
- if not cls._instance_dict:
- raise RuntimeError(
- f"Before calling {cls.__name__}.get_current_instance(), you "
- "should call get_instance(name=xxx) at least once."
- )
- name = next(iter(reversed(cls._instance_dict)))
- _release_lock()
- return cls._instance_dict[name]
-
- @classmethod
- def check_instance_created(cls, name: str) -> bool:
- """Check whether the name corresponding instance exists.
-
- Args:
- name (str): Name of instance.
-
- Returns:
- bool: Whether the name corresponding instance exists.
- """
- return name in cls._instance_dict
-
- @property
- def instance_name(self) -> str:
- """Get the name of instance.
-
- Returns:
- str: Name of instance.
- """
- return self._instance_name
|