Browse Source

[DOC] modify doc of abl.structures

pull/1/head
Gao Enhao 1 year ago
parent
commit
ea0a5a3668
4 changed files with 74 additions and 152 deletions
  1. +2
    -1
      abl/structures/base_data_element.py
  2. +65
    -146
      abl/structures/list_data.py
  3. +2
    -2
      docs/API/abl.structures.rst
  4. +5
    -3
      docs/Intro/Datasets.rst

+ 2
- 1
abl/structures/base_data_element.py View File

@@ -5,7 +5,8 @@ from typing import Any, Iterator, Optional, Tuple, Type, Union
import numpy as np
import torch


# Modified from
# https://github.com/open-mmlab/mmengine/blob/main/mmengine/structures/base_data_element.py
class BaseDataElement:
"""A base data interface that supports Tensor-like and dict-like
operations.


+ 65
- 146
abl/structures/list_data.py View File

@@ -18,106 +18,47 @@ IndexType = Union[str, slice, int, list, LongTypeTensor, BoolTypeTensor, np.ndar
# Modified from
# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa
class ListData(BaseDataElement):
"""Data structure for instance-level annotations or predictions.
"""
Data structure for example-level data.

Subclass of :class:`BaseDataElement`. All value in `data_fields`
should have the same length. This design refer to
https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py # noqa E501
ListData also support extra functions: ``index``, ``slice`` and ``cat`` for data field. The type of value
in data field can be base data structure such as `torch.Tensor`, `numpy.ndarray`, `list`, `str`, `tuple`,
and can be customized data structure that has ``__len__``, ``__getitem__`` and ``cat`` attributes.
https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py

ListData supports `index` and `slice` for data field. The type of value in data field can be either `None` or `list` of base data structures such as `torch.Tensor`, `numpy.ndarray`, `list`, `str` and `tuple`.

Examples:
>>> # custom data structure
>>> class TmpObject:
... def __init__(self, tmp) -> None:
... assert isinstance(tmp, list)
... self.tmp = tmp
... def __len__(self):
... return len(self.tmp)
... def __getitem__(self, item):
... if isinstance(item, int):
... if item >= len(self) or item < -len(self): # type:ignore
... raise IndexError(f'Index {item} out of range!')
... else:
... # keep the dimension
... item = slice(item, None, len(self))
... return TmpObject(self.tmp[item])
... @staticmethod
... def cat(tmp_objs):
... assert all(isinstance(results, TmpObject) for results in tmp_objs)
... if len(tmp_objs) == 1:
... return tmp_objs[0]
... tmp_list = [tmp_obj.tmp for tmp_obj in tmp_objs]
... tmp_list = list(itertools.chain(*tmp_list))
... new_data = TmpObject(tmp_list)
... return new_data
... def __repr__(self):
... return str(self.tmp)
>>> from mmengine.structures import ListData
>>> from abl.structures import ListData
>>> import numpy as np
>>> import torch
>>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3))
>>> instance_data = ListData(metainfo=img_meta)
>>> 'img_shape' in instance_data
True
>>> instance_data.det_labels = torch.LongTensor([2, 3])
>>> instance_data["det_scores"] = torch.Tensor([0.8, 0.7])
>>> instance_data.bboxes = torch.rand((2, 4))
>>> instance_data.polygons = TmpObject([[1, 2, 3, 4], [5, 6, 7, 8]])
>>> len(instance_data)
2
>>> print(instance_data)
<ListData(
META INFORMATION
img_shape: (800, 1196, 3)
pad_shape: (800, 1216, 3)
DATA FIELDS
det_labels: tensor([2, 3])
det_scores: tensor([0.8000, 0.7000])
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188],
[0.8101, 0.3105, 0.5123, 0.6263]])
polygons: [[1, 2, 3, 4], [5, 6, 7, 8]]
) at 0x7fb492de6280>
>>> sorted_results = instance_data[instance_data.det_scores.sort().indices]
>>> sorted_results.det_scores
tensor([0.7000, 0.8000])
>>> print(instance_data[instance_data.det_scores > 0.75])
<ListData(
META INFORMATION
img_shape: (800, 1196, 3)
pad_shape: (800, 1216, 3)
DATA FIELDS
det_labels: tensor([2])
det_scores: tensor([0.8000])
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188]])
polygons: [[1, 2, 3, 4]]
) at 0x7f64ecf0ec40>
>>> print(instance_data[instance_data.det_scores > 1])
>>> data_examples = ListData()
>>> data_examples.X = [list(torch.randn(2)) for _ in range(3)]
>>> data_examples.Y = [1, 2, 3]
>>> data_examples.gt_pseudo_label = [[1, 2], [3, 4], [5, 6]]
>>> len(data_examples)
3
>>> print(data_examples)
<ListData(
META INFORMATION
img_shape: (800, 1196, 3)
pad_shape: (800, 1216, 3)
DATA FIELDS
det_labels: tensor([], dtype=torch.int64)
det_scores: tensor([])
bboxes: tensor([], size=(0, 4))
polygons: []
) at 0x7f660a6a7f70>
>>> print(instance_data.cat([instance_data, instance_data]))
Y: [1, 2, 3]
gt_pseudo_label: [[1, 2], [3, 4], [5, 6]]
X: [[tensor(1.1949), tensor(-0.9378)], [tensor(0.7414), tensor(0.7603)], [tensor(1.0587), tensor(1.9697)]]
) at 0x7f3bbf1991c0>
>>> print(data_examples[:1])
<ListData(
META INFORMATION
img_shape: (800, 1196, 3)
pad_shape: (800, 1216, 3)
DATA FIELDS
det_labels: tensor([2, 3, 2, 3])
det_scores: tensor([0.8000, 0.7000, 0.8000, 0.7000])
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188],
[0.8101, 0.3105, 0.5123, 0.6263],
[0.4997, 0.7707, 0.0595, 0.4188],
[0.8101, 0.3105, 0.5123, 0.6263]])
polygons: [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [5, 6, 7, 8]]
) at 0x7f203542feb0>
Y: [1]
gt_pseudo_label: [[1, 2]]
X: [[tensor(1.1949), tensor(-0.9378)]]
) at 0x7f3bbf1a3580>
>>> print(data_examples.elements_num("X"))
6
>>> print(data_examples.flatten("gt_pseudo_label"))
[1, 2, 3, 4, 5, 6]
>>> print(data_examples.to_tuple("Y"))
(1, 2, 3)
"""

