Browse Source

!!!重要更新,DataSet理论上支持任意类型的数据了,但是因为改动非常大,所以可能会有bug

tags/v0.4.10
yh 6 years ago
parent
commit
e90bbbb3f1
5 changed files with 444 additions and 377 deletions
  1. +10
    -7
      fastNLP/core/batch.py
  2. +18
    -5
      fastNLP/core/dataset.py
  3. +271
    -301
      fastNLP/core/field.py
  4. +2
    -2
      fastNLP/io/embed_loader.py
  5. +143
    -62
      test/core/test_field.py

+ 10
- 7
fastNLP/core/batch.py View File

@@ -12,6 +12,7 @@ from queue import Empty, Full
import numpy as np import numpy as np
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from numbers import Number


from .sampler import RandomSampler from .sampler import RandomSampler


@@ -78,8 +79,10 @@ class Batch(object):
for field_name, field in self.dataset.get_all_fields().items(): for field_name, field in self.dataset.get_all_fields().items():
if field.is_target or field.is_input: if field.is_target or field.is_input:
batch = field.get(indices) batch = field.get(indices)
if not self.as_numpy and field.padder is not None:
batch = _to_tensor(batch, field.dtype)
if not self.as_numpy and \
field.dtype is not None and \
issubclass(field.dtype, Number) and not isinstance(batch, torch.Tensor):
batch = _to_tensor(batch)
if field.is_target: if field.is_target:
batch_y[field_name] = batch batch_y[field_name] = batch
if field.is_input: if field.is_input:
@@ -174,12 +177,12 @@ class Batch(object):
# print('iter done') # print('iter done')




def _to_tensor(batch, dtype):
def _to_tensor(batch):
try: try:
if dtype in (int, np.int8, np.int16, np.int32, np.int64):
batch = torch.LongTensor(batch)
if dtype in (float, np.float32, np.float64):
batch = torch.FloatTensor(batch)
if issubclass(batch.dtype.type, np.floating):
batch = torch.as_tensor(batch).float() # 默认使用float32
else:
batch = torch.as_tensor(batch) # 复用内存地址,避免复制
except: except:
pass pass
return batch return batch

+ 18
- 5
fastNLP/core/dataset.py View File

@@ -285,7 +285,8 @@ from .field import AutoPadder
from .field import FieldArray from .field import FieldArray
from .instance import Instance from .instance import Instance
from .utils import _get_func_signature from .utils import _get_func_signature

from .field import AppendToTargetOrInputException
from .field import SetInputOrTargetException


class DataSet(object): class DataSet(object):
""" """
@@ -422,7 +423,7 @@ class DataSet(object):
if len(self.field_arrays) == 0: if len(self.field_arrays) == 0:
# DataSet has no field yet # DataSet has no field yet
for name, field in instance.fields.items(): for name, field in instance.fields.items():
field = field.tolist() if isinstance(field, np.ndarray) else field
# field = field.tolist() if isinstance(field, np.ndarray) else field
self.field_arrays[name] = FieldArray(name, [field]) # 第一个样本,必须用list包装起来 self.field_arrays[name] = FieldArray(name, [field]) # 第一个样本,必须用list包装起来
else: else:
if len(self.field_arrays) != len(instance.fields): if len(self.field_arrays) != len(instance.fields):
@@ -431,7 +432,11 @@ class DataSet(object):
.format(len(self.field_arrays), len(instance.fields))) .format(len(self.field_arrays), len(instance.fields)))
for name, field in instance.fields.items(): for name, field in instance.fields.items():
assert name in self.field_arrays assert name in self.field_arrays
self.field_arrays[name].append(field)
try:
self.field_arrays[name].append(field)
except AppendToTargetOrInputException as e:
print(f"Cannot append to field:{name}.")
raise e
def add_fieldarray(self, field_name, fieldarray): def add_fieldarray(self, field_name, fieldarray):
""" """
@@ -565,7 +570,11 @@ class DataSet(object):
assert isinstance(flag, bool), "Only bool type supported." assert isinstance(flag, bool), "Only bool type supported."
for name in field_names: for name in field_names:
if name in self.field_arrays: if name in self.field_arrays:
self.field_arrays[name].is_target = flag
try:
self.field_arrays[name].is_target = flag
except SetInputOrTargetException as e:
print(f"Cannot set field:{name} as target.")
raise e
else: else:
raise KeyError("{} is not a valid field name.".format(name)) raise KeyError("{} is not a valid field name.".format(name))
@@ -581,7 +590,11 @@ class DataSet(object):
""" """
for name in field_names: for name in field_names:
if name in self.field_arrays: if name in self.field_arrays:
self.field_arrays[name].is_input = flag
try:
self.field_arrays[name].is_input = flag
except SetInputOrTargetException as e:
print(f"Cannot set field:{name} as input.")
raise e
else: else:
raise KeyError("{} is not a valid field name.".format(name)) raise KeyError("{} is not a valid field name.".format(name))


+ 271
- 301
fastNLP/core/field.py View File

@@ -1,251 +1,162 @@
"""
field模块实现了 FieldArray 和若干 Padder。 FieldArray 是 :class:`~fastNLP.DataSet` 中一列的存储方式,
原理部分请参考 :doc:`fastNLP.core.dataset`

"""
__all__ = [
"FieldArray",
"Padder",
"AutoPadder",
"EngChar2DPadder"
]


from copy import deepcopy


from numbers import Number
import torch
import numpy as np import numpy as np
from typing import Any
from abc import abstractmethod
from copy import deepcopy



class FieldArray(object):
"""
别名::class:`fastNLP.FieldArray` :class:`fastNLP.core.field.FieldArray`

FieldArray 是用于保存 :class:`~fastNLP.DataSet` 中一个field的类型。
:param str name: FieldArray的名称
:param list,numpy.ndarray content: 列表的元素可以为list,int,float,
:param bool is_target: 这个field是否是一个target field。
:param bool is_input: 这个field是否是一个input field。
:param padder: :class:`~fastNLP.Padder` 类型。赋值给fieldarray的padder的对象会被deepcopy一份,需要修改padder参数必须通过
fieldarray.set_pad_val()。默认为None,即使用 :class:`~fastNLP.AutoPadder` 。
:param bool ignore_type: 是否忽略该field的type,一般如果这个field不需要转为torch.FloatTensor或torch.LongTensor,
就可以设置为True。具体意义请参考 :class:`~fastNLP.DataSet` 。
"""
def __init__(self, name, content, is_target=None, is_input=None, padder=None, ignore_type=False):
class SetInputOrTargetException(Exception):
def __init__(self, msg, index=None, field_name=None):
super().__init__(msg)
self.msg = msg
self.index = index # 标示在哪个数据遭遇到问题了
self.field_name = field_name # 标示当前field的名称

