# Copyright (c) OpenMMLab. All rights reserved. # Modified from # https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa from typing import List, Union import numpy as np import torch from ...utils import flatten as flatten_list from ...utils import to_hashable from .base_data_element import BaseDataElement BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor] LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor] IndexType = Union[str, slice, int, list, LongTypeTensor, BoolTypeTensor, np.ndarray] class ListData(BaseDataElement): """ Abstract Data Interface used throughout the ABL Kit. ``ListData`` is the underlying data structure used in the ABL Kit, designed to manage diverse forms of data dynamically generated throughout the Abductive Learning (ABL) framework. This includes handling raw data, predicted pseudo-labels, abduced pseudo-labels, pseudo-label indices, etc. As a fundamental data structure in ABL, ``ListData`` is essential for the smooth transfer and manipulation of data across various components of the ABL framework, such as prediction, abductive reasoning, and training phases. It provides a unified data format across these stages, ensuring compatibility and flexibility in handling diverse data forms in the ABL framework. The attributes in ``ListData`` are divided into two parts, the ``metainfo`` and the ``data`` respectively. - ``metainfo``: Usually used to store basic information about data examples, such as symbol number, image size, etc. The attributes can be accessed or modified by dict-like or object-like operations, such as ``.`` (for data access and modification), ``in``, ``del``, ``pop(str)``, ``get(str)``, ``metainfo_keys()``, ``metainfo_values()``, ``metainfo_items()``, ``set_metainfo()`` (for set or change key-value pairs in metainfo). - ``data``: raw data, labels, predictions, and abduced results are stored. The attributes can be accessed or modified by dict-like or object-like operations, such as ``.``, ``in``, ``del``, ``pop(str)``, ``get(str)``, ``keys()``, ``values()``, ``items()``. Users can also apply tensor-like methods to all :obj:`torch.Tensor` in the ``data_fields``, such as ``.cuda()``, ``.cpu()``, ``.numpy()``, ``.to()``, ``to_tensor()``, ``.detach()``. ListData supports ``index`` and ``slice`` for data field. The type of value in data field can be either ``None`` or ``list`` of base data structures such as ``torch.Tensor``, ``numpy.ndarray``, ``list``, ``str`` and ``tuple``. This design is inspired by and extends the functionalities of the ``BaseDataElement`` class implemented in `MMEngine `_. # noqa: E501 Examples: >>> from ablkit.data.structures import ListData >>> import numpy as np >>> import torch >>> data_examples = ListData() >>> data_examples.X = [list(torch.randn(2)) for _ in range(3)] >>> data_examples.Y = [1, 2, 3] >>> data_examples.gt_pseudo_label = [[1, 2], [3, 4], [5, 6]] >>> len(data_examples) 3 >>> print(data_examples) >>> print(data_examples[:1]) >>> print(data_examples.elements_num("X")) 6 >>> print(data_examples.flatten("gt_pseudo_label")) [1, 2, 3, 4, 5, 6] >>> print(data_examples.to_tuple("Y")) (1, 2, 3) """ def __setattr__(self, name: str, value: list): """setattr is only used to set data. The value must have the attribute of `__len__` and have the same length of `ListData`. """ if name in ("_metainfo_fields", "_data_fields"): if not hasattr(self, name): super().__setattr__(name, value) else: raise AttributeError( f"{name} has been used as a " "private attribute, which is immutable." ) else: # assert isinstance(value, list), "value must be of type `list`" # if len(self) > 0: # assert len(value) == len(self), ( # "The length of " # f"values {len(value)} is " # "not consistent with " # "the length of this " # ":obj:`ListData` " # f"{len(self)}" # ) super().__setattr__(name, value) __setitem__ = __setattr__ def __getitem__(self, item: IndexType) -> "ListData": """ Args: item (str, int, list, :obj:`slice`, :obj:`numpy.ndarray`, :obj:`torch.LongTensor`, :obj:`torch.BoolTensor`): Get the corresponding values according to item. Returns: :obj:`ListData`: Corresponding values. """ assert isinstance(item, IndexType.__args__) if isinstance(item, list): item = np.array(item) if isinstance(item, np.ndarray): # The default int type of numpy is platform dependent, int32 for # windows and int64 for linux. `torch.Tensor` requires the index # should be int64, therefore we simply convert it to int64 here. # More details in https://github.com/numpy/numpy/issues/9464 item = item.astype(np.int64) if item.dtype == np.int32 else item item = torch.from_numpy(item) if isinstance(item, str): return getattr(self, item) new_data = self.__class__(metainfo=self.metainfo) if isinstance(item, torch.Tensor): assert item.dim() == 1, "Only support to get the" " values along the first dimension." for k, v in self.items(): if v is None: new_data[k] = None elif isinstance(v, torch.Tensor): new_data[k] = v[item] elif isinstance(v, np.ndarray): new_data[k] = v[item.cpu().numpy()] elif isinstance(v, (str, list, tuple)) or ( hasattr(v, "__getitem__") and hasattr(v, "cat") ): # convert to indexes from BoolTensor if isinstance(item, BoolTypeTensor.__args__): indexes = torch.nonzero(item).view(-1).cpu().numpy().tolist() else: indexes = item.cpu().numpy().tolist() slice_list = [] if indexes: for index in indexes: slice_list.append(slice(index, None, len(v))) else: slice_list.append(slice(None, 0, None)) r_list = [v[s] for s in slice_list] if isinstance(v, (str, list, tuple)): new_value = r_list[0] for r in r_list[1:]: new_value = new_value + r else: new_value = v.cat(r_list) new_data[k] = new_value else: raise ValueError( f"The type of `{k}` is `{type(v)}`, which has no " "attribute of `cat`, so it does not " "support slice with `bool`" ) else: # item is a slice or int for k, v in self.items(): if v is None: new_data[k] = None else: new_data[k] = v[item] return new_data # type:ignore def flatten(self, item: str) -> List: """ Flatten the list of the attribute specified by ``item``. Parameters ---------- item Name of the attribute to be flattened. Returns ------- list The flattened list of the attribute specified by ``item``. """ return flatten_list(self[item]) def elements_num(self, item: str) -> int: """ Return the number of elements in the attribute specified by ``item``. Parameters ---------- item : str Name of the attribute for which the number of elements is to be determined. Returns ------- int The number of elements in the attribute specified by ``item``. """ return len(self.flatten(item)) def to_tuple(self, item: str) -> tuple: """ Convert the attribute specified by ``item`` to a tuple. Parameters ---------- item : str Name of the attribute to be converted. Returns ------- tuple The attribute after conversion to a tuple. """ return to_hashable(self[item]) def __len__(self) -> int: """int: The length of ListData.""" iterator = iter(self._data_fields) data = next(iterator) while getattr(self, data) is None: try: data = next(iterator) except StopIteration: break if getattr(self, data) is None: raise ValueError("All data fields are None.") else: return len(getattr(self, data))