def __setattr__(self, name: str, value: list):
@@ -224,74 +165,52 @@ class ListData(BaseDataElement):
new_data[k] = v[item]
return new_data # type:ignore

@staticmethod
def cat(instances_list: List["ListData"]) -> "ListData":
"""Concat the instances of all :obj:`ListData` in the list.
def flatten(self, item: str) -> List:
"""
Flatten the list of the attribute specified by ``item``.

Note: To ensure that cat returns as expected, make sure that
all elements in the list must have exactly the same keys.
Parameters
----------
item
Name of the attribute to be flattened.

Args:
instances_list (list[:obj:`ListData`]): A list
of :obj:`ListData`.
Returns
-------
list
The flattened list of the attribute specified by ``item``.
"""
return flatten_list(self[item])

Returns:
:obj:`ListData`
def elements_num(self, item: str) -> int:
"""
assert all(isinstance(results, ListData) for results in instances_list)
assert len(instances_list) > 0
if len(instances_list) == 1:
return instances_list[0]

# metainfo and data_fields must be exactly the
# same for each element to avoid exceptions.
field_keys_list = [instances.all_keys() for instances in instances_list]
assert len({len(field_keys) for field_keys in field_keys_list}) == 1 and len(
set(itertools.chain(*field_keys_list))
) == len(field_keys_list[0]), (
"There are different keys in "
"`instances_list`, which may "
"cause the cat operation "
"to fail. Please make sure all "
"elements in `instances_list` "
"have the exact same key."
)

