|
|
@@ -96,10 +96,11 @@ class FieldArray(object): |
|
|
|
:param list content: a list of int, float, str or np.ndarray, or a list of list of one, or a np.ndarray. |
|
|
|
:param bool is_target: If True, this FieldArray is used to compute loss. |
|
|
|
:param bool is_input: If True, this FieldArray is used to the model input. |
|
|
|
:param padder: PadderBase类型。大多数情况下都不需要设置该值,除非需要在多个维度上进行padding(比如英文中对character进行padding) |
|
|
|
:param PadderBase padder: PadderBase类型。大多数情况下都不需要设置该值,除非需要在多个维度上进行padding(比如英文中对character进行padding) |
|
|
|
:param bool ignore_type: whether to ignore type. If True, no type detection will rise for this FieldArray. (default: False) |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, name, content, is_target=None, is_input=None, padder=AutoPadder(pad_val=0)): |
|
|
|
def __init__(self, name, content, is_target=None, is_input=None, padder=AutoPadder(pad_val=0), ignore_type=False): |
|
|
|
"""DataSet在初始化时会有两类方法对FieldArray操作: |
|
|
|
1) 如果DataSet使用dict初始化,那么在add_field中会构造FieldArray: |
|
|
|
1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]}) |
|
|
@@ -114,6 +115,7 @@ class FieldArray(object): |
|
|
|
2.4) 二维array DataSet([Instance(x=np.array([[1, 2], [3, 4]]))]) |
|
|
|
|
|
|
|
类型检查(dtype check)发生在当该field被设置为is_input或者is_target时。 |
|
|
|
ignore_type用来控制是否进行类型检查,如果为True,则不检查。 |
|
|
|
|
|
|
|
""" |
|
|
|
self.name = name |
|
|
@@ -136,6 +138,7 @@ class FieldArray(object): |
|
|
|
self.content = content # 1维 或 2维 或 3维 list, 形状可能不对齐 |
|
|
|
self.content_dim = None # 表示content是多少维的list |
|
|
|
self.set_padder(padder) |
|
|
|
self.ignore_type = ignore_type |
|
|
|
|
|
|
|
self.BASIC_TYPES = (int, float, str) # content中可接受的Python基本类型,这里没有np.array |
|
|
|
|
|
|
@@ -149,8 +152,9 @@ class FieldArray(object): |
|
|
|
self.is_target = is_target |
|
|
|
|
|
|
|
def _set_dtype(self): |
|
|
|
self.pytype = self._type_detection(self.content) |
|
|
|
self.dtype = self._map_to_np_type(self.pytype) |
|
|
|
if self.ignore_type is False: |
|
|
|
self.pytype = self._type_detection(self.content) |
|
|
|
self.dtype = self._map_to_np_type(self.pytype) |
|
|
|
|
|
|
|
@property |
|
|
|
def is_input(self): |
|
|
@@ -278,39 +282,40 @@ class FieldArray(object): |
|
|
|
|
|
|
|
:param val: int, float, str, or a list of one. |
|
|
|
""" |
|
|
|
if isinstance(val, list): |
|
|
|
pass |
|
|
|
elif isinstance(val, tuple): # 确保最外层是list |
|
|
|
val = list(val) |
|
|
|
elif isinstance(val, np.ndarray): |
|
|
|
val = val.tolist() |
|
|
|
elif any((isinstance(val, t) for t in self.BASIC_TYPES)): |
|
|
|
pass |
|
|
|
else: |
|
|
|
raise RuntimeError( |
|
|
|
"Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) |
|
|
|
|
|
|
|
if self.is_input is True or self.is_target is True: |
|
|
|
if type(val) == list: |
|
|
|
if len(val) == 0: |
|
|
|
raise ValueError("Cannot append an empty list.") |
|
|
|
if self.content_dim == 2 and self._1d_list_check(val): |
|
|
|
# 1维list检查 |
|
|
|
pass |
|
|
|
elif self.content_dim == 3 and self._2d_list_check(val): |
|
|
|
# 2维list检查 |
|
|
|
pass |
|
|
|
else: |
|
|
|
raise RuntimeError( |
|
|
|
"Dimension not matched: expect dim={}, got {}.".format(self.content_dim - 1, val)) |
|
|
|
elif type(val) in self.BASIC_TYPES and self.content_dim == 1: |
|
|
|
# scalar检查 |
|
|
|
if type(val) == float and self.pytype == int: |
|
|
|
self.pytype = float |
|
|
|
self.dtype = self._map_to_np_type(self.pytype) |
|
|
|
if self.ignore_type is False: |
|
|
|
if isinstance(val, list): |
|
|
|
pass |
|
|
|
elif isinstance(val, tuple): # 确保最外层是list |
|
|
|
val = list(val) |
|
|
|
elif isinstance(val, np.ndarray): |
|
|
|
val = val.tolist() |
|
|
|
elif any((isinstance(val, t) for t in self.BASIC_TYPES)): |
|
|
|
pass |
|
|
|
else: |
|
|
|
raise RuntimeError( |
|
|
|
"Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) |
|
|
|
|
|
|
|
if self.is_input is True or self.is_target is True: |
|
|
|
if type(val) == list: |
|
|
|
if len(val) == 0: |
|
|
|
raise ValueError("Cannot append an empty list.") |
|
|
|
if self.content_dim == 2 and self._1d_list_check(val): |
|
|
|
# 1维list检查 |
|
|
|
pass |
|
|
|
elif self.content_dim == 3 and self._2d_list_check(val): |
|
|
|
# 2维list检查 |
|
|
|
pass |
|
|
|
else: |
|
|
|
raise RuntimeError( |
|
|
|
"Dimension not matched: expect dim={}, got {}.".format(self.content_dim - 1, val)) |
|
|
|
elif type(val) in self.BASIC_TYPES and self.content_dim == 1: |
|
|
|
# scalar检查 |
|
|
|
if type(val) == float and self.pytype == int: |
|
|
|
self.pytype = float |
|
|
|
self.dtype = self._map_to_np_type(self.pytype) |
|
|
|
else: |
|
|
|
raise RuntimeError( |
|
|
|
"Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) |
|
|
|
self.content.append(val) |
|
|
|
|
|
|
|
def __getitem__(self, indices): |
|
|
|