|
|
@@ -11,7 +11,7 @@ class FieldArray(object): |
|
|
|
""" |
|
|
|
|
|
|
|
:param str name: the name of the FieldArray |
|
|
|
:param list content: a list of int, float, or a list of list. |
|
|
|
:param list content: a list of int, float, str or np.ndarray, or a list of list of one. |
|
|
|
:param int padding_val: the integer for padding. Default: 0. |
|
|
|
:param bool is_target: If True, this FieldArray is used to compute loss. |
|
|
|
:param bool is_input: If True, this FieldArray is used to the model input. |
|
|
@@ -27,35 +27,46 @@ class FieldArray(object): |
|
|
|
self.padding_val = padding_val |
|
|
|
self.is_target = is_target |
|
|
|
self.is_input = is_input |
|
|
|
|
|
|
|
self.BASIC_TYPES = (int, float, str, np.ndarray) |
|
|
|
self.is_2d_list = False |
|
|
|
self.pytype = self._type_detection(content) |
|
|
|
self.dtype = self._map_to_np_type(self.pytype) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _type_detection(content): |
|
|
|
def _type_detection(self, content): |
|
|
|
""" |
|
|
|
|
|
|
|
:param content: a list of int, float, str or np.ndarray, or a list of list of one. |
|
|
|
:return type: one of int, float, str, np.ndarray |
|
|
|
|
|
|
|
""" |
|
|
|
if isinstance(content, list) and len(content) > 0 and isinstance(content[0], list): |
|
|
|
# 2-D list |
|
|
|
# TODO: refactor |
|
|
|
type_set = set([type(item) for item in content[0]]) |
|
|
|
else: |
|
|
|
# 1-D list |
|
|
|
# content is a 2-D list |
|
|
|
type_set = set([self._type_detection(x) for x in content]) |
|
|
|
if len(type_set) > 1: |
|
|
|
raise RuntimeError("Cannot create FieldArray with more than one type. Provided {}".format(type_set)) |
|
|
|
self.is_2d_list = True |
|
|
|
return type_set.pop() |
|
|
|
|
|
|
|
elif isinstance(content, list): |
|
|
|
# content is a 1-D list |
|
|
|
if len(content) == 0: |
|
|
|
raise RuntimeError("Cannot create FieldArray with an empty list.") |
|
|
|
type_set = set([type(item) for item in content]) |
|
|
|
|
|
|
|
if len(type_set) == 1 and any(basic_type in type_set for basic_type in (str, int, float)): |
|
|
|
return type_set.pop() |
|
|
|
elif len(type_set) == 2 and float in type_set and int in type_set: |
|
|
|
# up-cast int to float |
|
|
|
for idx, _ in enumerate(content): |
|
|
|
content[idx] = float(content[idx]) |
|
|
|
return float |
|
|
|
if len(type_set) == 1 and tuple(type_set)[0] in self.BASIC_TYPES: |
|
|
|
return type_set.pop() |
|
|
|
elif len(type_set) == 2 and float in type_set and int in type_set: |
|
|
|
# up-cast int to float |
|
|
|
return float |
|
|
|
else: |
|
|
|
raise RuntimeError("Cannot create FieldArray with type {}".format(*type_set)) |
|
|
|
else: |
|
|
|
raise ValueError("Unsupported type conversion detected in FieldArray: {}".format(*type_set)) |
|
|
|
raise RuntimeError("Cannot create FieldArray with type {}".format(type(content))) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _map_to_np_type(basic_type): |
|
|
|
type_mapping = {int: np.int64, float: np.float64, str: np.str} |
|
|
|
type_mapping = {int: np.int64, float: np.float64, str: np.str, np.ndarray: np.ndarray} |
|
|
|
return type_mapping[basic_type] |
|
|
|
|
|
|
|
def __repr__(self): |
|
|
@@ -64,29 +75,35 @@ class FieldArray(object): |
|
|
|
def append(self, val): |
|
|
|
"""Add a new item to the tail of FieldArray. |
|
|
|
|
|
|
|
:param val: int, float, str, or a list of them. |
|
|
|
:param val: int, float, str, or a list of one. |
|
|
|
""" |
|
|
|
val_type = type(val) |
|
|
|
if val_type is int and self.pytype is float: |
|
|
|
# up-cast the appended value |
|
|
|
val = float(val) |
|
|
|
elif val_type is float and self.pytype is int: |
|
|
|
# up-cast all other values in the content |
|
|
|
for idx, _ in enumerate(self.content): |
|
|
|
self.content[idx] = float(self.content[idx]) |
|
|
|
self.pytype = float |
|
|
|
self.dtype = self._map_to_np_type(self.pytype) |
|
|
|
elif val_type is list: |
|
|
|
if val_type == list: # shape check |
|
|
|
if self.is_2d_list is False: |
|
|
|
raise RuntimeError("Cannot append a list into a 1-D FieldArray. Please provide an element.") |
|
|
|
if len(val) == 0: |
|
|
|
raise ValueError("Cannot append an empty list.") |
|
|
|
raise RuntimeError("Cannot append an empty list.") |
|
|
|
val_list_type = [type(_) for _ in val] # type check |
|
|
|
if len(val_list_type) == 2 and int in val_list_type and float in val_list_type: |
|
|
|
# up-cast int to float |
|
|
|
val_type = float |
|
|
|
elif len(val_list_type) == 1: |
|
|
|
val_type = val_list_type[0] |
|
|
|
else: |
|
|
|
if type(val[0]) != self.pytype: |
|
|
|
raise ValueError( |
|
|
|
"Cannot append a list of {}-type value into a {}-tpye FieldArray.". |
|
|
|
format(type(val[0]), self.pytype)) |
|
|
|
elif val_type != self.pytype: |
|
|
|
raise ValueError("Cannot append a {}-type value into a {}-tpye FieldArray.".format(val_type, self.pytype)) |
|
|
|
|
|
|
|
raise RuntimeError("Cannot append a list of {}".format(val_list_type)) |
|
|
|
else: |
|
|
|
if self.is_2d_list is True: |
|
|
|
raise RuntimeError("Cannot append a non-list into a 2-D list. Please provide a list.") |
|
|
|
if val_type == float and self.pytype == int: |
|
|
|
# up-cast |
|
|
|
self.pytype = float |
|
|
|
self.dtype = self._map_to_np_type(self.pytype) |
|
|
|
elif val_type == int and self.pytype == float: |
|
|
|
pass |
|
|
|
elif val_type == self.pytype: |
|
|
|
pass |
|
|
|
else: |
|
|
|
raise RuntimeError("Cannot append type {} into type {}".format(val_type, self.pytype)) |
|
|
|
self.content.append(val) |
|
|
|
|
|
|
|
def __getitem__(self, indices): |
|
|
@@ -102,7 +119,6 @@ class FieldArray(object): |
|
|
|
:param indices: an int, or a list of int. |
|
|
|
:return: |
|
|
|
""" |
|
|
|
# TODO: 返回行为不一致,有隐患 |
|
|
|
if isinstance(indices, int): |
|
|
|
return self.content[indices] |
|
|
|
assert self.is_input is True or self.is_target is True |
|
|
|