class AppendToTargetOrInputException(Exception):
def __init__(self, msg, index=None, field_name=None):
super().__init__(msg)
self.msg = msg
self.index = index # 标示在哪个数据遭遇到问题了
self.field_name = field_name # 标示当前field的名称

class FieldArray:
def __init__(self, name, content, is_target=False, is_input=False, padder=None, ignore_type=False):
if len(content)==0:
raise RuntimeError("Empty fieldarray is not allowed.")
_content = content
try:
_content = list(_content)
except BaseException as e:
print(f"Cannot convert content(of type:{type(content)}) into list.")
raise e
self.name = name self.name = name
if isinstance(content, list):
# 如果DataSet使用dict初始化, content 可能是二维list/二维array/三维list
# 如果DataSet使用list of Instance 初始化, content可能是 [list]/[array]/[2D list]
for idx, item in enumerate(content):
# 这是使用list of Instance 初始化时第一个样本:FieldArray(name, [field])
# 将[np.array] 转化为 list of list
# 也可以支持[array, array, array]的情况
if isinstance(item, np.ndarray):
content[idx] = content[idx].tolist()
elif isinstance(content, np.ndarray):
content = content.tolist() # convert np.ndarray into 2-D list
else:
raise TypeError("content in FieldArray can only be list or numpy.ndarray, got {}.".format(type(content)))
if len(content) == 0:
raise RuntimeError("Cannot initialize FieldArray with empty list.")
self.content = content # 1维 或 2维 或 3维 list, 形状可能不对齐
self.content_dim = None # 表示content是多少维的list
self.content = _content
self._ignore_type = ignore_type
# 根据input的情况设置input,target等
self._cell_ndim = None # 多少维度
self.dtype = None # 最内层的element都是什么类型的
self._is_input = False
self._is_target = False

if is_input:
self.is_input = is_input
if is_target:
self.is_target = is_target

if padder is None: if padder is None:
padder = AutoPadder(pad_val=0) padder = AutoPadder(pad_val=0)
else: else:
assert isinstance(padder, Padder), "padder must be of type Padder."
assert isinstance(padder, Padder), "padder must be of type fastNLP.Padder."
padder = deepcopy(padder) padder = deepcopy(padder)
self.set_padder(padder) self.set_padder(padder)
self.ignore_type = ignore_type
self.BASIC_TYPES = (int, float, str) # content中可接受的Python基本类型,这里没有np.array
self.pytype = None
self.dtype = None
self._is_input = None
self._is_target = None
if is_input is not None or is_target is not None:
self.is_input = is_input
self.is_target = is_target
def _set_dtype(self):
if self.ignore_type is False:
self.pytype = self._type_detection(self.content)
self.dtype = self._map_to_np_type(self.pytype)

@property
def ignore_type(self):
return self._ignore_type

@ignore_type.setter
def ignore_type(self, value):
if value:
self._cell_ndim = None
self.dtype = None

@property @property
def is_input(self): def is_input(self):
return self._is_input return self._is_input

