|
|
@@ -18,106 +18,47 @@ IndexType = Union[str, slice, int, list, LongTypeTensor, BoolTypeTensor, np.ndar |
|
|
|
# Modified from |
|
|
|
# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa |
|
|
|
class ListData(BaseDataElement): |
|
|
|
"""Data structure for instance-level annotations or predictions. |
|
|
|
""" |
|
|
|
Data structure for example-level data. |
|
|
|
|
|
|
|
Subclass of :class:`BaseDataElement`. All value in `data_fields` |
|
|
|
should have the same length. This design refer to |
|
|
|
https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py # noqa E501 |
|
|
|
ListData also support extra functions: ``index``, ``slice`` and ``cat`` for data field. The type of value |
|
|
|
in data field can be base data structure such as `torch.Tensor`, `numpy.ndarray`, `list`, `str`, `tuple`, |
|
|
|
and can be customized data structure that has ``__len__``, ``__getitem__`` and ``cat`` attributes. |
|
|
|
https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py |
|
|
|
|
|
|
|
ListData supports `index` and `slice` for data field. The type of value in data field can be either `None` or `list` of base data structures such as `torch.Tensor`, `numpy.ndarray`, `list`, `str` and `tuple`. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> # custom data structure |
|
|
|
>>> class TmpObject: |
|
|
|
... def __init__(self, tmp) -> None: |
|
|
|
... assert isinstance(tmp, list) |
|
|
|
... self.tmp = tmp |
|
|
|
... def __len__(self): |
|
|
|
... return len(self.tmp) |
|
|
|
... def __getitem__(self, item): |
|
|
|
... if isinstance(item, int): |
|
|
|
... if item >= len(self) or item < -len(self): # type:ignore |
|
|
|
... raise IndexError(f'Index {item} out of range!') |
|
|
|
... else: |
|
|
|
... # keep the dimension |
|
|
|
... item = slice(item, None, len(self)) |
|
|
|
... return TmpObject(self.tmp[item]) |
|
|
|
... @staticmethod |
|
|
|
... def cat(tmp_objs): |
|
|
|
... assert all(isinstance(results, TmpObject) for results in tmp_objs) |
|
|
|
... if len(tmp_objs) == 1: |
|
|
|
... return tmp_objs[0] |
|
|
|
... tmp_list = [tmp_obj.tmp for tmp_obj in tmp_objs] |
|
|
|
... tmp_list = list(itertools.chain(*tmp_list)) |
|
|
|
... new_data = TmpObject(tmp_list) |
|
|
|
... return new_data |
|
|
|
... def __repr__(self): |
|
|
|
... return str(self.tmp) |
|
|
|
>>> from mmengine.structures import ListData |
|
|
|
>>> from abl.structures import ListData |
|
|
|
>>> import numpy as np |
|
|
|
>>> import torch |
|
|
|
>>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3)) |
|
|
|
>>> instance_data = ListData(metainfo=img_meta) |
|
|
|
>>> 'img_shape' in instance_data |
|
|
|
True |
|
|
|
>>> instance_data.det_labels = torch.LongTensor([2, 3]) |
|
|
|
>>> instance_data["det_scores"] = torch.Tensor([0.8, 0.7]) |
|
|
|
>>> instance_data.bboxes = torch.rand((2, 4)) |
|
|
|
>>> instance_data.polygons = TmpObject([[1, 2, 3, 4], [5, 6, 7, 8]]) |
|
|
|
>>> len(instance_data) |
|
|
|
2 |
|
|
|
>>> print(instance_data) |
|
|
|
<ListData( |
|
|
|
META INFORMATION |
|
|
|
img_shape: (800, 1196, 3) |
|
|
|
pad_shape: (800, 1216, 3) |
|
|
|
DATA FIELDS |
|
|
|
det_labels: tensor([2, 3]) |
|
|
|
det_scores: tensor([0.8000, 0.7000]) |
|
|
|
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188], |
|
|
|
[0.8101, 0.3105, 0.5123, 0.6263]]) |
|
|
|
polygons: [[1, 2, 3, 4], [5, 6, 7, 8]] |
|
|
|
) at 0x7fb492de6280> |
|
|
|
>>> sorted_results = instance_data[instance_data.det_scores.sort().indices] |
|
|
|
>>> sorted_results.det_scores |
|
|
|
tensor([0.7000, 0.8000]) |
|
|
|
>>> print(instance_data[instance_data.det_scores > 0.75]) |
|
|
|
<ListData( |
|
|
|
META INFORMATION |
|
|
|
img_shape: (800, 1196, 3) |
|
|
|
pad_shape: (800, 1216, 3) |
|
|
|
DATA FIELDS |
|
|
|
det_labels: tensor([2]) |
|
|
|
det_scores: tensor([0.8000]) |
|
|
|
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188]]) |
|
|
|
polygons: [[1, 2, 3, 4]] |
|
|
|
) at 0x7f64ecf0ec40> |
|
|
|
>>> print(instance_data[instance_data.det_scores > 1]) |
|
|
|
>>> data_examples = ListData() |
|
|
|
>>> data_examples.X = [list(torch.randn(2)) for _ in range(3)] |
|
|
|
>>> data_examples.Y = [1, 2, 3] |
|
|
|
>>> data_examples.gt_pseudo_label = [[1, 2], [3, 4], [5, 6]] |
|
|
|
>>> len(data_examples) |
|
|
|
3 |
|
|
|
>>> print(data_examples) |
|
|
|
<ListData( |
|
|
|
META INFORMATION |
|
|
|
img_shape: (800, 1196, 3) |
|
|
|
pad_shape: (800, 1216, 3) |
|
|
|
DATA FIELDS |
|
|
|
det_labels: tensor([], dtype=torch.int64) |
|
|
|
det_scores: tensor([]) |
|
|
|
bboxes: tensor([], size=(0, 4)) |
|
|
|
polygons: [] |
|
|
|
) at 0x7f660a6a7f70> |
|
|
|
>>> print(instance_data.cat([instance_data, instance_data])) |
|
|
|
Y: [1, 2, 3] |
|
|
|
gt_pseudo_label: [[1, 2], [3, 4], [5, 6]] |
|
|
|
X: [[tensor(1.1949), tensor(-0.9378)], [tensor(0.7414), tensor(0.7603)], [tensor(1.0587), tensor(1.9697)]] |
|
|
|
) at 0x7f3bbf1991c0> |
|
|
|
>>> print(data_examples[:1]) |
|
|
|
<ListData( |
|
|
|
META INFORMATION |
|
|
|
img_shape: (800, 1196, 3) |
|
|
|
pad_shape: (800, 1216, 3) |
|
|
|
DATA FIELDS |
|
|
|
det_labels: tensor([2, 3, 2, 3]) |
|
|
|
det_scores: tensor([0.8000, 0.7000, 0.8000, 0.7000]) |
|
|
|
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188], |
|
|
|
[0.8101, 0.3105, 0.5123, 0.6263], |
|
|
|
[0.4997, 0.7707, 0.0595, 0.4188], |
|
|
|
[0.8101, 0.3105, 0.5123, 0.6263]]) |
|
|
|
polygons: [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [5, 6, 7, 8]] |
|
|
|
) at 0x7f203542feb0> |
|
|
|
Y: [1] |
|
|
|
gt_pseudo_label: [[1, 2]] |
|
|
|
X: [[tensor(1.1949), tensor(-0.9378)]] |
|
|
|
) at 0x7f3bbf1a3580> |
|
|
|
>>> print(data_examples.elements_num("X")) |
|
|
|
6 |
|
|
|
>>> print(data_examples.flatten("gt_pseudo_label")) |
|
|
|
[1, 2, 3, 4, 5, 6] |
|
|
|
>>> print(data_examples.to_tuple("Y")) |
|
|
|
(1, 2, 3) |
|
|
|
""" |
|
|
|
|
|
|
|
def __setattr__(self, name: str, value: list): |
|
|
@@ -224,74 +165,52 @@ class ListData(BaseDataElement): |
|
|
|
new_data[k] = v[item] |
|
|
|
return new_data # type:ignore |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def cat(instances_list: List["ListData"]) -> "ListData": |
|
|
|
"""Concat the instances of all :obj:`ListData` in the list. |
|
|
|
def flatten(self, item: str) -> List: |
|
|
|
""" |
|
|
|
Flatten the list of the attribute specified by ``item``. |
|
|
|
|
|
|
|
Note: To ensure that cat returns as expected, make sure that |
|
|
|
all elements in the list must have exactly the same keys. |
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
item |
|
|
|
Name of the attribute to be flattened. |
|
|
|
|
|
|
|
Args: |
|
|
|
instances_list (list[:obj:`ListData`]): A list |
|
|
|
of :obj:`ListData`. |
|
|
|
Returns |
|
|
|
------- |
|
|
|
list |
|
|
|
The flattened list of the attribute specified by ``item``. |
|
|
|
""" |
|
|
|
return flatten_list(self[item]) |
|
|
|
|
|
|
|
Returns: |
|
|
|
:obj:`ListData` |
|
|
|
def elements_num(self, item: str) -> int: |
|
|
|
""" |
|
|
|
assert all(isinstance(results, ListData) for results in instances_list) |
|
|
|
assert len(instances_list) > 0 |
|
|
|
if len(instances_list) == 1: |
|
|
|
return instances_list[0] |
|
|
|
|
|
|
|
# metainfo and data_fields must be exactly the |
|
|
|
# same for each element to avoid exceptions. |
|
|
|
field_keys_list = [instances.all_keys() for instances in instances_list] |
|
|
|
assert len({len(field_keys) for field_keys in field_keys_list}) == 1 and len( |
|
|
|
set(itertools.chain(*field_keys_list)) |
|
|
|
) == len(field_keys_list[0]), ( |
|
|
|
"There are different keys in " |
|
|
|
"`instances_list`, which may " |
|
|
|
"cause the cat operation " |
|
|
|
"to fail. Please make sure all " |
|
|
|
"elements in `instances_list` " |
|
|
|
"have the exact same key." |
|
|
|
) |
|
|
|
|
|
|
|
new_data = instances_list[0].__class__(metainfo=instances_list[0].metainfo) |
|
|
|
for k in instances_list[0].keys(): |
|
|
|
values = [results[k] for results in instances_list] |
|
|
|
v0 = values[0] |
|
|
|
if isinstance(v0, torch.Tensor): |
|
|
|
new_values = torch.cat(values, dim=0) |
|
|
|
elif isinstance(v0, np.ndarray): |
|
|
|
new_values = np.concatenate(values, axis=0) |
|
|
|
elif isinstance(v0, (str, list, tuple)): |
|
|
|
new_values = v0[:] |
|
|
|
for v in values[1:]: |
|
|
|
new_values += v |
|
|
|
elif hasattr(v0, "cat"): |
|
|
|
new_values = v0.cat(values) |
|
|
|
else: |
|
|
|
raise ValueError( |
|
|
|
f"The type of `{k}` is `{type(v0)}` which has no " "attribute of `cat`" |
|
|
|
) |
|
|
|
new_data[k] = new_values |
|
|
|
return new_data # type:ignore |
|
|
|
Return the number of elements in the attribute specified by ``item``. |
|
|
|
|
|
|
|
def flatten(self, item: IndexType) -> List: |
|
|
|
"""Flatten self[item]. |
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
item : str |
|
|
|
Name of the attribute for which the number of elements is to be determined. |
|
|
|
|
|
|
|
Returns: |
|
|
|
list: Flattened data fields. |
|
|
|
Returns |
|
|
|
------- |
|
|
|
int |
|
|
|
The number of elements in the attribute specified by ``item``. |
|
|
|
""" |
|
|
|
return flatten_list(self[item]) |
|
|
|
|
|
|
|
def elements_num(self, item: IndexType) -> int: |
|
|
|
"""int: The number of elements in self[item].""" |
|
|
|
return len(self.flatten(item)) |
|
|
|
|
|
|
|
def to_tuple(self, item: IndexType) -> tuple: |
|
|
|
"""tuple: The data fields in self[item] converted to tuple.""" |
|
|
|
def to_tuple(self, item: str) -> tuple: |
|
|
|
""" |
|
|
|
Convert the attribute specified by ``item`` to a tuple. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
item : str |
|
|
|
Name of the attribute to be converted. |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
tuple |
|
|
|
The attribute after conversion to a tuple. |
|
|
|
""" |
|
|
|
return to_hashable(self[item]) |
|
|
|
|
|
|
|
def __len__(self) -> int: |
|
|
|