new_data = instances_list[0].__class__(metainfo=instances_list[0].metainfo)
for k in instances_list[0].keys():
values = [results[k] for results in instances_list]
v0 = values[0]
if isinstance(v0, torch.Tensor):
new_values = torch.cat(values, dim=0)
elif isinstance(v0, np.ndarray):
new_values = np.concatenate(values, axis=0)
elif isinstance(v0, (str, list, tuple)):
new_values = v0[:]
for v in values[1:]:
new_values += v
elif hasattr(v0, "cat"):
new_values = v0.cat(values)
else:
raise ValueError(
f"The type of `{k}` is `{type(v0)}` which has no " "attribute of `cat`"
)
new_data[k] = new_values
return new_data # type:ignore
Return the number of elements in the attribute specified by ``item``.

def flatten(self, item: IndexType) -> List:
"""Flatten self[item].
Parameters
----------
item : str
Name of the attribute for which the number of elements is to be determined.

Returns:
list: Flattened data fields.
Returns
-------
int
The number of elements in the attribute specified by ``item``.
"""
return flatten_list(self[item])

def elements_num(self, item: IndexType) -> int:
"""int: The number of elements in self[item]."""
return len(self.flatten(item))

def to_tuple(self, item: IndexType) -> tuple:
"""tuple: The data fields in self[item] converted to tuple."""
def to_tuple(self, item: str) -> tuple:
"""
Convert the attribute specified by ``item`` to a tuple.

Parameters
----------
item : str
Name of the attribute to be converted.

Returns
-------
tuple
The attribute after conversion to a tuple.
"""
return to_hashable(self[item])

def __len__(self) -> int:


+ 2
- 2
docs/API/abl.structures.rst View File

@@ -1,7 +1,7 @@
abl.structures
==================

.. automodule:: abl.structures
.. autoclass:: abl.structures.ListData
:members:
:undoc-members:
:show-inheritance:
:show-inheritance:

+ 5
- 3
docs/Intro/Datasets.rst View File

@@ -21,7 +21,7 @@ In this section, we will look at the datasets and data structures in ABL-Package
Dataset
-------

ABL-Package assumes user data to be structured as a tuple, comprising the following three components:
ABL-Package assumes user data to be either structured as a tuple or a ``ListData`` which is the underlying data structure utilized in the whole package and will be introduced in the next section. Regardless of the chosen format, the data should encompass the following three essential components:

- ``X``: List[List[Any]]
@@ -53,9 +53,11 @@ As an illustration, in the MNIST Addition example, the data used for training ar
Data Structure
--------------

In Abductive Learning, there are various types of data in the training and testing process, such as raw data, pseudo-label, index of the pseudo-label, abduced pseudo-label, etc. To enhance the stability and versatility, ABL-Package uses `abstract data interfaces <../API/abl.structures.html>`_ to encapsulate various data during the implementation of the model.
Besides the user-provided dataset, various forms of data are utilized and dynamicly generate throughout the training and testing process of Abductive Learning framework. Examples include raw data, predicted pseudo-label, abduced pseudo-label, pseudo-label indices, and so on. To manage this diversity and ensure a stable, versatile interface, ABL-Package employs `abstract data interfaces <../API/abl.structures.html>`_ to encapsulate different forms of data that will be used in the total learning process.

One of the most commonly used abstract data interface is ``ListData``. Besides orginizing data into tuple, we can also prepare data to be in the form of this data interface.
``BaseDataElement`` is the base class for all abstract data interfaces. Inherited from ``BaseDataElement``, ``ListData`` is the most commonly used abstract data interface in ABL-Package. As the fundamental data structure, ``ListData`` implements commonly used data manipulation methods and is responsible for transferring data between various components of ABL, ensuring that stages such as prediction, training, and abductive reasoning can utilize ``ListData`` as a unified input format.

Before proceeding to other stages, user-provided datasets are firstly converted into ``ListData``. For flexibility, ABL-Package also allows user to directly supply data in ``ListData`` format, which similarly requires the inclusion of three attributes: ``X``, ``gt_pseudo_label``, and ``Y``. The following code shows the basic usage of ``ListData``. More information can be found in the `API documentation <../API/abl.structures.html>`_.

.. code-block:: python



Loading…
Cancel
Save