|
|
@@ -6,6 +6,8 @@ from typing import Any, List, Union |
|
|
|
import numpy as np |
|
|
|
import torch |
|
|
|
|
|
|
|
from ..utils import flatten as flatten_list |
|
|
|
from ..utils import to_hashable |
|
|
|
from .base_data_element import BaseDataElement |
|
|
|
|
|
|
|
BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor] |
|
|
@@ -293,6 +295,22 @@ class ListData(BaseDataElement): |
|
|
|
new_data[k] = new_values |
|
|
|
return new_data # type:ignore |
|
|
|
|
|
|
|
def flatten(self, item: IndexType) -> List: |
|
|
|
"""Flatten self[item]. |
|
|
|
|
|
|
|
Returns: |
|
|
|
list: Flattened data fields. |
|
|
|
""" |
|
|
|
return flatten_list(self[item]) |
|
|
|
|
|
|
|
def elements_num(self, item: IndexType) -> int: |
|
|
|
"""int: The number of elements in self[item].""" |
|
|
|
return len(self.flatten(item)) |
|
|
|
|
|
|
|
def to_tuple(self, item: IndexType) -> tuple: |
|
|
|
"""tuple: The data fields in self[item] converted to tuple.""" |
|
|
|
return to_hashable(self[item]) |
|
|
|
|
|
|
|
def __len__(self) -> int: |
|
|
|
"""int: The length of ListData.""" |
|
|
|
if len(self._data_fields) > 0: |
|
|
|