|
|
@@ -132,22 +132,21 @@ class ListData(BaseDataElement): |
|
|
|
super().__setattr__(name, value) |
|
|
|
else: |
|
|
|
raise AttributeError( |
|
|
|
f"{name} has been used as a " |
|
|
|
"private attribute, which is immutable." |
|
|
|
f"{name} has been used as a " "private attribute, which is immutable." |
|
|
|
) |
|
|
|
|
|
|
|
else: |
|
|
|
assert isinstance(value, list), "value must be of type `list`" |
|
|
|
# assert isinstance(value, list), "value must be of type `list`" |
|
|
|
|
|
|
|
if len(self) > 0: |
|
|
|
assert len(value) == len(self), ( |
|
|
|
"The length of " |
|
|
|
f"values {len(value)} is " |
|
|
|
"not consistent with " |
|
|
|
"the length of this " |
|
|
|
":obj:`ListData` " |
|
|
|
f"{len(self)}" |
|
|
|
) |
|
|
|
# if len(self) > 0: |
|
|
|
# assert len(value) == len(self), ( |
|
|
|
# "The length of " |
|
|
|
# f"values {len(value)} is " |
|
|
|
# "not consistent with " |
|
|
|
# "the length of this " |
|
|
|
# ":obj:`ListData` " |
|
|
|
# f"{len(self)}" |
|
|
|
# ) |
|
|
|
super().__setattr__(name, value) |
|
|
|
|
|
|
|
__setitem__ = __setattr__ |
|
|
@@ -176,32 +175,15 @@ class ListData(BaseDataElement): |
|
|
|
if isinstance(item, str): |
|
|
|
return getattr(self, item) |
|
|
|
|
|
|
|
if isinstance(item, int): |
|
|
|
if item >= len(self) or item < -len(self): # type:ignore |
|
|
|
raise IndexError(f"Index {item} out of range!") |
|
|
|
else: |
|
|
|
# keep the dimension |
|
|
|
item = slice(item, None, len(self)) |
|
|
|
|
|
|
|
new_data = self.__class__(metainfo=self.metainfo) |
|
|
|
|
|
|
|
if isinstance(item, torch.Tensor): |
|
|
|
assert item.dim() == 1, ( |
|
|
|
"Only support to get the" " values along the first dimension." |
|
|
|
) |
|
|
|
if isinstance(item, BoolTypeTensor.__args__): |
|
|
|
assert len(item) == len(self), ( |
|
|
|
"The shape of the " |
|
|
|
"input(BoolTensor) " |
|
|
|
f"{len(item)} " |
|
|
|
"does not match the shape " |
|
|
|
"of the indexed tensor " |
|
|
|
"in results_field " |
|
|
|
f"{len(self)} at " |
|
|
|
"first dimension." |
|
|
|
) |
|
|
|
assert item.dim() == 1, "Only support to get the" " values along the first dimension." |
|
|
|
|
|
|
|
for k, v in self.items(): |
|
|
|
if isinstance(v, torch.Tensor): |
|
|
|
if v is None: |
|
|
|
new_data[k] = None |
|
|
|
elif isinstance(v, torch.Tensor): |
|
|
|
new_data[k] = v[item] |
|
|
|
elif isinstance(v, np.ndarray): |
|
|
|
new_data[k] = v[item.cpu().numpy()] |
|
|
@@ -235,9 +217,12 @@ class ListData(BaseDataElement): |
|
|
|
) |
|
|
|
|
|
|
|
else: |
|
|
|
# item is a slice |
|
|
|
# item is a slice or int |
|
|
|
for k, v in self.items(): |
|
|
|
new_data[k] = v[item] |
|
|
|
if v is None: |
|
|
|
new_data[k] = None |
|
|
|
else: |
|
|
|
new_data[k] = v[item] |
|
|
|
return new_data # type:ignore |
|
|
|
|
|
|
|
@staticmethod |
|
|
@@ -289,8 +274,7 @@ class ListData(BaseDataElement): |
|
|
|
new_values = v0.cat(values) |
|
|
|
else: |
|
|
|
raise ValueError( |
|
|
|
f"The type of `{k}` is `{type(v0)}` which has no " |
|
|
|
"attribute of `cat`" |
|
|
|
f"The type of `{k}` is `{type(v0)}` which has no " "attribute of `cat`" |
|
|
|
) |
|
|
|
new_data[k] = new_values |
|
|
|
return new_data # type:ignore |
|
|
@@ -302,15 +286,15 @@ class ListData(BaseDataElement): |
|
|
|
list: Flattened data fields. |
|
|
|
""" |
|
|
|
return flatten_list(self[item]) |
|
|
|
|
|
|
|
|
|
|
|
def elements_num(self, item: IndexType) -> int: |
|
|
|
"""int: The number of elements in self[item].""" |
|
|
|
return len(self.flatten(item)) |
|
|
|
|
|
|
|
|
|
|
|
def to_tuple(self, item: IndexType) -> tuple: |
|
|
|
"""tuple: The data fields in self[item] converted to tuple.""" |
|
|
|
return to_hashable(self[item]) |
|
|
|
|
|
|
|
|
|
|
|
def __len__(self) -> int: |
|
|
|
"""int: The length of ListData.""" |
|
|
|
if len(self._data_fields) > 0: |
|
|
|