@is_input.setter @is_input.setter
def is_input(self, value): def is_input(self, value):
""" """
当 field_array.is_input = True / False 时被调用 当 field_array.is_input = True / False 时被调用
""" """
if value is True:
self._set_dtype()
# 如果(value为True)且(_is_input和_is_target都是False)且(ignore_type为False)
if value is True and \
self._is_target is False and \
self._ignore_type is False:
self._check_dtype_and_ndim()
if value is False and self._is_target is False:
self.dtype = None
self._cell_ndim = None
self._is_input = value self._is_input = value
@property @property
def is_target(self): def is_target(self):
return self._is_target return self._is_target
@is_target.setter @is_target.setter
def is_target(self, value): def is_target(self, value):
""" """
当 field_array.is_target = True / False 时被调用 当 field_array.is_target = True / False 时被调用
""" """
if value is True:
self._set_dtype()
if value is True and \
self._is_input is False and \
self._ignore_type is False:
self._check_dtype_and_ndim()
if value is False and self._is_input is False:
self.dtype = None
self._cell_ndim = None
self._is_target = value self._is_target = value
def _type_detection(self, content):
"""
当该field被设置为is_input或者is_target时被调用


def _check_dtype_and_ndim(self):
""" """
if len(content) == 0:
raise RuntimeError("Empty list in Field {}.".format(self.name))
type_set = set([type(item) for item in content])
if list in type_set:
if len(type_set) > 1:
# list 跟 非list 混在一起
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set)))
# >1维list
inner_type_set = set()
for l in content:
[inner_type_set.add(type(obj)) for obj in l]
if list not in inner_type_set:
# 二维list
self.content_dim = 2
return self._basic_type_detection(inner_type_set)
else:
if len(inner_type_set) == 1:
# >2维list
inner_inner_type_set = set()
for _2d_list in content:
for _1d_list in _2d_list:
[inner_inner_type_set.add(type(obj)) for obj in _1d_list]
if list in inner_inner_type_set:
raise RuntimeError("FieldArray cannot handle 4-D or more-D list.")
# 3维list
self.content_dim = 3
return self._basic_type_detection(inner_inner_type_set)
else:
# list 跟 非list 混在一起
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(inner_type_set)))
else:
# 一维list
for content_type in type_set:
if content_type not in self.BASIC_TYPES:
raise RuntimeError("Unexpected data type in Field '{}'. Expect one of {}. Got {}.".format(
self.name, self.BASIC_TYPES, content_type))
self.content_dim = 1
return self._basic_type_detection(type_set)
def _basic_type_detection(self, type_set):
检查当前content所有的element是否是同一个类型,且是否每个元素具有相同的维度。通过的话,设置_cell_ndim与_ele_type属性;没有
通过将直接报错.

:return:
""" """
:param type_set: a set of Python types
:return: one of self.BASIC_TYPES
cell_0 = self.content[0]
index = 0
try:
type_0, dim_0 = _get_ele_type_and_dim(cell_0)
for cell in self.content[1:]:
index += 1
type_i, dim_i = _get_ele_type_and_dim(cell)
if type_i!=type_0:
raise SetInputOrTargetException("Type:{} in index {} is different from the first element with type:{}."
".".format(type_i, index, type_0))
if dim_0!=dim_i:
raise SetInputOrTargetException("Dimension:{} in index {} is different from the first element with "
"dimension:{}.".format(dim_i, index, dim_0))
self._cell_ndim = dim_0
self.dtype = type_0
except SetInputOrTargetException as e:
e.index = index
raise e

def append(self, val:Any):
"""
:param val: 把该val append到fieldarray。
:return:
""" """
if len(type_set) == 1:
return type_set.pop()
elif len(type_set) == 2:
# 有多个basic type; 可能需要up-cast
if float in type_set and int in type_set:
# up-cast int to float
return float
else:
# str 跟 int 或者 float 混在一起
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set)))
if (self._is_target or self._is_input) and self._ignore_type is False:
type_, dim_ = _get_ele_type_and_dim(val)
if self.dtype!=type_:
raise AppendToTargetOrInputException(f"Value(type:{type_}) are of different types with "
f"previous values(type:{self.dtype}).")
if self._cell_ndim!=dim_:
raise AppendToTargetOrInputException(f"Value(dim:{dim_}) are of different dimensions with "
f"previous values(dim:{self._cell_ndim}).")
self.content.append(val)
else: else:
# str, int, float混在一起
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set)))
def _1d_list_check(self, val):
"""如果不是1D list就报错
"""
type_set = set((type(obj) for obj in val))
if any(obj not in self.BASIC_TYPES for obj in type_set):
raise ValueError("Mixed data types in Field {}: {}".format(self.name, list(type_set)))
self._basic_type_detection(type_set)
# otherwise: _basic_type_detection will raise error
return True
def _2d_list_check(self, val):
"""如果不是2D list 就报错
"""
type_set = set(type(obj) for obj in val)
if list(type_set) != [list]:
raise ValueError("Mixed data types in Field {}: {}".format(self.name, type_set))
inner_type_set = set()
for l in val:
for obj in l:
inner_type_set.add(type(obj))
self._basic_type_detection(inner_type_set)
return True
@staticmethod
def _map_to_np_type(basic_type):
type_mapping = {int: np.int64, float: np.float64, str: np.str, np.ndarray: np.ndarray}
return type_mapping[basic_type]
def __repr__(self):
return "FieldArray {}: {}".format(self.name, self.content.__repr__())
def append(self, val):
"""将val append到这个field的尾部。如果这个field已经被设置为input或者target,则在append之前会检查该类型是否与已有
的内容是匹配的。

:param Any val: 需要append的值。
"""
if self.ignore_type is False:
if isinstance(val, list):
pass
elif isinstance(val, tuple): # 确保最外层是list
val = list(val)
elif isinstance(val, np.ndarray):
val = val.tolist()
elif any((isinstance(val, t) for t in self.BASIC_TYPES)):
pass
else:
raise RuntimeError(
"Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES))
if self.is_input is True or self.is_target is True:
if type(val) == list:
if len(val) == 0:
raise ValueError("Cannot append an empty list.")
if self.content_dim == 2 and self._1d_list_check(val):
# 1维list检查
pass
elif self.content_dim == 3 and self._2d_list_check(val):
# 2维list检查
pass
else:
raise RuntimeError(
"Dimension not matched: expect dim={}, got {}.".format(self.content_dim - 1, val))
elif type(val) in self.BASIC_TYPES and self.content_dim == 1:
# scalar检查
if type(val) == float and self.pytype == int:
self.pytype = float
self.dtype = self._map_to_np_type(self.pytype)
else:
raise RuntimeError(
"Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES))
self.content.append(val)
self.content.append(val)

def __getitem__(self, indices): def __getitem__(self, indices):
return self.get(indices, pad=False) return self.get(indices, pad=False)

def __setitem__(self, idx, val): def __setitem__(self, idx, val):
assert isinstance(idx, int) assert isinstance(idx, int)
if (self._is_target or self._is_input) and self.ignore_type is False: # 需要检测类型
type_, dim_ = _get_ele_type_and_dim(val)
if self.dtype!=type_:
raise RuntimeError(f"Value(type:{type_}) are of different types with "
f"other values(type:{self.dtype}).")
if self._cell_ndim!=dim_:
raise RuntimeError(f"Value(dim:{dim_}) are of different dimensions with "
f"previous values(dim:{self._cell_ndim}).")
self.content[idx] = val self.content[idx] = val

def get(self, indices, pad=True): def get(self, indices, pad=True):
""" """
根据给定的indices返回内容 根据给定的indices返回内容
@@ -257,14 +168,14 @@ class FieldArray(object):
if isinstance(indices, int): if isinstance(indices, int):
return self.content[indices] return self.content[indices]
if self.is_input is False and self.is_target is False: if self.is_input is False and self.is_target is False:
raise RuntimeError("Please specify either is_input or is_target is True for {}".format(self.name))
raise RuntimeError("Please specify either is_input or is_target to True for {}".format(self.name))
contents = [self.content[i] for i in indices] contents = [self.content[i] for i in indices]
if self.padder is None or pad is False: if self.padder is None or pad is False:
return np.array(contents) return np.array(contents)
else: else:
return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype)
return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim)
def set_padder(self, padder): def set_padder(self, padder):
""" """
设置padder,在这个field进行pad的时候用这个padder进行pad,如果为None则不进行pad。 设置padder,在这个field进行pad的时候用这个padder进行pad,如果为None则不进行pad。
@@ -276,7 +187,7 @@ class FieldArray(object):
self.padder = deepcopy(padder) self.padder = deepcopy(padder)
else: else:
self.padder = None self.padder = None
def set_pad_val(self, pad_val): def set_pad_val(self, pad_val):
""" """
修改padder的pad_val. 修改padder的pad_val.
@@ -286,7 +197,7 @@ class FieldArray(object):
if self.padder is not None: if self.padder is not None:
self.padder.set_pad_val(pad_val) self.padder.set_pad_val(pad_val)
return self return self
def __len__(self): def __len__(self):
""" """
Returns the size of FieldArray. Returns the size of FieldArray.
@@ -294,7 +205,7 @@ class FieldArray(object):
:return int length: :return int length:
""" """
return len(self.content) return len(self.content)
def to(self, other): def to(self, other):
""" """
将other的属性复制给本FieldArray(other必须为FieldArray类型). 将other的属性复制给本FieldArray(other必须为FieldArray类型).
@@ -303,22 +214,63 @@ class FieldArray(object):
:param other: :class:`~fastNLP.FieldArray` 从哪个field拷贝属性 :param other: :class:`~fastNLP.FieldArray` 从哪个field拷贝属性
:return: :class:`~fastNLP.FieldArray` :return: :class:`~fastNLP.FieldArray`
""" """
assert isinstance(other, FieldArray), "Only support FieldArray type, not {}.".format(type(other))
assert isinstance(other, FieldArray), "Only supports fastNLP.FieldArray type, not {}.".format(type(other))

