You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

list_data.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. # Modified from
  3. # https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa
  4. from typing import List, Union
  5. import numpy as np
  6. import torch
  7. from ...utils import flatten as flatten_list
  8. from ...utils import to_hashable
  9. from .base_data_element import BaseDataElement
  10. BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor]
  11. LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor]
  12. IndexType = Union[str, slice, int, list, LongTypeTensor, BoolTypeTensor, np.ndarray]
  13. class ListData(BaseDataElement):
  14. """
  15. Abstract Data Interface used throughout the ABL Kit.
  16. ``ListData`` is the underlying data structure used in the ABL Kit,
  17. designed to manage diverse forms of data dynamically generated throughout the
  18. Abductive Learning (ABL) framework. This includes handling raw data, predicted
  19. pseudo-labels, abduced pseudo-labels, pseudo-label indices, etc.
  20. As a fundamental data structure in ABL, ``ListData`` is essential for the smooth
  21. transfer and manipulation of data across various components of the ABL framework,
  22. such as prediction, abductive reasoning, and training phases. It provides a
  23. unified data format across these stages, ensuring compatibility and flexibility
  24. in handling diverse data forms in the ABL framework.
  25. The attributes in ``ListData`` are divided into two parts,
  26. the ``metainfo`` and the ``data`` respectively.
  27. - ``metainfo``: Usually used to store basic information about data examples,
  28. such as symbol number, image size, etc. The attributes can be accessed or
  29. modified by dict-like or object-like operations, such as ``.`` (for data
  30. access and modification), ``in``, ``del``, ``pop(str)``, ``get(str)``,
  31. ``metainfo_keys()``, ``metainfo_values()``, ``metainfo_items()``,
  32. ``set_metainfo()`` (for set or change key-value pairs in metainfo).
  33. - ``data``: raw data, labels, predictions, and abduced results are stored.
  34. The attributes can be accessed or modified by dict-like or object-like operations,
  35. such as ``.``, ``in``, ``del``, ``pop(str)``, ``get(str)``, ``keys()``,
  36. ``values()``, ``items()``. Users can also apply tensor-like
  37. methods to all :obj:`torch.Tensor` in the ``data_fields``, such as ``.cuda()``,
  38. ``.cpu()``, ``.numpy()``, ``.to()``, ``to_tensor()``, ``.detach()``.
  39. ListData supports ``index`` and ``slice`` for data field. The type of value in
  40. data field can be either ``None`` or ``list`` of base data structures such as
  41. ``torch.Tensor``, ``numpy.ndarray``, ``list``, ``str`` and ``tuple``.
  42. This design is inspired by and extends the functionalities of the ``BaseDataElement``
  43. class implemented in `MMEngine <https://github.com/open-mmlab/mmengine/blob/main/mmengine/structures/base_data_element.py>`_. # noqa: E501
  44. Examples:
  45. >>> from ablkit.data.structures import ListData
  46. >>> import numpy as np
  47. >>> import torch
  48. >>> data_examples = ListData()
  49. >>> data_examples.X = [list(torch.randn(2)) for _ in range(3)]
  50. >>> data_examples.Y = [1, 2, 3]
  51. >>> data_examples.gt_pseudo_label = [[1, 2], [3, 4], [5, 6]]
  52. >>> len(data_examples)
  53. 3
  54. >>> print(data_examples)
  55. <ListData(
  56. META INFORMATION
  57. DATA FIELDS
  58. Y: [1, 2, 3]
  59. gt_pseudo_label: [[1, 2], [3, 4], [5, 6]]
  60. X: [[tensor(1.1949), tensor(-0.9378)], [tensor(0.7414), tensor(0.7603)], [tensor(1.0587), tensor(1.9697)]] # noqa: E501
  61. ) at 0x7f3bbf1991c0>
  62. >>> print(data_examples[:1])
  63. <ListData(
  64. META INFORMATION
  65. DATA FIELDS
  66. Y: [1]
  67. gt_pseudo_label: [[1, 2]]
  68. X: [[tensor(1.1949), tensor(-0.9378)]]
  69. ) at 0x7f3bbf1a3580>
  70. >>> print(data_examples.elements_num("X"))
  71. 6
  72. >>> print(data_examples.flatten("gt_pseudo_label"))
  73. [1, 2, 3, 4, 5, 6]
  74. >>> print(data_examples.to_tuple("Y"))
  75. (1, 2, 3)
  76. """
  77. def __setattr__(self, name: str, value: list):
  78. """setattr is only used to set data.
  79. The value must have the attribute of `__len__` and have the same length
  80. of `ListData`.
  81. """
  82. if name in ("_metainfo_fields", "_data_fields"):
  83. if not hasattr(self, name):
  84. super().__setattr__(name, value)
  85. else:
  86. raise AttributeError(
  87. f"{name} has been used as a " "private attribute, which is immutable."
  88. )
  89. else:
  90. # assert isinstance(value, list), "value must be of type `list`"
  91. # if len(self) > 0:
  92. # assert len(value) == len(self), (
  93. # "The length of "
  94. # f"values {len(value)} is "
  95. # "not consistent with "
  96. # "the length of this "
  97. # ":obj:`ListData` "
  98. # f"{len(self)}"
  99. # )
  100. super().__setattr__(name, value)
  101. __setitem__ = __setattr__
  102. def __getitem__(self, item: IndexType) -> "ListData":
  103. """
  104. Args:
  105. item (str, int, list, :obj:`slice`, :obj:`numpy.ndarray`,
  106. :obj:`torch.LongTensor`, :obj:`torch.BoolTensor`):
  107. Get the corresponding values according to item.
  108. Returns:
  109. :obj:`ListData`: Corresponding values.
  110. """
  111. assert isinstance(item, IndexType.__args__)
  112. if isinstance(item, list):
  113. item = np.array(item)
  114. if isinstance(item, np.ndarray):
  115. # The default int type of numpy is platform dependent, int32 for
  116. # windows and int64 for linux. `torch.Tensor` requires the index
  117. # should be int64, therefore we simply convert it to int64 here.
  118. # More details in https://github.com/numpy/numpy/issues/9464
  119. item = item.astype(np.int64) if item.dtype == np.int32 else item
  120. item = torch.from_numpy(item)
  121. if isinstance(item, str):
  122. return getattr(self, item)
  123. new_data = self.__class__(metainfo=self.metainfo)
  124. if isinstance(item, torch.Tensor):
  125. assert item.dim() == 1, "Only support to get the" " values along the first dimension."
  126. for k, v in self.items():
  127. if v is None:
  128. new_data[k] = None
  129. elif isinstance(v, torch.Tensor):
  130. new_data[k] = v[item]
  131. elif isinstance(v, np.ndarray):
  132. new_data[k] = v[item.cpu().numpy()]
  133. elif isinstance(v, (str, list, tuple)) or (
  134. hasattr(v, "__getitem__") and hasattr(v, "cat")
  135. ):
  136. # convert to indexes from BoolTensor
  137. if isinstance(item, BoolTypeTensor.__args__):
  138. indexes = torch.nonzero(item).view(-1).cpu().numpy().tolist()
  139. else:
  140. indexes = item.cpu().numpy().tolist()
  141. slice_list = []
  142. if indexes:
  143. for index in indexes:
  144. slice_list.append(slice(index, None, len(v)))
  145. else:
  146. slice_list.append(slice(None, 0, None))
  147. r_list = [v[s] for s in slice_list]
  148. if isinstance(v, (str, list, tuple)):
  149. new_value = r_list[0]
  150. for r in r_list[1:]:
  151. new_value = new_value + r
  152. else:
  153. new_value = v.cat(r_list)
  154. new_data[k] = new_value
  155. else:
  156. raise ValueError(
  157. f"The type of `{k}` is `{type(v)}`, which has no "
  158. "attribute of `cat`, so it does not "
  159. "support slice with `bool`"
  160. )
  161. else:
  162. # item is a slice or int
  163. for k, v in self.items():
  164. if v is None:
  165. new_data[k] = None
  166. else:
  167. new_data[k] = v[item]
  168. return new_data # type:ignore
  169. def flatten(self, item: str) -> List:
  170. """
  171. Flatten the list of the attribute specified by ``item``.
  172. Parameters
  173. ----------
  174. item
  175. Name of the attribute to be flattened.
  176. Returns
  177. -------
  178. list
  179. The flattened list of the attribute specified by ``item``.
  180. """
  181. return flatten_list(self[item])
  182. def elements_num(self, item: str) -> int:
  183. """
  184. Return the number of elements in the attribute specified by ``item``.
  185. Parameters
  186. ----------
  187. item : str
  188. Name of the attribute for which the number of elements is to be determined.
  189. Returns
  190. -------
  191. int
  192. The number of elements in the attribute specified by ``item``.
  193. """
  194. return len(self.flatten(item))
  195. def to_tuple(self, item: str) -> tuple:
  196. """
  197. Convert the attribute specified by ``item`` to a tuple.
  198. Parameters
  199. ----------
  200. item : str
  201. Name of the attribute to be converted.
  202. Returns
  203. -------
  204. tuple
  205. The attribute after conversion to a tuple.
  206. """
  207. return to_hashable(self[item])
  208. def __len__(self) -> int:
  209. """int: The length of ListData."""
  210. iterator = iter(self._data_fields)
  211. data = next(iterator)
  212. while getattr(self, data) is None:
  213. try:
  214. data = next(iterator)
  215. except StopIteration:
  216. break
  217. if getattr(self, data) is None:
  218. raise ValueError("All data fields are None.")
  219. else:
  220. return len(getattr(self, data))

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.