|
|
@@ -7,11 +7,11 @@ class FieldArray(object): |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, name, content, padding_val=0, is_target=False, is_input=False): |
|
|
|
def __init__(self, name, content, padding_val=0, is_target=None, is_input=None): |
|
|
|
""" |
|
|
|
|
|
|
|
:param str name: the name of the FieldArray |
|
|
|
:param list content: a list of int, float, str or np.ndarray, or a list of list of one. |
|
|
|
:param list content: a list of int, float, str or np.ndarray, or a list of list of one, or a np.ndarray. |
|
|
|
:param int padding_val: the integer for padding. Default: 0. |
|
|
|
:param bool is_target: If True, this FieldArray is used to compute loss. |
|
|
|
:param bool is_input: If True, this FieldArray is used to the model input. |
|
|
@@ -20,18 +20,44 @@ class FieldArray(object): |
|
|
|
if isinstance(content, list): |
|
|
|
content = content |
|
|
|
elif isinstance(content, np.ndarray): |
|
|
|
content = content.tolist() |
|
|
|
content = content.tolist() # convert np.ndarray into 2-D list |
|
|
|
else: |
|
|
|
raise TypeError("content in FieldArray can only be list or numpy.ndarray, got {}.".format(type(content))) |
|
|
|
self.content = content |
|
|
|
self.padding_val = padding_val |
|
|
|
self.is_target = is_target |
|
|
|
self.is_input = is_input |
|
|
|
|
|
|
|
self._is_target = None |
|
|
|
self._is_input = None |
|
|
|
|
|
|
|
self.BASIC_TYPES = (int, float, str, np.ndarray) |
|
|
|
self.is_2d_list = False |
|
|
|
self.pytype = self._type_detection(content) |
|
|
|
self.pytype = None # int, float, str, or np.ndarray |
|
|
|
self.dtype = None # np.int64, np.float64, np.str |
|
|
|
|
|
|
|
if is_input is not None: |
|
|
|
self.is_input = is_input |
|
|
|
if is_target is not None: |
|
|
|
self.is_target = is_target |
|
|
|
|
|
|
|
@property |
|
|
|
def is_input(self): |
|
|
|
return self._is_input |
|
|
|
|
|
|
|
@is_input.setter |
|
|
|
def is_input(self, value): |
|
|
|
self.pytype = self._type_detection(self.content) |
|
|
|
self.dtype = self._map_to_np_type(self.pytype) |
|
|
|
self._is_input = value |
|
|
|
|
|
|
|
@property |
|
|
|
def is_target(self): |
|
|
|
return self._is_target |
|
|
|
|
|
|
|
@is_target.setter |
|
|
|
def is_target(self, value): |
|
|
|
self.pytype = self._type_detection(self.content) |
|
|
|
self.dtype = self._map_to_np_type(self.pytype) |
|
|
|
self._is_target = value |
|
|
|
|
|
|
|
def _type_detection(self, content): |
|
|
|
""" |
|
|
@@ -42,9 +68,13 @@ class FieldArray(object): |
|
|
|
""" |
|
|
|
if isinstance(content, list) and len(content) > 0 and isinstance(content[0], list): |
|
|
|
# content is a 2-D list |
|
|
|
if not all(isinstance(_, list) for _ in content): # strict check 2-D list |
|
|
|
raise TypeError("Please provide 2-D list.") |
|
|
|
type_set = set([self._type_detection(x) for x in content]) |
|
|
|
if len(type_set) > 1: |
|
|
|
raise RuntimeError("Cannot create FieldArray with more than one type. Provided {}".format(type_set)) |
|
|
|
if len(type_set) == 2 and int in type_set and float in type_set: |
|
|
|
type_set = {float} |
|
|
|
elif len(type_set) > 1: |
|
|
|
raise TypeError("Cannot create FieldArray with more than one type. Provided {}".format(type_set)) |
|
|
|
self.is_2d_list = True |
|
|
|
return type_set.pop() |
|
|
|
|
|
|
@@ -60,9 +90,9 @@ class FieldArray(object): |
|
|
|
# up-cast int to float |
|
|
|
return float |
|
|
|
else: |
|
|
|
raise RuntimeError("Cannot create FieldArray with type {}".format(*type_set)) |
|
|
|
raise TypeError("Cannot create FieldArray with type {}".format(*type_set)) |
|
|
|
else: |
|
|
|
raise RuntimeError("Cannot create FieldArray with type {}".format(type(content))) |
|
|
|
raise TypeError("Cannot create FieldArray with type {}".format(type(content))) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _map_to_np_type(basic_type): |
|
|
@@ -77,33 +107,38 @@ class FieldArray(object): |
|
|
|
|
|
|
|
:param val: int, float, str, or a list of one. |
|
|
|
""" |
|
|
|
val_type = type(val) |
|
|
|
if val_type == list: # shape check |
|
|
|
if self.is_2d_list is False: |
|
|
|
raise RuntimeError("Cannot append a list into a 1-D FieldArray. Please provide an element.") |
|
|
|
if len(val) == 0: |
|
|
|
raise RuntimeError("Cannot append an empty list.") |
|
|
|
val_list_type = set([type(_) for _ in val]) # type check |
|
|
|
if len(val_list_type) == 2 and int in val_list_type and float in val_list_type: |
|
|
|
# up-cast int to float |
|
|
|
val_type = float |
|
|
|
elif len(val_list_type) == 1: |
|
|
|
val_type = val_list_type.pop() |
|
|
|
if self.is_target is True or self.is_input is True: |
|
|
|
# only check type when used as target or input |
|
|
|
|
|
|
|
val_type = type(val) |
|
|
|
if val_type == list: # shape check |
|
|
|
if self.is_2d_list is False: |
|
|
|
raise RuntimeError("Cannot append a list into a 1-D FieldArray. Please provide an element.") |
|
|
|
if len(val) == 0: |
|
|
|
raise RuntimeError("Cannot append an empty list.") |
|
|
|
val_list_type = set([type(_) for _ in val]) # type check |
|
|
|
if len(val_list_type) == 2 and int in val_list_type and float in val_list_type: |
|
|
|
# up-cast int to float |
|
|
|
val_type = float |
|
|
|
elif len(val_list_type) == 1: |
|
|
|
val_type = val_list_type.pop() |
|
|
|
else: |
|
|
|
raise TypeError("Cannot append a list of {}".format(val_list_type)) |
|
|
|
else: |
|
|
|
raise RuntimeError("Cannot append a list of {}".format(val_list_type)) |
|
|
|
else: |
|
|
|
if self.is_2d_list is True: |
|
|
|
raise RuntimeError("Cannot append a non-list into a 2-D list. Please provide a list.") |
|
|
|
if val_type == float and self.pytype == int: |
|
|
|
# up-cast |
|
|
|
self.pytype = float |
|
|
|
self.dtype = self._map_to_np_type(self.pytype) |
|
|
|
elif val_type == int and self.pytype == float: |
|
|
|
pass |
|
|
|
elif val_type == self.pytype: |
|
|
|
pass |
|
|
|
else: |
|
|
|
raise RuntimeError("Cannot append type {} into type {}".format(val_type, self.pytype)) |
|
|
|
if self.is_2d_list is True: |
|
|
|
raise RuntimeError("Cannot append a non-list into a 2-D list. Please provide a list.") |
|
|
|
|
|
|
|
if val_type == float and self.pytype == int: |
|
|
|
# up-cast |
|
|
|
self.pytype = float |
|
|
|
self.dtype = self._map_to_np_type(self.pytype) |
|
|
|
elif val_type == int and self.pytype == float: |
|
|
|
pass |
|
|
|
elif val_type == self.pytype: |
|
|
|
pass |
|
|
|
else: |
|
|
|
raise TypeError("Cannot append type {} into type {}".format(val_type, self.pytype)) |
|
|
|
|
|
|
|
self.content.append(val) |
|
|
|
|
|
|
|
def __getitem__(self, indices): |
|
|
@@ -121,7 +156,8 @@ class FieldArray(object): |
|
|
|
""" |
|
|
|
if isinstance(indices, int): |
|
|
|
return self.content[indices] |
|
|
|
assert self.is_input is True or self.is_target is True |
|
|
|
if self.is_input is False and self.is_target is False: |
|
|
|
raise RuntimeError("Please specify either is_input or is_target is True for {}".format(self.name)) |
|
|
|
batch_size = len(indices) |
|
|
|
# TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下 |
|
|
|
if not is_iterable(self.content[0]): |
|
|
|