self.ignore_type = other.ignore_type
self.is_input = other.is_input self.is_input = other.is_input
self.is_target = other.is_target self.is_target = other.is_target
self.padder = other.padder self.padder = other.padder
self.ignore_type = other.ignore_type

return self return self




def _is_iterable(content):
def _get_ele_type_and_dim(cell:Any, dim=0):
"""
识别cell的类别与dimension的数量

numpy scalar type:https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
:param cell:
:param dim:
:return:
"""
if isinstance(cell, (str, Number, np.bool_)):
return type(cell), dim
elif isinstance(cell, list):
dim += 1
res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell]
types = set([i for i,j in res])
dims = set([j for i,j in res])
if len(types)>1:
raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types)))
if len(dims)>1:
raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims)))
return types.pop(), dims.pop()
elif isinstance(cell, torch.Tensor):
return cell.dtype, cell.dim() + dim # 如果是torch.mean的结果是0
elif isinstance(cell, np.ndarray):
if cell.dtype != np.dtype('O'): # 如果不是object的话说明是well-formatted的了
return cell.dtype.type, cell.ndim + dim
# 否则需要继续往下iterate
dim += 1
res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell]
types = set([i for i,j in res])
dims = set([j for i,j in res])
if len(types)>1:
raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types)))
if len(dims)>1:
raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims)))
return types.pop(), dims.pop()
else: # 包含tuple, set, dict以及其它的类型
raise SetInputOrTargetException(f"Cannot process type:{type(cell)}.")


def _is_iterable(value):
# 检查是否是iterable的, duck typing
try: try:
_ = (e for e in content)
except TypeError:
iter(value)
return True
except BaseException as e:
return False return False
return True




class Padder: class Padder:
@@ -327,32 +279,35 @@ class Padder:


所有padder都需要继承这个类,并覆盖__call__方法。 所有padder都需要继承这个类,并覆盖__call__方法。
用于对batch进行padding操作。传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前deepcopy一份。 用于对batch进行padding操作。传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前deepcopy一份。
.. py:function:: __call__(self, contents, field_name, field_ele_dtype): .. py:function:: __call__(self, contents, field_name, field_ele_dtype):
传入的是List内容。假设有以下的DataSet。 传入的是List内容。假设有以下的DataSet。
:param list(Any) contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前 :param list(Any) contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前
deepcopy一份。 deepcopy一份。
:param str, field_name: field的名称。 :param str, field_name: field的名称。
:param np.int64,np.float64,np.str,None, field_ele_dtype: 该field的内层元素的类型。如果该field的ignore_type为True,该这个值为None。 :param np.int64,np.float64,np.str,None, field_ele_dtype: 该field的内层元素的类型。如果该field的ignore_type为True,该这个值为None。
:return: np.array([padded_element]) :return: np.array([padded_element])
""" """
def __init__(self, pad_val=0, **kwargs): def __init__(self, pad_val=0, **kwargs):
self.pad_val = pad_val self.pad_val = pad_val
def set_pad_val(self, pad_val): def set_pad_val(self, pad_val):
self.pad_val = pad_val self.pad_val = pad_val
def __call__(self, contents, field_name, field_ele_dtype):

@abstractmethod
def __call__(self, contents, field_name, field_ele_dtype, dim:int):
""" """
传入的是List内容。假设有以下的DataSet。 传入的是List内容。假设有以下的DataSet。


:param list(Any) contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前 :param list(Any) contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前
deepcopy一份。 deepcopy一份。
:param str, field_name: field的名称。 :param str, field_name: field的名称。
:param np.int64,np.float64,np.str,None, field_ele_dtype: 该field的内层元素的类型。如果该field的ignore_type为True,该这个值为None。
:param np.int64,np.float64,np.str,None, field_ele_dtype: 该field的内层元素的类型。如果该field的ignore_type为True,
该这个值为None。
:param dim: 这个field的维度。当ignore_type为True时,该值为None
:return: np.array([padded_element]) :return: np.array([padded_element])


Example:: Example::
@@ -394,50 +349,87 @@ class AutoPadder(Padder):
根据contents的数据自动判定是否需要做padding。 根据contents的数据自动判定是否需要做padding。


1 如果元素类型(元素类型是指field中最里层元素的数据类型, 可以通过FieldArray.dtype查看,比如['This', 'is', ...]的元素类 1 如果元素类型(元素类型是指field中最里层元素的数据类型, 可以通过FieldArray.dtype查看,比如['This', 'is', ...]的元素类
型为np.str, [[1,2], ...]的元素类型为np.int64)的数据不为(np.int64, np.float64)则不会进行pad
型为str, [[1,2], ...]的元素类型为int)的数据不为数值类型则不会进行pad

2 如果元素类型为数值类型,比如np.int64, np.float64, int, float, torch.int64等


