|
- # Copyright (c) OpenMMLab. All rights reserved.
- import copy
- from typing import Any, Iterator, Optional, Tuple, Type, Union
-
- import numpy as np
- import torch
-
-
- class BaseDataElement:
- """A base data interface that supports Tensor-like and dict-like
- operations.
-
- A typical data elements refer to predicted results or ground truth labels
- on a task, such as predicted bboxes, instance masks, semantic
- segmentation masks, etc. Because groundtruth labels and predicted results
- often have similar properties (for example, the predicted bboxes and the
- groundtruth bboxes), MMEngine uses the same abstract data interface to
- encapsulate predicted results and groundtruth labels, and it is recommended
- to use different name conventions to distinguish them, such as using
- ``gt_instances`` and ``pred_instances`` to distinguish between labels and
- predicted results. Additionally, we distinguish data elements at instance
- level, pixel level, and label level. Each of these types has its own
- characteristics. Therefore, MMEngine defines the base class
- ``BaseDataElement``, and implement ``InstanceData``, ``PixelData``, and
- ``LabelData`` inheriting from ``BaseDataElement`` to represent different
- types of ground truth labels or predictions.
-
- Another common data element is data example. A data example consists of input
- data (such as an image) and its annotations and predictions. In general,
- an image can have multiple types of annotations and/or predictions at the
- same time (for example, both pixel-level semantic segmentation annotations
- and instance-level detection bboxes annotations). All labels and
- predictions of a training example are often passed between Dataset, Model,
- Visualizer, and Evaluator components. In order to simplify the interface
- between components, we can treat them as a large data element and
- encapsulate them. Such data elements are generally called XXDataSample in
- the OpenMMLab. Therefore, Similar to `nn.Module`, the `BaseDataElement`
- allows `BaseDataElement` as its attribute. Such a class generally
- encapsulates all the data of a example in the algorithm library, and its
- attributes generally are various types of data elements. For example,
- MMDetection is assigned by the BaseDataElement to encapsulate all the data
- elements of the example labeling and prediction of a example in the
- algorithm library.
-
- The attributes in ``BaseDataElement`` are divided into two parts,
- the ``metainfo`` and the ``data`` respectively.
-
- - ``metainfo``: Usually contains the
- information about the image such as filename,
- image_shape, pad_shape, etc. The attributes can be accessed or
- modified by dict-like or object-like operations, such as
- ``.`` (for data access and modification), ``in``, ``del``,
- ``pop(str)``, ``get(str)``, ``metainfo_keys()``,
- ``metainfo_values()``, ``metainfo_items()``, ``set_metainfo()`` (for
- set or change key-value pairs in metainfo).
-
- - ``data``: Annotations or model predictions are
- stored. The attributes can be accessed or modified by
- dict-like or object-like operations, such as
- ``.``, ``in``, ``del``, ``pop(str)``, ``get(str)``, ``keys()``,
- ``values()``, ``items()``. Users can also apply tensor-like
- methods to all :obj:`torch.Tensor` in the ``data_fields``,
- such as ``.cuda()``, ``.cpu()``, ``.numpy()``, ``.to()``,
- ``to_tensor()``, ``.detach()``.
-
- Args:
- metainfo (dict, optional): A dict contains the meta information
- of single image, such as ``dict(img_shape=(512, 512, 3),
- scale_factor=(1, 1, 1, 1))``. Defaults to None.
- kwargs (dict, optional): A dict contains annotations of single image or
- model predictions. Defaults to None.
-
- Examples:
- >>> import torch
- >>> from mmengine.structures import BaseDataElement
- >>> gt_instances = BaseDataElement()
- >>> bboxes = torch.rand((5, 4))
- >>> scores = torch.rand((5,))
- >>> img_id = 0
- >>> img_shape = (800, 1333)
- >>> gt_instances = BaseDataElement(
- ... metainfo=dict(img_id=img_id, img_shape=img_shape),
- ... bboxes=bboxes, scores=scores)
- >>> gt_instances = BaseDataElement(
- ... metainfo=dict(img_id=img_id, img_shape=(640, 640)))
-
- >>> # new
- >>> gt_instances1 = gt_instances.new(
- ... metainfo=dict(img_id=1, img_shape=(640, 640)),
- ... bboxes=torch.rand((5, 4)),
- ... scores=torch.rand((5,)))
- >>> gt_instances2 = gt_instances1.new()
-
- >>> # add and process property
- >>> gt_instances = BaseDataElement()
- >>> gt_instances.set_metainfo(dict(img_id=9, img_shape=(100, 100)))
- >>> assert 'img_shape' in gt_instances.metainfo_keys()
- >>> assert 'img_shape' in gt_instances
- >>> assert 'img_shape' not in gt_instances.keys()
- >>> assert 'img_shape' in gt_instances.all_keys()
- >>> print(gt_instances.img_shape)
- (100, 100)
- >>> gt_instances.scores = torch.rand((5,))
- >>> assert 'scores' in gt_instances.keys()
- >>> assert 'scores' in gt_instances
- >>> assert 'scores' in gt_instances.all_keys()
- >>> assert 'scores' not in gt_instances.metainfo_keys()
- >>> print(gt_instances.scores)
- tensor([0.5230, 0.7885, 0.2426, 0.3911, 0.4876])
- >>> gt_instances.bboxes = torch.rand((5, 4))
- >>> assert 'bboxes' in gt_instances.keys()
- >>> assert 'bboxes' in gt_instances
- >>> assert 'bboxes' in gt_instances.all_keys()
- >>> assert 'bboxes' not in gt_instances.metainfo_keys()
- >>> print(gt_instances.bboxes)
- tensor([[0.0900, 0.0424, 0.1755, 0.4469],
- [0.8648, 0.0592, 0.3484, 0.0913],
- [0.5808, 0.1909, 0.6165, 0.7088],
- [0.5490, 0.4209, 0.9416, 0.2374],
- [0.3652, 0.1218, 0.8805, 0.7523]])
-
- >>> # delete and change property
- >>> gt_instances = BaseDataElement(
- ... metainfo=dict(img_id=0, img_shape=(640, 640)),
- ... bboxes=torch.rand((6, 4)), scores=torch.rand((6,)))
- >>> gt_instances.set_metainfo(dict(img_shape=(1280, 1280)))
- >>> gt_instances.img_shape # (1280, 1280)
- >>> gt_instances.bboxes = gt_instances.bboxes * 2
- >>> gt_instances.get('img_shape', None) # (1280, 1280)
- >>> gt_instances.get('bboxes', None) # 6x4 tensor
- >>> del gt_instances.img_shape
- >>> del gt_instances.bboxes
- >>> assert 'img_shape' not in gt_instances
- >>> assert 'bboxes' not in gt_instances
- >>> gt_instances.pop('img_shape', None) # None
- >>> gt_instances.pop('bboxes', None) # None
-
- >>> # Tensor-like
- >>> cuda_instances = gt_instances.cuda()
- >>> cuda_instances = gt_instances.to('cuda:0')
- >>> cpu_instances = cuda_instances.cpu()
- >>> cpu_instances = cuda_instances.to('cpu')
- >>> fp16_instances = cuda_instances.to(
- ... device=None, dtype=torch.float16, non_blocking=False,
- ... copy=False, memory_format=torch.preserve_format)
- >>> cpu_instances = cuda_instances.detach()
- >>> np_instances = cpu_instances.numpy()
-
- >>> # print
- >>> metainfo = dict(img_shape=(800, 1196, 3))
- >>> gt_instances = BaseDataElement(
- ... metainfo=metainfo, det_labels=torch.LongTensor([0, 1, 2, 3]))
- >>> example = BaseDataElement(metainfo=metainfo,
- ... gt_instances=gt_instances)
- >>> print(example)
- <BaseDataElement(
- META INFORMATION
- img_shape: (800, 1196, 3)
- DATA FIELDS
- gt_instances: <BaseDataElement(
- META INFORMATION
- img_shape: (800, 1196, 3)
- DATA FIELDS
- det_labels: tensor([0, 1, 2, 3])
- ) at 0x7f0ec5eadc70>
- ) at 0x7f0fea49e130>
-
- >>> # inheritance
- >>> class DetDataSample(BaseDataElement):
- ... @property
- ... def proposals(self):
- ... return self._proposals
- ... @proposals.setter
- ... def proposals(self, value):
- ... self.set_field(value, '_proposals', dtype=BaseDataElement)
- ... @proposals.deleter
- ... def proposals(self):
- ... del self._proposals
- ... @property
- ... def gt_instances(self):
- ... return self._gt_instances
- ... @gt_instances.setter
- ... def gt_instances(self, value):
- ... self.set_field(value, '_gt_instances',
- ... dtype=BaseDataElement)
- ... @gt_instances.deleter
- ... def gt_instances(self):
- ... del self._gt_instances
- ... @property
- ... def pred_instances(self):
- ... return self._pred_instances
- ... @pred_instances.setter
- ... def pred_instances(self, value):
- ... self.set_field(value, '_pred_instances',
- ... dtype=BaseDataElement)
- ... @pred_instances.deleter
- ... def pred_instances(self):
- ... del self._pred_instances
- >>> det_example = DetDataSample()
- >>> proposals = BaseDataElement(bboxes=torch.rand((5, 4)))
- >>> det_example.proposals = proposals
- >>> assert 'proposals' in det_example
- >>> assert det_example.proposals == proposals
- >>> del det_example.proposals
- >>> assert 'proposals' not in det_example
- >>> with self.assertRaises(AssertionError):
- ... det_example.proposals = torch.rand((5, 4))
- """
-
- def __init__(self, *, metainfo: Optional[dict] = None, **kwargs) -> None:
- self._metainfo_fields: set = set()
- self._data_fields: set = set()
-
- if metainfo is not None:
- self.set_metainfo(metainfo=metainfo)
- if kwargs:
- self.set_data(kwargs)
-
- def set_metainfo(self, metainfo: dict) -> None:
- """Set or change key-value pairs in ``metainfo_field`` by parameter
- ``metainfo``.
-
- Args:
- metainfo (dict): A dict contains the meta information
- of image, such as ``img_shape``, ``scale_factor``, etc.
- """
- assert isinstance(metainfo, dict), f"metainfo should be a ``dict`` but got {type(metainfo)}"
- meta = copy.deepcopy(metainfo)
- for k, v in meta.items():
- self.set_field(name=k, value=v, field_type="metainfo", dtype=None)
-
- def set_data(self, data: dict) -> None:
- """Set or change key-value pairs in ``data_field`` by parameter
- ``data``.
-
- Args:
- data (dict): A dict contains annotations of image or
- model predictions.
- """
- assert isinstance(data, dict), f"data should be a `dict` but got {data}"
- for k, v in data.items():
- # Use `setattr()` rather than `self.set_field` to allow `set_data`
- # to set property method.
- setattr(self, k, v)
-
- def update(self, instance: "BaseDataElement") -> None:
- """The update() method updates the BaseDataElement with the elements
- from another BaseDataElement object.
-
- Args:
- instance (BaseDataElement): Another BaseDataElement object for
- update the current object.
- """
- assert isinstance(
- instance, BaseDataElement
- ), f"instance should be a `BaseDataElement` but got {type(instance)}"
- self.set_metainfo(dict(instance.metainfo_items()))
- self.set_data(dict(instance.items()))
-
- def new(self, *, metainfo: Optional[dict] = None, **kwargs) -> "BaseDataElement":
- """Return a new data element with same type. If ``metainfo`` and
- ``data`` are None, the new data element will have same metainfo and
- data. If metainfo or data is not None, the new result will overwrite it
- with the input value.
-
- Args:
- metainfo (dict, optional): A dict contains the meta information
- of image, such as ``img_shape``, ``scale_factor``, etc.
- Defaults to None.
- kwargs (dict): A dict contains annotations of image or
- model predictions.
-
- Returns:
- BaseDataElement: A new data element with same type.
- """
- new_data = self.__class__()
-
- if metainfo is not None:
- new_data.set_metainfo(metainfo)
- else:
- new_data.set_metainfo(dict(self.metainfo_items()))
- if kwargs:
- new_data.set_data(kwargs)
- else:
- new_data.set_data(dict(self.items()))
- return new_data
-
- def clone(self):
- """Deep copy the current data element.
-
- Returns:
- BaseDataElement: The copy of current data element.
- """
- clone_data = self.__class__()
- clone_data.set_metainfo(dict(self.metainfo_items()))
- clone_data.set_data(dict(self.items()))
- return clone_data
-
- def keys(self) -> list:
- """
- Returns:
- list: Contains all keys in data_fields.
- """
- # We assume that the name of the attribute related to property is
- # '_' + the name of the property. We use this rule to filter out
- # private keys.
- # TODO: Use a more robust way to solve this problem
- private_keys = {
- "_" + key
- for key in self._data_fields
- if isinstance(getattr(type(self), key, None), property)
- }
- return list(self._data_fields - private_keys)
-
- def metainfo_keys(self) -> list:
- """
- Returns:
- list: Contains all keys in metainfo_fields.
- """
- return list(self._metainfo_fields)
-
- def values(self) -> list:
- """
- Returns:
- list: Contains all values in data.
- """
- return [getattr(self, k) for k in self.keys()]
-
- def metainfo_values(self) -> list:
- """
- Returns:
- list: Contains all values in metainfo.
- """
- return [getattr(self, k) for k in self.metainfo_keys()]
-
- def all_keys(self) -> list:
- """
- Returns:
- list: Contains all keys in metainfo and data.
- """
- return self.metainfo_keys() + self.keys()
-
- def all_values(self) -> list:
- """
- Returns:
- list: Contains all values in metainfo and data.
- """
- return self.metainfo_values() + self.values()
-
- def all_items(self) -> Iterator[Tuple[str, Any]]:
- """
- Returns:
- iterator: An iterator object whose element is (key, value) tuple
- pairs for ``metainfo`` and ``data``.
- """
- for k in self.all_keys():
- yield (k, getattr(self, k))
-
- def items(self) -> Iterator[Tuple[str, Any]]:
- """
- Returns:
- iterator: An iterator object whose element is (key, value) tuple
- pairs for ``data``.
- """
- for k in self.keys():
- yield (k, getattr(self, k))
-
- def metainfo_items(self) -> Iterator[Tuple[str, Any]]:
- """
- Returns:
- iterator: An iterator object whose element is (key, value) tuple
- pairs for ``metainfo``.
- """
- for k in self.metainfo_keys():
- yield (k, getattr(self, k))
-
- @property
- def metainfo(self) -> dict:
- """dict: A dict contains metainfo of current data element."""
- return dict(self.metainfo_items())
-
- def __setattr__(self, name: str, value: Any):
- """setattr is only used to set data."""
- if name in ("_metainfo_fields", "_data_fields"):
- if not hasattr(self, name):
- super().__setattr__(name, value)
- else:
- raise AttributeError(
- f"{name} has been used as a " "private attribute, which is immutable."
- )
- else:
- self.set_field(name=name, value=value, field_type="data", dtype=None)
-
- def __delattr__(self, item: str):
- """Delete the item in dataelement.
-
- Args:
- item (str): The key to delete.
- """
- if item in ("_metainfo_fields", "_data_fields"):
- raise AttributeError(
- f"{item} has been used as a " "private attribute, which is immutable."
- )
- super().__delattr__(item)
- if item in self._metainfo_fields:
- self._metainfo_fields.remove(item)
- elif item in self._data_fields:
- self._data_fields.remove(item)
-
- # dict-like methods
- __delitem__ = __delattr__
-
- def get(self, key, default=None) -> Any:
- """Get property in data and metainfo as the same as python."""
- # Use `getattr()` rather than `self.__dict__.get()` to allow getting
- # properties.
- return getattr(self, key, default)
-
- def pop(self, *args) -> Any:
- """Pop property in data and metainfo as the same as python."""
- assert len(args) < 3, "``pop`` get more than 2 arguments"
- name = args[0]
- if name in self._metainfo_fields:
- self._metainfo_fields.remove(args[0])
- return self.__dict__.pop(*args)
-
- elif name in self._data_fields:
- self._data_fields.remove(args[0])
- return self.__dict__.pop(*args)
-
- # with default value
- elif len(args) == 2:
- return args[1]
- else:
- # don't just use 'self.__dict__.pop(*args)' for only popping key in
- # metainfo or data
- raise KeyError(f"{args[0]} is not contained in metainfo or data")
-
- def __contains__(self, item: str) -> bool:
- """Whether the item is in dataelement.
-
- Args:
- item (str): The key to inquire.
- """
- return item in self._data_fields or item in self._metainfo_fields
-
- def set_field(
- self,
- value: Any,
- name: str,
- dtype: Optional[Union[Type, Tuple[Type, ...]]] = None,
- field_type: str = "data",
- ) -> None:
- """Special method for set union field, used as property.setter
- functions."""
- assert field_type in ["metainfo", "data"]
- if dtype is not None:
- assert isinstance(value, dtype), f"{value} should be a {dtype} but got {type(value)}"
-
- if field_type == "metainfo":
- if name in self._data_fields:
- raise AttributeError(
- f"Cannot set {name} to be a field of metainfo "
- f"because {name} is already a data field"
- )
- self._metainfo_fields.add(name)
- else:
- if name in self._metainfo_fields:
- raise AttributeError(
- f"Cannot set {name} to be a field of data "
- f"because {name} is already a metainfo field"
- )
- self._data_fields.add(name)
- super().__setattr__(name, value)
-
- # Tensor-like methods
- def to(self, *args, **kwargs) -> "BaseDataElement":
- """Apply same name function to all tensors in data_fields."""
- new_data = self.new()
- for k, v in self.items():
- if hasattr(v, "to"):
- v = v.to(*args, **kwargs)
- data = {k: v}
- new_data.set_data(data)
- return new_data
-
- # Tensor-like methods
- def cpu(self) -> "BaseDataElement":
- """Convert all tensors to CPU in data."""
- new_data = self.new()
- for k, v in self.items():
- if isinstance(v, (torch.Tensor, BaseDataElement)):
- v = v.cpu()
- data = {k: v}
- new_data.set_data(data)
- return new_data
-
- # Tensor-like methods
- def cuda(self) -> "BaseDataElement":
- """Convert all tensors to GPU in data."""
- new_data = self.new()
- for k, v in self.items():
- if isinstance(v, (torch.Tensor, BaseDataElement)):
- v = v.cuda()
- data = {k: v}
- new_data.set_data(data)
- return new_data
-
- # Tensor-like methods
- def npu(self) -> "BaseDataElement":
- """Convert all tensors to NPU in data."""
- new_data = self.new()
- for k, v in self.items():
- if isinstance(v, (torch.Tensor, BaseDataElement)):
- v = v.npu()
- data = {k: v}
- new_data.set_data(data)
- return new_data
-
- def mlu(self) -> "BaseDataElement":
- """Convert all tensors to MLU in data."""
- new_data = self.new()
- for k, v in self.items():
- if isinstance(v, (torch.Tensor, BaseDataElement)):
- v = v.mlu()
- data = {k: v}
- new_data.set_data(data)
- return new_data
-
- # Tensor-like methods
- def detach(self) -> "BaseDataElement":
- """Detach all tensors in data."""
- new_data = self.new()
- for k, v in self.items():
- if isinstance(v, (torch.Tensor, BaseDataElement)):
- v = v.detach()
- data = {k: v}
- new_data.set_data(data)
- return new_data
-
- # Tensor-like methods
- def numpy(self) -> "BaseDataElement":
- """Convert all tensors to np.ndarray in data."""
- new_data = self.new()
- for k, v in self.items():
- if isinstance(v, (torch.Tensor, BaseDataElement)):
- v = v.detach().cpu().numpy()
- data = {k: v}
- new_data.set_data(data)
- return new_data
-
- def to_tensor(self) -> "BaseDataElement":
- """Convert all np.ndarray to tensor in data."""
- new_data = self.new()
- for k, v in self.items():
- data = {}
- if isinstance(v, np.ndarray):
- v = torch.from_numpy(v)
- data[k] = v
- elif isinstance(v, BaseDataElement):
- v = v.to_tensor()
- data[k] = v
- new_data.set_data(data)
- return new_data
-
- def to_dict(self) -> dict:
- """Convert BaseDataElement to dict."""
- return {
- k: v.to_dict() if isinstance(v, BaseDataElement) else v for k, v in self.all_items()
- }
-
- def __repr__(self) -> str:
- """Represent the object."""
-
- def _addindent(s_: str, num_spaces: int) -> str:
- """This func is modified from `pytorch` https://github.com/pytorch/
- pytorch/blob/b17b2b1cc7b017c3daaeff8cc7ec0f514d42ec37/torch/nn/modu
- les/module.py#L29.
-
- Args:
- s_ (str): The string to add spaces.
- num_spaces (int): The num of space to add.
-
- Returns:
- str: The string after add indent.
- """
- s = s_.split("\n")
- # don't do anything for single-line stuff
- if len(s) == 1:
- return s_
- first = s.pop(0)
- s = [(num_spaces * " ") + line for line in s]
- s = "\n".join(s) # type: ignore
- s = first + "\n" + s # type: ignore
- return s # type: ignore
-
- def dump(obj: Any) -> str:
- """Represent the object.
-
- Args:
- obj (Any): The obj to represent.
-
- Returns:
- str: The represented str.
- """
- _repr = ""
- if isinstance(obj, dict):
- for k, v in obj.items():
- _repr += f"\n{k}: {_addindent(dump(v), 4)}"
- elif isinstance(obj, BaseDataElement):
- _repr += "\n\n META INFORMATION"
- metainfo_items = dict(obj.metainfo_items())
- _repr += _addindent(dump(metainfo_items), 4)
- _repr += "\n\n DATA FIELDS"
- items = dict(obj.items())
- _repr += _addindent(dump(items), 4)
- classname = obj.__class__.__name__
- _repr = f"<{classname}({_repr}\n) at {hex(id(obj))}>"
- else:
- _repr += repr(obj)
- return _repr
-
- return dump(self)
|