Browse Source

[ENH] create ListData

ab_data
Gao Enhao 1 year ago
parent
commit
ddf7b7a3e1
1 changed files with 18 additions and 0 deletions
  1. +18
    -0
      abl/structures/list_data.py

+ 18
- 0
abl/structures/list_data.py View File

@@ -6,6 +6,8 @@ from typing import Any, List, Union
import numpy as np
import torch

from ..utils import flatten as flatten_list
from ..utils import to_hashable
from .base_data_element import BaseDataElement

BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor]
@@ -293,6 +295,22 @@ class ListData(BaseDataElement):
new_data[k] = new_values
return new_data # type:ignore

def flatten(self, item: IndexType) -> List:
"""Flatten self[item].

Returns:
list: Flattened data fields.
"""
return flatten_list(self[item])
def elements_num(self, item: IndexType) -> int:
"""int: The number of elements in self[item]."""
return len(self.flatten(item))
def to_tuple(self, item: IndexType) -> tuple:
"""tuple: The data fields in self[item] converted to tuple."""
return to_hashable(self[item])
def __len__(self) -> int:
"""int: The length of ListData."""
if len(self._data_fields) > 0:


Loading…
Cancel
Save