2 如果元素类型为(np.int64, np.float64),
2.1 如果该field的内容为数值类型(包括int, float等),比如为seq_len, 则不进行padding


2.1 如果该field的内容为(np.int64, np.float64),比如为seq_len, 则不进行padding
2.2 如果该field的内容等价于一维list, 那么会将Batch中的List pad为一样长。


2.2 如果该field的内容为List, 那么会将Batch中的List pad为一样长。若该List下还有里层的List需要padding,请使用其它padder。
即如果Instance中field形如[1, 2, 3, ...],则可以pad;若为[[1,2], [3,4, ...]]则不能进行pad
2.3 如果该field的内容等价于二维list,那么会按照英语character padding的方式进行padding。如果是character padding建议使用
:class: fastNLP.EngChar2DPadder.

2.4 如果该field的内容等价于三维list,则如果每个instance在每个维度上相等,会组成一个batch的tensor返回,这种情况应该是为图片
的情况。

3 其它情况不进行处理,返回一个np.array类型。
""" """
def __init__(self, pad_val=0): def __init__(self, pad_val=0):
"""
:param pad_val: int, padding的位置使用该index
"""
super().__init__(pad_val=pad_val) super().__init__(pad_val=pad_val)
def _is_two_dimension(self, contents):
"""
判断contents是不是只有两个维度。[[1,2], [3]]是两个维度. [[[1,2], [3, 4, 5]], [[4,5]]]有三个维度
:param contents:
:return:
"""
value = contents[0]
if isinstance(value, (np.ndarray, list)):
value = value[0]
if isinstance(value, (np.ndarray, list)):
return False
return True
return False
def __call__(self, contents, field_name, field_ele_dtype):
if not _is_iterable(contents[0]):
array = np.array([content for content in contents], dtype=field_ele_dtype)
elif field_ele_dtype in (np.int64, np.float64) and self._is_two_dimension(contents):
max_len = max([len(content) for content in contents])
array = np.full((len(contents), max_len), self.pad_val, dtype=field_ele_dtype)
for i, content in enumerate(contents):
array[i][:len(content)] = content
elif field_ele_dtype is None:
array = np.array(contents) # 当ignore_type=True时,直接返回contents
else: # should only be str
array = np.array([content for content in contents])
return array

