diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index 109d4fe9..0ca920d4 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -12,6 +12,7 @@ from queue import Empty, Full import numpy as np import torch import torch.multiprocessing as mp +from numbers import Number from .sampler import RandomSampler @@ -78,8 +79,10 @@ class Batch(object): for field_name, field in self.dataset.get_all_fields().items(): if field.is_target or field.is_input: batch = field.get(indices) - if not self.as_numpy and field.padder is not None: - batch = _to_tensor(batch, field.dtype) + if not self.as_numpy and \ + field.dtype is not None and \ + issubclass(field.dtype, Number) and not isinstance(batch, torch.Tensor): + batch = _to_tensor(batch) if field.is_target: batch_y[field_name] = batch if field.is_input: @@ -174,12 +177,12 @@ class Batch(object): # print('iter done') -def _to_tensor(batch, dtype): +def _to_tensor(batch): try: - if dtype in (int, np.int8, np.int16, np.int32, np.int64): - batch = torch.LongTensor(batch) - if dtype in (float, np.float32, np.float64): - batch = torch.FloatTensor(batch) + if issubclass(batch.dtype.type, np.floating): + batch = torch.as_tensor(batch).float() # 默认使用float32 + else: + batch = torch.as_tensor(batch) # 复用内存地址,避免复制 except: pass return batch diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 9f24adf2..ab020ce4 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -285,7 +285,8 @@ from .field import AutoPadder from .field import FieldArray from .instance import Instance from .utils import _get_func_signature - +from .field import AppendToTargetOrInputException +from .field import SetInputOrTargetException class DataSet(object): """ @@ -422,7 +423,7 @@ class DataSet(object): if len(self.field_arrays) == 0: # DataSet has no field yet for name, field in instance.fields.items(): - field = field.tolist() if isinstance(field, np.ndarray) else field + # field = field.tolist() if isinstance(field, np.ndarray) else field self.field_arrays[name] = FieldArray(name, [field]) # 第一个样本,必须用list包装起来 else: if len(self.field_arrays) != len(instance.fields): @@ -431,7 +432,11 @@ class DataSet(object): .format(len(self.field_arrays), len(instance.fields))) for name, field in instance.fields.items(): assert name in self.field_arrays - self.field_arrays[name].append(field) + try: + self.field_arrays[name].append(field) + except AppendToTargetOrInputException as e: + print(f"Cannot append to field:{name}.") + raise e def add_fieldarray(self, field_name, fieldarray): """ @@ -565,7 +570,11 @@ class DataSet(object): assert isinstance(flag, bool), "Only bool type supported." for name in field_names: if name in self.field_arrays: - self.field_arrays[name].is_target = flag + try: + self.field_arrays[name].is_target = flag + except SetInputOrTargetException as e: + print(f"Cannot set field:{name} as target.") + raise e else: raise KeyError("{} is not a valid field name.".format(name)) @@ -581,7 +590,11 @@ class DataSet(object): """ for name in field_names: if name in self.field_arrays: - self.field_arrays[name].is_input = flag + try: + self.field_arrays[name].is_input = flag + except SetInputOrTargetException as e: + print(f"Cannot set field:{name} as input.") + raise e else: raise KeyError("{} is not a valid field name.".format(name)) diff --git a/fastNLP/core/field.py b/fastNLP/core/field.py index 9ef8d963..c47771df 100644 --- a/fastNLP/core/field.py +++ b/fastNLP/core/field.py @@ -1,251 +1,162 @@ -""" -field模块实现了 FieldArray 和若干 Padder。 FieldArray 是 :class:`~fastNLP.DataSet` 中一列的存储方式, -原理部分请参考 :doc:`fastNLP.core.dataset` - -""" -__all__ = [ - "FieldArray", - "Padder", - "AutoPadder", - "EngChar2DPadder" -] -from copy import deepcopy +from numbers import Number +import torch import numpy as np +from typing import Any +from abc import abstractmethod +from copy import deepcopy - -class FieldArray(object): - """ - 别名::class:`fastNLP.FieldArray` :class:`fastNLP.core.field.FieldArray` - - FieldArray 是用于保存 :class:`~fastNLP.DataSet` 中一个field的类型。 - - :param str name: FieldArray的名称 - :param list,numpy.ndarray content: 列表的元素可以为list,int,float, - :param bool is_target: 这个field是否是一个target field。 - :param bool is_input: 这个field是否是一个input field。 - :param padder: :class:`~fastNLP.Padder` 类型。赋值给fieldarray的padder的对象会被deepcopy一份,需要修改padder参数必须通过 - fieldarray.set_pad_val()。默认为None,即使用 :class:`~fastNLP.AutoPadder` 。 - :param bool ignore_type: 是否忽略该field的type,一般如果这个field不需要转为torch.FloatTensor或torch.LongTensor, - 就可以设置为True。具体意义请参考 :class:`~fastNLP.DataSet` 。 - """ - - def __init__(self, name, content, is_target=None, is_input=None, padder=None, ignore_type=False): +class SetInputOrTargetException(Exception): + def __init__(self, msg, index=None, field_name=None): + super().__init__(msg) + self.msg = msg + self.index = index # 标示在哪个数据遭遇到问题了 + self.field_name = field_name # 标示当前field的名称 + +class AppendToTargetOrInputException(Exception): + def __init__(self, msg, index=None, field_name=None): + super().__init__(msg) + self.msg = msg + self.index = index # 标示在哪个数据遭遇到问题了 + self.field_name = field_name # 标示当前field的名称 + +class FieldArray: + def __init__(self, name, content, is_target=False, is_input=False, padder=None, ignore_type=False): + if len(content)==0: + raise RuntimeError("Empty fieldarray is not allowed.") + _content = content + try: + _content = list(_content) + except BaseException as e: + print(f"Cannot convert content(of type:{type(content)}) into list.") + raise e self.name = name - if isinstance(content, list): - # 如果DataSet使用dict初始化, content 可能是二维list/二维array/三维list - # 如果DataSet使用list of Instance 初始化, content可能是 [list]/[array]/[2D list] - for idx, item in enumerate(content): - # 这是使用list of Instance 初始化时第一个样本:FieldArray(name, [field]) - # 将[np.array] 转化为 list of list - # 也可以支持[array, array, array]的情况 - if isinstance(item, np.ndarray): - content[idx] = content[idx].tolist() - elif isinstance(content, np.ndarray): - content = content.tolist() # convert np.ndarray into 2-D list - else: - raise TypeError("content in FieldArray can only be list or numpy.ndarray, got {}.".format(type(content))) - if len(content) == 0: - raise RuntimeError("Cannot initialize FieldArray with empty list.") - - self.content = content # 1维 或 2维 或 3维 list, 形状可能不对齐 - self.content_dim = None # 表示content是多少维的list + self.content = _content + self._ignore_type = ignore_type + # 根据input的情况设置input,target等 + self._cell_ndim = None # 多少维度 + self.dtype = None # 最内层的element都是什么类型的 + self._is_input = False + self._is_target = False + + if is_input: + self.is_input = is_input + if is_target: + self.is_target = is_target + if padder is None: padder = AutoPadder(pad_val=0) else: - assert isinstance(padder, Padder), "padder must be of type Padder." + assert isinstance(padder, Padder), "padder must be of type fastNLP.Padder." padder = deepcopy(padder) self.set_padder(padder) - self.ignore_type = ignore_type - - self.BASIC_TYPES = (int, float, str) # content中可接受的Python基本类型,这里没有np.array - - self.pytype = None - self.dtype = None - self._is_input = None - self._is_target = None - - if is_input is not None or is_target is not None: - self.is_input = is_input - self.is_target = is_target - - def _set_dtype(self): - if self.ignore_type is False: - self.pytype = self._type_detection(self.content) - self.dtype = self._map_to_np_type(self.pytype) - + + @property + def ignore_type(self): + return self._ignore_type + + @ignore_type.setter + def ignore_type(self, value): + if value: + self._cell_ndim = None + self.dtype = None + @property def is_input(self): return self._is_input - + @is_input.setter def is_input(self, value): """ 当 field_array.is_input = True / False 时被调用 """ - if value is True: - self._set_dtype() + # 如果(value为True)且(_is_input和_is_target都是False)且(ignore_type为False) + if value is True and \ + self._is_target is False and \ + self._ignore_type is False: + self._check_dtype_and_ndim() + if value is False and self._is_target is False: + self.dtype = None + self._cell_ndim = None self._is_input = value - + @property def is_target(self): return self._is_target - + @is_target.setter def is_target(self, value): """ 当 field_array.is_target = True / False 时被调用 """ - if value is True: - self._set_dtype() + if value is True and \ + self._is_input is False and \ + self._ignore_type is False: + self._check_dtype_and_ndim() + if value is False and self._is_input is False: + self.dtype = None + self._cell_ndim = None self._is_target = value - - def _type_detection(self, content): - """ - 当该field被设置为is_input或者is_target时被调用 + def _check_dtype_and_ndim(self): """ - if len(content) == 0: - raise RuntimeError("Empty list in Field {}.".format(self.name)) - - type_set = set([type(item) for item in content]) - - if list in type_set: - if len(type_set) > 1: - # list 跟 非list 混在一起 - raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) - # >1维list - inner_type_set = set() - for l in content: - [inner_type_set.add(type(obj)) for obj in l] - if list not in inner_type_set: - # 二维list - self.content_dim = 2 - return self._basic_type_detection(inner_type_set) - else: - if len(inner_type_set) == 1: - # >2维list - inner_inner_type_set = set() - for _2d_list in content: - for _1d_list in _2d_list: - [inner_inner_type_set.add(type(obj)) for obj in _1d_list] - if list in inner_inner_type_set: - raise RuntimeError("FieldArray cannot handle 4-D or more-D list.") - # 3维list - self.content_dim = 3 - return self._basic_type_detection(inner_inner_type_set) - else: - # list 跟 非list 混在一起 - raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(inner_type_set))) - else: - # 一维list - for content_type in type_set: - if content_type not in self.BASIC_TYPES: - raise RuntimeError("Unexpected data type in Field '{}'. Expect one of {}. Got {}.".format( - self.name, self.BASIC_TYPES, content_type)) - self.content_dim = 1 - return self._basic_type_detection(type_set) - - def _basic_type_detection(self, type_set): + 检查当前content所有的element是否是同一个类型,且是否每个元素具有相同的维度。通过的话,设置_cell_ndim与_ele_type属性;没有 + 通过将直接报错. + + :return: """ - :param type_set: a set of Python types - :return: one of self.BASIC_TYPES + cell_0 = self.content[0] + index = 0 + try: + type_0, dim_0 = _get_ele_type_and_dim(cell_0) + for cell in self.content[1:]: + index += 1 + type_i, dim_i = _get_ele_type_and_dim(cell) + if type_i!=type_0: + raise SetInputOrTargetException("Type:{} in index {} is different from the first element with type:{}." + ".".format(type_i, index, type_0)) + if dim_0!=dim_i: + raise SetInputOrTargetException("Dimension:{} in index {} is different from the first element with " + "dimension:{}.".format(dim_i, index, dim_0)) + self._cell_ndim = dim_0 + self.dtype = type_0 + except SetInputOrTargetException as e: + e.index = index + raise e + + def append(self, val:Any): + """ + :param val: 把该val append到fieldarray。 + :return: """ - if len(type_set) == 1: - return type_set.pop() - elif len(type_set) == 2: - # 有多个basic type; 可能需要up-cast - if float in type_set and int in type_set: - # up-cast int to float - return float - else: - # str 跟 int 或者 float 混在一起 - raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) + if (self._is_target or self._is_input) and self._ignore_type is False: + type_, dim_ = _get_ele_type_and_dim(val) + if self.dtype!=type_: + raise AppendToTargetOrInputException(f"Value(type:{type_}) are of different types with " + f"previous values(type:{self.dtype}).") + if self._cell_ndim!=dim_: + raise AppendToTargetOrInputException(f"Value(dim:{dim_}) are of different dimensions with " + f"previous values(dim:{self._cell_ndim}).") + self.content.append(val) else: - # str, int, float混在一起 - raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) - - def _1d_list_check(self, val): - """如果不是1D list就报错 - """ - type_set = set((type(obj) for obj in val)) - if any(obj not in self.BASIC_TYPES for obj in type_set): - raise ValueError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) - self._basic_type_detection(type_set) - # otherwise: _basic_type_detection will raise error - return True - - def _2d_list_check(self, val): - """如果不是2D list 就报错 - """ - type_set = set(type(obj) for obj in val) - if list(type_set) != [list]: - raise ValueError("Mixed data types in Field {}: {}".format(self.name, type_set)) - inner_type_set = set() - for l in val: - for obj in l: - inner_type_set.add(type(obj)) - self._basic_type_detection(inner_type_set) - return True - - @staticmethod - def _map_to_np_type(basic_type): - type_mapping = {int: np.int64, float: np.float64, str: np.str, np.ndarray: np.ndarray} - return type_mapping[basic_type] - - def __repr__(self): - return "FieldArray {}: {}".format(self.name, self.content.__repr__()) - - def append(self, val): - """将val append到这个field的尾部。如果这个field已经被设置为input或者target,则在append之前会检查该类型是否与已有 - 的内容是匹配的。 - - :param Any val: 需要append的值。 - """ - if self.ignore_type is False: - if isinstance(val, list): - pass - elif isinstance(val, tuple): # 确保最外层是list - val = list(val) - elif isinstance(val, np.ndarray): - val = val.tolist() - elif any((isinstance(val, t) for t in self.BASIC_TYPES)): - pass - else: - raise RuntimeError( - "Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) - - if self.is_input is True or self.is_target is True: - if type(val) == list: - if len(val) == 0: - raise ValueError("Cannot append an empty list.") - if self.content_dim == 2 and self._1d_list_check(val): - # 1维list检查 - pass - elif self.content_dim == 3 and self._2d_list_check(val): - # 2维list检查 - pass - else: - raise RuntimeError( - "Dimension not matched: expect dim={}, got {}.".format(self.content_dim - 1, val)) - elif type(val) in self.BASIC_TYPES and self.content_dim == 1: - # scalar检查 - if type(val) == float and self.pytype == int: - self.pytype = float - self.dtype = self._map_to_np_type(self.pytype) - else: - raise RuntimeError( - "Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) - self.content.append(val) - + self.content.append(val) + def __getitem__(self, indices): return self.get(indices, pad=False) - + def __setitem__(self, idx, val): assert isinstance(idx, int) + if (self._is_target or self._is_input) and self.ignore_type is False: # 需要检测类型 + type_, dim_ = _get_ele_type_and_dim(val) + if self.dtype!=type_: + raise RuntimeError(f"Value(type:{type_}) are of different types with " + f"other values(type:{self.dtype}).") + if self._cell_ndim!=dim_: + raise RuntimeError(f"Value(dim:{dim_}) are of different dimensions with " + f"previous values(dim:{self._cell_ndim}).") self.content[idx] = val - + def get(self, indices, pad=True): """ 根据给定的indices返回内容 @@ -257,14 +168,14 @@ class FieldArray(object): if isinstance(indices, int): return self.content[indices] if self.is_input is False and self.is_target is False: - raise RuntimeError("Please specify either is_input or is_target is True for {}".format(self.name)) - + raise RuntimeError("Please specify either is_input or is_target to True for {}".format(self.name)) + contents = [self.content[i] for i in indices] if self.padder is None or pad is False: return np.array(contents) else: - return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype) - + return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim) + def set_padder(self, padder): """ 设置padder,在这个field进行pad的时候用这个padder进行pad,如果为None则不进行pad。 @@ -276,7 +187,7 @@ class FieldArray(object): self.padder = deepcopy(padder) else: self.padder = None - + def set_pad_val(self, pad_val): """ 修改padder的pad_val. @@ -286,7 +197,7 @@ class FieldArray(object): if self.padder is not None: self.padder.set_pad_val(pad_val) return self - + def __len__(self): """ Returns the size of FieldArray. @@ -294,7 +205,7 @@ class FieldArray(object): :return int length: """ return len(self.content) - + def to(self, other): """ 将other的属性复制给本FieldArray(other必须为FieldArray类型). @@ -303,22 +214,63 @@ class FieldArray(object): :param other: :class:`~fastNLP.FieldArray` 从哪个field拷贝属性 :return: :class:`~fastNLP.FieldArray` """ - assert isinstance(other, FieldArray), "Only support FieldArray type, not {}.".format(type(other)) - + assert isinstance(other, FieldArray), "Only supports fastNLP.FieldArray type, not {}.".format(type(other)) + + self.ignore_type = other.ignore_type self.is_input = other.is_input self.is_target = other.is_target self.padder = other.padder - self.ignore_type = other.ignore_type - + return self -def _is_iterable(content): +def _get_ele_type_and_dim(cell:Any, dim=0): + """ + 识别cell的类别与dimension的数量 + + numpy scalar type:https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html + :param cell: + :param dim: + :return: + """ + if isinstance(cell, (str, Number, np.bool_)): + return type(cell), dim + elif isinstance(cell, list): + dim += 1 + res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] + types = set([i for i,j in res]) + dims = set([j for i,j in res]) + if len(types)>1: + raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) + if len(dims)>1: + raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) + return types.pop(), dims.pop() + elif isinstance(cell, torch.Tensor): + return cell.dtype, cell.dim() + dim # 如果是torch.mean的结果是0 + elif isinstance(cell, np.ndarray): + if cell.dtype != np.dtype('O'): # 如果不是object的话说明是well-formatted的了 + return cell.dtype.type, cell.ndim + dim + # 否则需要继续往下iterate + dim += 1 + res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] + types = set([i for i,j in res]) + dims = set([j for i,j in res]) + if len(types)>1: + raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) + if len(dims)>1: + raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) + return types.pop(), dims.pop() + else: # 包含tuple, set, dict以及其它的类型 + raise SetInputOrTargetException(f"Cannot process type:{type(cell)}.") + + +def _is_iterable(value): + # 检查是否是iterable的, duck typing try: - _ = (e for e in content) - except TypeError: + iter(value) + return True + except BaseException as e: return False - return True class Padder: @@ -327,32 +279,35 @@ class Padder: 所有padder都需要继承这个类,并覆盖__call__方法。 用于对batch进行padding操作。传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前deepcopy一份。 - + .. py:function:: __call__(self, contents, field_name, field_ele_dtype): 传入的是List内容。假设有以下的DataSet。 - + :param list(Any) contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前 deepcopy一份。 :param str, field_name: field的名称。 :param np.int64,np.float64,np.str,None, field_ele_dtype: 该field的内层元素的类型。如果该field的ignore_type为True,该这个值为None。 :return: np.array([padded_element]) - + """ - + def __init__(self, pad_val=0, **kwargs): self.pad_val = pad_val - + def set_pad_val(self, pad_val): self.pad_val = pad_val - - def __call__(self, contents, field_name, field_ele_dtype): + + @abstractmethod + def __call__(self, contents, field_name, field_ele_dtype, dim:int): """ 传入的是List内容。假设有以下的DataSet。 :param list(Any) contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前 deepcopy一份。 :param str, field_name: field的名称。 - :param np.int64,np.float64,np.str,None, field_ele_dtype: 该field的内层元素的类型。如果该field的ignore_type为True,该这个值为None。 + :param np.int64,np.float64,np.str,None, field_ele_dtype: 该field的内层元素的类型。如果该field的ignore_type为True, + 该这个值为None。 + :param dim: 这个field的维度。当ignore_type为True时,该值为None :return: np.array([padded_element]) Example:: @@ -394,50 +349,87 @@ class AutoPadder(Padder): 根据contents的数据自动判定是否需要做padding。 1 如果元素类型(元素类型是指field中最里层元素的数据类型, 可以通过FieldArray.dtype查看,比如['This', 'is', ...]的元素类 - 型为np.str, [[1,2], ...]的元素类型为np.int64)的数据不为(np.int64, np.float64)则不会进行pad + 型为str, [[1,2], ...]的元素类型为int)的数据不为数值类型则不会进行pad + + 2 如果元素类型为数值类型,比如np.int64, np.float64, int, float, torch.int64等 - 2 如果元素类型为(np.int64, np.float64), + 2.1 如果该field的内容为数值类型(包括int, float等),比如为seq_len, 则不进行padding - 2.1 如果该field的内容为(np.int64, np.float64),比如为seq_len, 则不进行padding + 2.2 如果该field的内容等价于一维list, 那么会将Batch中的List pad为一样长。 - 2.2 如果该field的内容为List, 那么会将Batch中的List pad为一样长。若该List下还有里层的List需要padding,请使用其它padder。 - 即如果Instance中field形如[1, 2, 3, ...],则可以pad;若为[[1,2], [3,4, ...]]则不能进行pad + 2.3 如果该field的内容等价于二维list,那么会按照英语character padding的方式进行padding。如果是character padding建议使用 + :class: fastNLP.EngChar2DPadder. + + 2.4 如果该field的内容等价于三维list,则如果每个instance在每个维度上相等,会组成一个batch的tensor返回,这种情况应该是为图片 + 的情况。 + + 3 其它情况不进行处理,返回一个np.array类型。 """ - def __init__(self, pad_val=0): - """ - :param pad_val: int, padding的位置使用该index - """ super().__init__(pad_val=pad_val) - - def _is_two_dimension(self, contents): - """ - 判断contents是不是只有两个维度。[[1,2], [3]]是两个维度. [[[1,2], [3, 4, 5]], [[4,5]]]有三个维度 - :param contents: - :return: - """ - value = contents[0] - if isinstance(value, (np.ndarray, list)): - value = value[0] - if isinstance(value, (np.ndarray, list)): - return False - return True - return False - - def __call__(self, contents, field_name, field_ele_dtype): - - if not _is_iterable(contents[0]): - array = np.array([content for content in contents], dtype=field_ele_dtype) - elif field_ele_dtype in (np.int64, np.float64) and self._is_two_dimension(contents): - max_len = max([len(content) for content in contents]) - array = np.full((len(contents), max_len), self.pad_val, dtype=field_ele_dtype) - for i, content in enumerate(contents): - array[i][:len(content)] = content - elif field_ele_dtype is None: - array = np.array(contents) # 当ignore_type=True时,直接返回contents - else: # should only be str - array = np.array([content for content in contents]) - return array + + def __call__(self, contents, field_name, field_ele_dtype, dim): + if field_ele_dtype: + if dim>3: + return np.array(contents) + if isinstance(field_ele_dtype, np.dtype) or field_ele_dtype in (float, int, bool, str): + if isinstance(field_ele_dtype, np.number) or field_ele_dtype in (float, int, bool): + if dim==0: + array = np.array(contents, dtype=field_ele_dtype) + elif dim==1: + max_len = max(map(len, contents)) + array = np.full((len(contents), max_len), self.pad_val, dtype=field_ele_dtype) + for i, content_i in enumerate(contents): + array[i, :len(content_i)] = content_i + elif dim==2: + max_len = max(map(len, contents)) + max_word_len = max([max([len(content_ii) for content_ii in content_i]) for + content_i in contents]) + array = np.full((len(contents), max_len, max_word_len), self.pad_val, dtype=field_ele_dtype) + for i, content_i in enumerate(contents): + for j, content_ii in enumerate(content_i): + array[i, j, :len(content_ii)] = content_ii + else: + shape = np.shape(contents) + if len(shape)==4: # 说明各dimension是相同的大小 + array = np.array(contents, dtype=field_ele_dtype) + else: + raise RuntimeError(f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") + return array + return np.array(contents) + elif str(field_ele_dtype).startswith('torch'): + if dim==0: + tensor = torch.tensor(contents).to(field_ele_dtype) + elif dim==1: + max_len = max(map(len, contents)) + tensor = torch.full((len(contents), max_len), fill_value=self.pad_val, dtype=field_ele_dtype) + for i, content_i in enumerate(contents): + tensor[i, :len(content_i)] = torch.tensor(content_i) + elif dim==2: + max_len = max(map(len, contents)) + max_word_len = max([max([len(content_ii) for content_ii in content_i]) for + content_i in contents]) + tensor = torch.full((len(contents), max_len, max_word_len), fill_value=self.pad_val, + dtype=field_ele_dtype) + for i, content_i in enumerate(contents): + for j, content_ii in enumerate(content_i): + tensor[i, j, :len(content_ii)] = torch.tensor(content_ii) + else: + shapes = set([np.shape(content_i) for content_i in contents]) + if len(shapes)>1: + raise RuntimeError(f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") + shape = shapes.pop() + if len(shape)==3: + tensor = torch.full([len(contents)]+list(shape), fill_value=self.pad_val, dtype=field_ele_dtype) + for i, content_i in enumerate(contents): + tensor[i] = torch.tensor(content_i, dtype=field_ele_dtype) + else: + raise RuntimeError(f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") + return tensor + else: + return np.array(contents) # 不进行任何操作 + else: + return np.array(contents) class EngChar2DPadder(Padder): @@ -463,7 +455,7 @@ class EngChar2DPadder(Padder): dataset.set_padder('chars', padder) # chars这个field的设置为了EnChar2DPadder """ - + def __init__(self, pad_val=0, pad_length=0): """ :param pad_val: int, pad的位置使用该index @@ -471,32 +463,10 @@ class EngChar2DPadder(Padder): 都pad或截取到该长度. """ super().__init__(pad_val=pad_val) - + self.pad_length = pad_length - - def _exactly_three_dims(self, contents, field_name): - """ - 检查传入的contents是否刚好是3维,如果不是3维就报错。理论上,第一个维度是batch,第二个维度是word,第三个维度是character - :param contents: - :param field_name: str - :return: - """ - if not isinstance(contents, list): - raise TypeError("contents should be a list, not {}.".format(type(contents))) - value = contents[0] - try: - value = value[0] - except: - raise ValueError("Field:{} only has one dimension.".format(field_name)) - try: - value = value[0] - except: - raise ValueError("Field:{} only has two dimensions.".format(field_name)) - - if _is_iterable(value): - raise ValueError("Field:{} has more than 3 dimension.".format(field_name)) - - def __call__(self, contents, field_name, field_ele_dtype): + + def __call__(self, contents, field_name, field_ele_dtype, dim): """ 期望输入类似于 [ @@ -510,11 +480,11 @@ class EngChar2DPadder(Padder): :param field_ele_dtype :return: """ - if field_ele_dtype not in (np.int64, np.float64): + if field_ele_dtype not in (np.int64, np.float64, int, float): raise TypeError('dtype of Field:{} should be np.int64 or np.float64 to do 2D padding, get {}.'.format( field_name, field_ele_dtype )) - self._exactly_three_dims(contents, field_name) + assert dim==2, f"Field:{field_name} has {dim}, EngChar2DPadder only supports input with 2 dimensions." if self.pad_length < 1: max_char_length = max([max(len(char_lst) for char_lst in word_lst) for word_lst in contents]) else: @@ -522,12 +492,12 @@ class EngChar2DPadder(Padder): max_sent_length = max(len(word_lst) for word_lst in contents) batch_size = len(contents) dtype = type(contents[0][0][0]) - + padded_array = np.full((batch_size, max_sent_length, max_char_length), fill_value=self.pad_val, dtype=dtype) for b_idx, word_lst in enumerate(contents): for c_idx, char_lst in enumerate(word_lst): chars = char_lst[:max_char_length] padded_array[b_idx, c_idx, :len(chars)] = chars - + return padded_array diff --git a/fastNLP/io/embed_loader.py b/fastNLP/io/embed_loader.py index bc37777e..4119d93f 100644 --- a/fastNLP/io/embed_loader.py +++ b/fastNLP/io/embed_loader.py @@ -107,9 +107,9 @@ class EmbedLoader(BaseLoader): :param bool normalize: 是否将每个vector归一化到norm为1 :param str error: `ignore` , `strict` ; 如果 `ignore` ,错误将自动跳过; 如果 `strict` , 错误将抛出。这里主要可能出错的地 方在于词表有空行或者词表出现了维度不一致。 - :return numpy.ndarray: shape为 [len(vocab), dimension], dimension由pretrain的embedding决定。 - :return numpy.ndarray: Vocabulary Embedding的shape是[词表大小+x, 词表维度], "词表大小+x"是由于最终的大小还取决与 + :return (numpy.ndarray, Vocabulary): Embedding的shape是[词表大小+x, 词表维度], "词表大小+x"是由于最终的大小还取决与 是否使用padding, 以及unknown有没有在词表中找到对应的词。 Vocabulary中的词的顺序与Embedding的顺序是一一对应的。 + """ vocab = Vocabulary(padding=padding, unknown=unknown) vec_dict = {} diff --git a/test/core/test_field.py b/test/core/test_field.py index 1f6580c1..e9053f37 100644 --- a/test/core/test_field.py +++ b/test/core/test_field.py @@ -1,8 +1,55 @@ import unittest import numpy as np +import torch from fastNLP import FieldArray +from fastNLP.core.field import _get_ele_type_and_dim +from fastNLP import AutoPadder + +class TestFieldArrayTyepDimDetect(unittest.TestCase): + """ + 检测FieldArray能否正确识别type与ndim + + """ + def test_case1(self): + # 1.1 常规类型测试 + for value in [1, True, 1.0, 'abc']: + type_ = type(value) + _type, _dim = _get_ele_type_and_dim(cell=value) + self.assertListEqual([_type, _dim], [type_, 0]) + # 1.2 mix类型报错 + with self.assertRaises(Exception): + value = [1, 2, 1.0] + self.assertRaises(_get_ele_type_and_dim(value)) + # 带有numpy的测试 + # 2.1 + value = np.array([1, 2, 3]) + type_ = value.dtype + dim_ = 1 + self.assertSequenceEqual(_get_ele_type_and_dim(cell=value), [type_, dim_]) + # 2.2 + value = np.array([[1, 2], [3, 4, 5]]) # char embedding的场景 + self.assertSequenceEqual([int, 2], _get_ele_type_and_dim(value)) + # 2.3 + value = np.zeros((3, 4)) + self.assertSequenceEqual([value.dtype, 2], _get_ele_type_and_dim(value)) + # 2.4 测试错误的dimension + with self.assertRaises(Exception): + value = np.array([[1, 2], [3, [1]]]) + _get_ele_type_and_dim(value) + # 2.5 测试混合类型 + with self.assertRaises(Exception): + value = np.array([[1, 2], [3.0]]) + _get_ele_type_and_dim(value) + + # 带有tensor的测试 + # 3.1 word embedding的场景 + value = torch.zeros(3, 10) + self.assertSequenceEqual([value.dtype, 2], _get_ele_type_and_dim(value)) + # 3.2 char embedding/image的场景 + value = torch.zeros(3, 32, 32) + self.assertSequenceEqual([value.dtype, 3], _get_ele_type_and_dim(value)) class TestFieldArrayInit(unittest.TestCase): @@ -31,12 +78,6 @@ class TestFieldArrayInit(unittest.TestCase): # 三维list fa = FieldArray("x", [[[1, 2], [3, 4]], [[1, 2], [3, 4]]], is_input=True) - def test_init_v7(self): - # list of array - fa = FieldArray("x", [np.array([[1, 2], [3, 4]]), np.array([[1, 2], [3, 4]])], is_input=True) - self.assertEqual(fa.pytype, int) - self.assertEqual(fa.dtype, np.int) - def test_init_v4(self): # 一维list val = [1, 2, 3, 4] @@ -56,6 +97,11 @@ class TestFieldArrayInit(unittest.TestCase): fa.append(val) def test_init_v7(self): + # list of array + fa = FieldArray("x", [np.array([[1, 2], [3, 4]]), np.array([[1, 2], [3, 4]])], is_input=True) + self.assertEqual(fa.dtype, np.array([1]).dtype) + + def test_init_v8(self): # 二维list val = np.array([[1, 2], [3, 4]]) fa = FieldArray("x", [val], is_input=True) @@ -79,33 +125,23 @@ class TestFieldArray(unittest.TestCase): self.assertListEqual(list(fa.get([0, 1, 2])), [1, 2, 3]) def test_type_conversion(self): - fa = FieldArray("x", [1.2, 2.2, 3, 4, 5], is_input=True) - self.assertEqual(fa.pytype, float) - self.assertEqual(fa.dtype, np.float64) - fa = FieldArray("x", [1, 2, 3, 4, 5], is_input=True) - fa.append(1.3333) - self.assertEqual(fa.pytype, float) - self.assertEqual(fa.dtype, np.float64) + self.assertEqual(fa.dtype, int) fa = FieldArray("y", [1.1, 2.2, 3.3, 4.4, 5.5], is_input=True) - fa.append(10) - self.assertEqual(fa.pytype, float) - self.assertEqual(fa.dtype, np.float64) + fa.append(10.0) + self.assertEqual(fa.dtype, float) fa = FieldArray("y", ["a", "b", "c", "d"], is_input=True) fa.append("e") - self.assertEqual(fa.dtype, np.str) - self.assertEqual(fa.pytype, str) + self.assertEqual(fa.dtype, str) def test_support_np_array(self): fa = FieldArray("y", np.array([[1.1, 2.2, 3.3, 4.4, 5.5]]), is_input=True) self.assertEqual(fa.dtype, np.float64) - self.assertEqual(fa.pytype, float) fa.append(np.array([1.1, 2.2, 3.3, 4.4, 5.5])) self.assertEqual(fa.dtype, np.float64) - self.assertEqual(fa.pytype, float) fa = FieldArray("my_field", np.random.rand(3, 5), is_input=True) # in this case, pytype is actually a float. We do not care about it. @@ -113,11 +149,10 @@ class TestFieldArray(unittest.TestCase): def test_nested_list(self): fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.1, 2.2, 3.3, 4.4, 5.5]], is_input=True) - self.assertEqual(fa.pytype, float) - self.assertEqual(fa.dtype, np.float64) + self.assertEqual(fa.dtype, float) def test_getitem_v1(self): - fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True) + fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]], is_input=True) self.assertEqual(fa[0], [1.1, 2.2, 3.3, 4.4, 5.5]) ans = fa[[0, 1]] self.assertTrue(isinstance(ans, np.ndarray)) @@ -150,7 +185,7 @@ class TestFieldArray(unittest.TestCase): fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True) fa.append(["str", 0, 0, 0, 1.89]) - fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True) + fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]], is_input=True) fa.append([1.2, 2.3, 3.4, 4.5, 5.6]) self.assertEqual(len(fa), 3) self.assertEqual(fa[2], [1.2, 2.3, 3.4, 4.5, 5.6]) @@ -163,33 +198,86 @@ class TestFieldArray(unittest.TestCase): fa = FieldArray("y", [(1, "1"), (2, "2"), (3, "3"), (4, "4")], is_target=True, ignore_type=True) -class TestPadder(unittest.TestCase): +class TestAutoPadder(unittest.TestCase): + def test00(self): + padder = AutoPadder() + # 没有类型时 + contents = [(1, 2), ('str', 'a')] + padder(contents, None, None, None) def test01(self): - """ - 测试AutoPadder能否正常工作 - :return: - """ - from fastNLP import AutoPadder + # 测试使用多维的bool, int, str, float的情况 + # str padder = AutoPadder() content = ['This is a str', 'this is another str'] - self.assertListEqual(content, padder(content, None, np.str).tolist()) + self.assertListEqual(content, padder(content, None, str, 0).tolist()) - content = [1, 2] - self.assertListEqual(content, padder(content, None, np.int64).tolist()) - - content = [[1,2], [3], [4]] - self.assertListEqual([[1,2], [3, 0], [4, 0]], - padder(content, None, np.int64).tolist()) + # 1维int + content = [[1, 2, 3], [4,], [5, 6, 7, 8]] + padded_content = [[1, 2, 3, 0], [4, 0, 0, 0], [5, 6, 7, 8]] + self.assertListEqual(padder(content, None, int, 1).tolist(), padded_content) + # 二维int + padded_content = [[[1, 2, 3, 0], [4, 5, 0, 0], [7, 8, 9, 10]], [[1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]] content = [ - [[1, 2, 3], [4, 5], [7,8,9,10]], - [[1]] - ] - self.assertListEqual(content, - padder(content, None, np.int64).tolist()) + [[1, 2, 3], [4, 5], [7, 8, 9, 10]], + [[1]] + ] + self.assertListEqual(padder(content, None, int, 2).tolist(), padded_content) + + # 3维图片 + contents = [np.random.rand(3, 4, 4).tolist() for _ in range(5)] + self.assertTrue(padder(contents, None, float, 3).shape==(5, 3, 4, 4)) + + # 更高维度直接返回 + contents = [np.random.rand(24, 3, 4, 4).tolist() for _ in range(5)] + self.assertTrue(isinstance(padder(contents, None, float, 4), np.ndarray)) def test02(self): + padder = AutoPadder() + # 测试numpy的情况 + # 0维 + contents = np.arange(12) + self.assertListEqual(padder(contents, None, contents.dtype, 0).tolist(), contents.tolist()) + + # 1维 + contents = np.arange(12).reshape((3, 4)) + self.assertListEqual(padder(contents, None, contents.dtype, 1).tolist(), contents.tolist()) + + # 2维 + contents = np.ones((3, 10, 5)) + self.assertListEqual(padder(contents, None, contents.dtype, 2).tolist(), contents.tolist()) + + # 3维 + contents = [np.random.rand(3, 4, 4) for _ in range(5)] + l_contents = [content.tolist() for content in contents] + self.assertListEqual(padder(contents, None, contents[0].dtype, 3).tolist(), l_contents) + + def test03(self): + padder = AutoPadder() + # 测试tensor的情况 + # 0维 + contents = torch.arange(12) + r_contents = padder(contents, None, contents.dtype, 0) + self.assertSequenceEqual(r_contents.tolist(), contents.tolist()) + self.assertTrue(r_contents.dtype==contents.dtype) + + # 0维 + contents = [torch.tensor(1) for _ in range(10)] + self.assertSequenceEqual(padder(contents, None, torch.int64, 0).tolist(), contents) + + # 1维 + contents = torch.randn(3, 4) + padder(contents, None, torch.float64, 1) + + # 3维 + contents = [torch.randn(3, 4, 4) for _ in range(5)] + padder(contents, None, torch.float64, 3) + + + +class TestEngChar2DPadder(unittest.TestCase): + def test01(self): """ 测试EngChar2DPadder能不能正确使用 :return: @@ -198,38 +286,31 @@ class TestPadder(unittest.TestCase): padder = EngChar2DPadder(pad_length=0) contents = [1, 2] - # 不能是1维 - with self.assertRaises(ValueError): - padder(contents, None, np.int64) + # 不能是0维 + with self.assertRaises(Exception): + padder(contents, None, np.int64, 0) contents = [[1, 2]] - # 不能是2维 - with self.assertRaises(ValueError): - padder(contents, None, np.int64) - contents = [[[[1, 2]]]] + # 不能是1维 + with self.assertRaises(Exception): + padder(contents, None, np.int64, 1) + contents = [ + [[[[1, 2]]]] + ] # 不能是3维以上 - with self.assertRaises(ValueError): - padder(contents, None, np.int64) + with self.assertRaises(Exception): + padder(contents, None, np.int64, 3) contents = [ [[1, 2, 3], [4, 5], [7,8,9,10]], [[1]] ] self.assertListEqual([[[1, 2, 3, 0], [4, 5, 0, 0], [7, 8, 9, 10]], [[1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]], - padder(contents, None, np.int64).tolist()) + padder(contents, None, np.int64, 2).tolist()) padder = EngChar2DPadder(pad_length=5, pad_val=-100) self.assertListEqual( [[[1, 2, 3, -100, -100], [4, 5, -100, -100, -100], [7, 8, 9, 10, -100]], [[1, -100, -100, -100, -100], [-100, -100, -100, -100, -100], [-100, -100, -100, -100, -100]]], - padder(contents, None, np.int64).tolist() + padder(contents, None, np.int64, 2).tolist() ) - def test_None_dtype(self): - from fastNLP import AutoPadder - padder = AutoPadder() - content = [ - [[1, 2, 3], [4, 5], [7, 8, 9, 10]], - [[1]] - ] - ans = padder(content, None, None).tolist() - self.assertListEqual(content, ans)