def __call__(self, contents, field_name, field_ele_dtype, dim):
if field_ele_dtype:
if dim>3:
return np.array(contents)
if isinstance(field_ele_dtype, np.dtype) or field_ele_dtype in (float, int, bool, str):
if isinstance(field_ele_dtype, np.number) or field_ele_dtype in (float, int, bool):
if dim==0:
array = np.array(contents, dtype=field_ele_dtype)
elif dim==1:
max_len = max(map(len, contents))
array = np.full((len(contents), max_len), self.pad_val, dtype=field_ele_dtype)
for i, content_i in enumerate(contents):
array[i, :len(content_i)] = content_i
elif dim==2:
max_len = max(map(len, contents))
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for
content_i in contents])
array = np.full((len(contents), max_len, max_word_len), self.pad_val, dtype=field_ele_dtype)
for i, content_i in enumerate(contents):
for j, content_ii in enumerate(content_i):
array[i, j, :len(content_ii)] = content_ii
else:
shape = np.shape(contents)
if len(shape)==4: # 说明各dimension是相同的大小
array = np.array(contents, dtype=field_ele_dtype)
else:
raise RuntimeError(f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
return array
return np.array(contents)
elif str(field_ele_dtype).startswith('torch'):
if dim==0:
tensor = torch.tensor(contents).to(field_ele_dtype)
elif dim==1:
max_len = max(map(len, contents))
tensor = torch.full((len(contents), max_len), fill_value=self.pad_val, dtype=field_ele_dtype)
for i, content_i in enumerate(contents):
tensor[i, :len(content_i)] = torch.tensor(content_i)
elif dim==2:
max_len = max(map(len, contents))
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for
content_i in contents])
tensor = torch.full((len(contents), max_len, max_word_len), fill_value=self.pad_val,
dtype=field_ele_dtype)
for i, content_i in enumerate(contents):
for j, content_ii in enumerate(content_i):
tensor[i, j, :len(content_ii)] = torch.tensor(content_ii)
else:
shapes = set([np.shape(content_i) for content_i in contents])
if len(shapes)>1:
raise RuntimeError(f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
shape = shapes.pop()
if len(shape)==3:
tensor = torch.full([len(contents)]+list(shape), fill_value=self.pad_val, dtype=field_ele_dtype)
for i, content_i in enumerate(contents):
tensor[i] = torch.tensor(content_i, dtype=field_ele_dtype)
else:
raise RuntimeError(f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
return tensor
else:
return np.array(contents) # 不进行任何操作
else:
return np.array(contents)




class EngChar2DPadder(Padder): class EngChar2DPadder(Padder):
@@ -463,7 +455,7 @@ class EngChar2DPadder(Padder):
dataset.set_padder('chars', padder) # chars这个field的设置为了EnChar2DPadder dataset.set_padder('chars', padder) # chars这个field的设置为了EnChar2DPadder


""" """
def __init__(self, pad_val=0, pad_length=0): def __init__(self, pad_val=0, pad_length=0):
""" """
:param pad_val: int, pad的位置使用该index :param pad_val: int, pad的位置使用该index
@@ -471,32 +463,10 @@ class EngChar2DPadder(Padder):
都pad或截取到该长度. 都pad或截取到该长度.
""" """
super().__init__(pad_val=pad_val) super().__init__(pad_val=pad_val)
self.pad_length = pad_length self.pad_length = pad_length
def _exactly_three_dims(self, contents, field_name):
"""
检查传入的contents是否刚好是3维,如果不是3维就报错。理论上,第一个维度是batch,第二个维度是word,第三个维度是character
:param contents:
:param field_name: str
:return:
"""
if not isinstance(contents, list):
raise TypeError("contents should be a list, not {}.".format(type(contents)))
value = contents[0]
try:
value = value[0]
except:
raise ValueError("Field:{} only has one dimension.".format(field_name))
try:
value = value[0]
except:
raise ValueError("Field:{} only has two dimensions.".format(field_name))
if _is_iterable(value):
raise ValueError("Field:{} has more than 3 dimension.".format(field_name))
def __call__(self, contents, field_name, field_ele_dtype):

def __call__(self, contents, field_name, field_ele_dtype, dim):
""" """
期望输入类似于 期望输入类似于
[ [
@@ -510,11 +480,11 @@ class EngChar2DPadder(Padder):
:param field_ele_dtype :param field_ele_dtype
:return: :return:
""" """
if field_ele_dtype not in (np.int64, np.float64):
if field_ele_dtype not in (np.int64, np.float64, int, float):
raise TypeError('dtype of Field:{} should be np.int64 or np.float64 to do 2D padding, get {}.'.format( raise TypeError('dtype of Field:{} should be np.int64 or np.float64 to do 2D padding, get {}.'.format(
field_name, field_ele_dtype field_name, field_ele_dtype
)) ))
self._exactly_three_dims(contents, field_name)
assert dim==2, f"Field:{field_name} has {dim}, EngChar2DPadder only supports input with 2 dimensions."
if self.pad_length < 1: if self.pad_length < 1:
max_char_length = max([max(len(char_lst) for char_lst in word_lst) for word_lst in contents]) max_char_length = max([max(len(char_lst) for char_lst in word_lst) for word_lst in contents])
else: else:
@@ -522,12 +492,12 @@ class EngChar2DPadder(Padder):
max_sent_length = max(len(word_lst) for word_lst in contents) max_sent_length = max(len(word_lst) for word_lst in contents)
batch_size = len(contents) batch_size = len(contents)
dtype = type(contents[0][0][0]) dtype = type(contents[0][0][0])
padded_array = np.full((batch_size, max_sent_length, max_char_length), fill_value=self.pad_val, padded_array = np.full((batch_size, max_sent_length, max_char_length), fill_value=self.pad_val,
dtype=dtype) dtype=dtype)
for b_idx, word_lst in enumerate(contents): for b_idx, word_lst in enumerate(contents):
for c_idx, char_lst in enumerate(word_lst): for c_idx, char_lst in enumerate(word_lst):
chars = char_lst[:max_char_length] chars = char_lst[:max_char_length]
padded_array[b_idx, c_idx, :len(chars)] = chars padded_array[b_idx, c_idx, :len(chars)] = chars
return padded_array return padded_array

+ 2
- 2
fastNLP/io/embed_loader.py View File

@@ -107,9 +107,9 @@ class EmbedLoader(BaseLoader):
:param bool normalize: 是否将每个vector归一化到norm为1 :param bool normalize: 是否将每个vector归一化到norm为1
:param str error: `ignore` , `strict` ; 如果 `ignore` ,错误将自动跳过; 如果 `strict` , 错误将抛出。这里主要可能出错的地 :param str error: `ignore` , `strict` ; 如果 `ignore` ,错误将自动跳过; 如果 `strict` , 错误将抛出。这里主要可能出错的地
方在于词表有空行或者词表出现了维度不一致。 方在于词表有空行或者词表出现了维度不一致。
:return numpy.ndarray: shape为 [len(vocab), dimension], dimension由pretrain的embedding决定。
:return numpy.ndarray: Vocabulary Embedding的shape是[词表大小+x, 词表维度], "词表大小+x"是由于最终的大小还取决与
:return (numpy.ndarray, Vocabulary): Embedding的shape是[词表大小+x, 词表维度], "词表大小+x"是由于最终的大小还取决与
是否使用padding, 以及unknown有没有在词表中找到对应的词。 Vocabulary中的词的顺序与Embedding的顺序是一一对应的。 是否使用padding, 以及unknown有没有在词表中找到对应的词。 Vocabulary中的词的顺序与Embedding的顺序是一一对应的。

""" """
vocab = Vocabulary(padding=padding, unknown=unknown) vocab = Vocabulary(padding=padding, unknown=unknown)
vec_dict = {} vec_dict = {}


+ 143
- 62
test/core/test_field.py View File

@@ -1,8 +1,55 @@
import unittest import unittest


import numpy as np import numpy as np
import torch


from fastNLP import FieldArray from fastNLP import FieldArray
from fastNLP.core.field import _get_ele_type_and_dim
from fastNLP import AutoPadder

class TestFieldArrayTyepDimDetect(unittest.TestCase):
"""
检测FieldArray能否正确识别type与ndim

"""
def test_case1(self):
# 1.1 常规类型测试
for value in [1, True, 1.0, 'abc']:
type_ = type(value)
_type, _dim = _get_ele_type_and_dim(cell=value)
self.assertListEqual([_type, _dim], [type_, 0])
# 1.2 mix类型报错
with self.assertRaises(Exception):
value = [1, 2, 1.0]
self.assertRaises(_get_ele_type_and_dim(value))
# 带有numpy的测试
# 2.1
value = np.array([1, 2, 3])
type_ = value.dtype
dim_ = 1
self.assertSequenceEqual(_get_ele_type_and_dim(cell=value), [type_, dim_])
# 2.2
value = np.array([[1, 2], [3, 4, 5]]) # char embedding的场景
self.assertSequenceEqual([int, 2], _get_ele_type_and_dim(value))
# 2.3
value = np.zeros((3, 4))
self.assertSequenceEqual([value.dtype, 2], _get_ele_type_and_dim(value))
# 2.4 测试错误的dimension
with self.assertRaises(Exception):
value = np.array([[1, 2], [3, [1]]])
_get_ele_type_and_dim(value)
# 2.5 测试混合类型
with self.assertRaises(Exception):
value = np.array([[1, 2], [3.0]])
_get_ele_type_and_dim(value)

# 带有tensor的测试
# 3.1 word embedding的场景
value = torch.zeros(3, 10)
self.assertSequenceEqual([value.dtype, 2], _get_ele_type_and_dim(value))
# 3.2 char embedding/image的场景
value = torch.zeros(3, 32, 32)
self.assertSequenceEqual([value.dtype, 3], _get_ele_type_and_dim(value))




class TestFieldArrayInit(unittest.TestCase): class TestFieldArrayInit(unittest.TestCase):
@@ -31,12 +78,6 @@ class TestFieldArrayInit(unittest.TestCase):
# 三维list # 三维list
fa = FieldArray("x", [[[1, 2], [3, 4]], [[1, 2], [3, 4]]], is_input=True) fa = FieldArray("x", [[[1, 2], [3, 4]], [[1, 2], [3, 4]]], is_input=True)


def test_init_v7(self):
# list of array
fa = FieldArray("x", [np.array([[1, 2], [3, 4]]), np.array([[1, 2], [3, 4]])], is_input=True)
self.assertEqual(fa.pytype, int)
self.assertEqual(fa.dtype, np.int)

def test_init_v4(self): def test_init_v4(self):
# 一维list # 一维list
val = [1, 2, 3, 4] val = [1, 2, 3, 4]
@@ -56,6 +97,11 @@ class TestFieldArrayInit(unittest.TestCase):
fa.append(val) fa.append(val)


def test_init_v7(self): def test_init_v7(self):
# list of array
fa = FieldArray("x", [np.array([[1, 2], [3, 4]]), np.array([[1, 2], [3, 4]])], is_input=True)
self.assertEqual(fa.dtype, np.array([1]).dtype)

def test_init_v8(self):
# 二维list # 二维list
val = np.array([[1, 2], [3, 4]]) val = np.array([[1, 2], [3, 4]])
fa = FieldArray("x", [val], is_input=True) fa = FieldArray("x", [val], is_input=True)
@@ -79,33 +125,23 @@ class TestFieldArray(unittest.TestCase):
self.assertListEqual(list(fa.get([0, 1, 2])), [1, 2, 3]) self.assertListEqual(list(fa.get([0, 1, 2])), [1, 2, 3])


def test_type_conversion(self): def test_type_conversion(self):
fa = FieldArray("x", [1.2, 2.2, 3, 4, 5], is_input=True)
self.assertEqual(fa.pytype, float)
self.assertEqual(fa.dtype, np.float64)

fa = FieldArray("x", [1, 2, 3, 4, 5], is_input=True) fa = FieldArray("x", [1, 2, 3, 4, 5], is_input=True)
fa.append(1.3333)
self.assertEqual(fa.pytype, float)
self.assertEqual(fa.dtype, np.float64)
self.assertEqual(fa.dtype, int)


fa = FieldArray("y", [1.1, 2.2, 3.3, 4.4, 5.5], is_input=True) fa = FieldArray("y", [1.1, 2.2, 3.3, 4.4, 5.5], is_input=True)
fa.append(10)
self.assertEqual(fa.pytype, float)
self.assertEqual(fa.dtype, np.float64)
fa.append(10.0)
self.assertEqual(fa.dtype, float)


fa = FieldArray("y", ["a", "b", "c", "d"], is_input=True) fa = FieldArray("y", ["a", "b", "c", "d"], is_input=True)
fa.append("e") fa.append("e")
self.assertEqual(fa.dtype, np.str)
self.assertEqual(fa.pytype, str)
self.assertEqual(fa.dtype, str)


def test_support_np_array(self): def test_support_np_array(self):
fa = FieldArray("y", np.array([[1.1, 2.2, 3.3, 4.4, 5.5]]), is_input=True) fa = FieldArray("y", np.array([[1.1, 2.2, 3.3, 4.4, 5.5]]), is_input=True)
self.assertEqual(fa.dtype, np.float64) self.assertEqual(fa.dtype, np.float64)
self.assertEqual(fa.pytype, float)


fa.append(np.array([1.1, 2.2, 3.3, 4.4, 5.5])) fa.append(np.array([1.1, 2.2, 3.3, 4.4, 5.5]))
self.assertEqual(fa.dtype, np.float64) self.assertEqual(fa.dtype, np.float64)
self.assertEqual(fa.pytype, float)


fa = FieldArray("my_field", np.random.rand(3, 5), is_input=True) fa = FieldArray("my_field", np.random.rand(3, 5), is_input=True)
# in this case, pytype is actually a float. We do not care about it. # in this case, pytype is actually a float. We do not care about it.
@@ -113,11 +149,10 @@ class TestFieldArray(unittest.TestCase):


def test_nested_list(self): def test_nested_list(self):
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.1, 2.2, 3.3, 4.4, 5.5]], is_input=True) fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.1, 2.2, 3.3, 4.4, 5.5]], is_input=True)
self.assertEqual(fa.pytype, float)
self.assertEqual(fa.dtype, np.float64)
self.assertEqual(fa.dtype, float)


def test_getitem_v1(self): def test_getitem_v1(self):
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True)
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]], is_input=True)
self.assertEqual(fa[0], [1.1, 2.2, 3.3, 4.4, 5.5]) self.assertEqual(fa[0], [1.1, 2.2, 3.3, 4.4, 5.5])
ans = fa[[0, 1]] ans = fa[[0, 1]]
self.assertTrue(isinstance(ans, np.ndarray)) self.assertTrue(isinstance(ans, np.ndarray))
@@ -150,7 +185,7 @@ class TestFieldArray(unittest.TestCase):
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True) fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True)
fa.append(["str", 0, 0, 0, 1.89]) fa.append(["str", 0, 0, 0, 1.89])


fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True)
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]], is_input=True)
fa.append([1.2, 2.3, 3.4, 4.5, 5.6]) fa.append([1.2, 2.3, 3.4, 4.5, 5.6])
self.assertEqual(len(fa), 3) self.assertEqual(len(fa), 3)
self.assertEqual(fa[2], [1.2, 2.3, 3.4, 4.5, 5.6]) self.assertEqual(fa[2], [1.2, 2.3, 3.4, 4.5, 5.6])
@@ -163,33 +198,86 @@ class TestFieldArray(unittest.TestCase):
fa = FieldArray("y", [(1, "1"), (2, "2"), (3, "3"), (4, "4")], is_target=True, ignore_type=True) fa = FieldArray("y", [(1, "1"), (2, "2"), (3, "3"), (4, "4")], is_target=True, ignore_type=True)




class TestPadder(unittest.TestCase):
class TestAutoPadder(unittest.TestCase):
def test00(self):
padder = AutoPadder()
# 没有类型时
contents = [(1, 2), ('str', 'a')]
padder(contents, None, None, None)


def test01(self): def test01(self):
"""
测试AutoPadder能否正常工作
:return:
"""
from fastNLP import AutoPadder
# 测试使用多维的bool, int, str, float的情况
# str
padder = AutoPadder() padder = AutoPadder()
content = ['This is a str', 'this is another str'] content = ['This is a str', 'this is another str']
self.assertListEqual(content, padder(content, None, np.str).tolist())
self.assertListEqual(content, padder(content, None, str, 0).tolist())


content = [1, 2]
self.assertListEqual(content, padder(content, None, np.int64).tolist())

content = [[1,2], [3], [4]]
self.assertListEqual([[1,2], [3, 0], [4, 0]],
padder(content, None, np.int64).tolist())
# 1维int
content = [[1, 2, 3], [4,], [5, 6, 7, 8]]
padded_content = [[1, 2, 3, 0], [4, 0, 0, 0], [5, 6, 7, 8]]
self.assertListEqual(padder(content, None, int, 1).tolist(), padded_content)


# 二维int
padded_content = [[[1, 2, 3, 0], [4, 5, 0, 0], [7, 8, 9, 10]], [[1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]
content = [ content = [
[[1, 2, 3], [4, 5], [7,8,9,10]],
[[1]]
]
self.assertListEqual(content,
padder(content, None, np.int64).tolist())
[[1, 2, 3], [4, 5], [7, 8, 9, 10]],
[[1]]
]
self.assertListEqual(padder(content, None, int, 2).tolist(), padded_content)

# 3维图片
contents = [np.random.rand(3, 4, 4).tolist() for _ in range(5)]
self.assertTrue(padder(contents, None, float, 3).shape==(5, 3, 4, 4))

# 更高维度直接返回
contents = [np.random.rand(24, 3, 4, 4).tolist() for _ in range(5)]
self.assertTrue(isinstance(padder(contents, None, float, 4), np.ndarray))


def test02(self): def test02(self):
padder = AutoPadder()
# 测试numpy的情况
# 0维
contents = np.arange(12)
self.assertListEqual(padder(contents, None, contents.dtype, 0).tolist(), contents.tolist())

# 1维
contents = np.arange(12).reshape((3, 4))
self.assertListEqual(padder(contents, None, contents.dtype, 1).tolist(), contents.tolist())

# 2维
contents = np.ones((3, 10, 5))
self.assertListEqual(padder(contents, None, contents.dtype, 2).tolist(), contents.tolist())

# 3维
contents = [np.random.rand(3, 4, 4) for _ in range(5)]
l_contents = [content.tolist() for content in contents]
self.assertListEqual(padder(contents, None, contents[0].dtype, 3).tolist(), l_contents)

def test03(self):
padder = AutoPadder()
# 测试tensor的情况
# 0维
contents = torch.arange(12)
r_contents = padder(contents, None, contents.dtype, 0)
self.assertSequenceEqual(r_contents.tolist(), contents.tolist())
self.assertTrue(r_contents.dtype==contents.dtype)

# 0维
contents = [torch.tensor(1) for _ in range(10)]
self.assertSequenceEqual(padder(contents, None, torch.int64, 0).tolist(), contents)

# 1维
contents = torch.randn(3, 4)
padder(contents, None, torch.float64, 1)

# 3维
contents = [torch.randn(3, 4, 4) for _ in range(5)]
padder(contents, None, torch.float64, 3)



class TestEngChar2DPadder(unittest.TestCase):
def test01(self):
""" """
测试EngChar2DPadder能不能正确使用 测试EngChar2DPadder能不能正确使用
:return: :return:
@@ -198,38 +286,31 @@ class TestPadder(unittest.TestCase):
padder = EngChar2DPadder(pad_length=0) padder = EngChar2DPadder(pad_length=0)


contents = [1, 2] contents = [1, 2]
# 不能是1
with self.assertRaises(ValueError):
padder(contents, None, np.int64)
# 不能是0
with self.assertRaises(Exception):
padder(contents, None, np.int64, 0)
contents = [[1, 2]] contents = [[1, 2]]
# 不能是2维
with self.assertRaises(ValueError):
padder(contents, None, np.int64)
contents = [[[[1, 2]]]]
# 不能是1维
with self.assertRaises(Exception):
padder(contents, None, np.int64, 1)
contents = [
[[[[1, 2]]]]
]
# 不能是3维以上 # 不能是3维以上
with self.assertRaises(ValueError):
padder(contents, None, np.int64)
with self.assertRaises(Exception):
padder(contents, None, np.int64, 3)


contents = [ contents = [
[[1, 2, 3], [4, 5], [7,8,9,10]], [[1, 2, 3], [4, 5], [7,8,9,10]],
[[1]] [[1]]
] ]
self.assertListEqual([[[1, 2, 3, 0], [4, 5, 0, 0], [7, 8, 9, 10]], [[1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]], self.assertListEqual([[[1, 2, 3, 0], [4, 5, 0, 0], [7, 8, 9, 10]], [[1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]],
padder(contents, None, np.int64).tolist())
padder(contents, None, np.int64, 2).tolist())


padder = EngChar2DPadder(pad_length=5, pad_val=-100) padder = EngChar2DPadder(pad_length=5, pad_val=-100)
self.assertListEqual( self.assertListEqual(
[[[1, 2, 3, -100, -100], [4, 5, -100, -100, -100], [7, 8, 9, 10, -100]], [[[1, 2, 3, -100, -100], [4, 5, -100, -100, -100], [7, 8, 9, 10, -100]],
[[1, -100, -100, -100, -100], [-100, -100, -100, -100, -100], [-100, -100, -100, -100, -100]]], [[1, -100, -100, -100, -100], [-100, -100, -100, -100, -100], [-100, -100, -100, -100, -100]]],
padder(contents, None, np.int64).tolist()
padder(contents, None, np.int64, 2).tolist()
) )


def test_None_dtype(self):
from fastNLP import AutoPadder
padder = AutoPadder()
content = [
[[1, 2, 3], [4, 5], [7, 8, 9, 10]],
[[1]]
]
ans = padder(content, None, None).tolist()
self.assertListEqual(content, ans)

Loading…
Cancel
Save