@@ -0,0 +1,16 @@ | |||||
.gitignore | |||||
.DS_Store | |||||
.ipynb_checkpoints | |||||
*.pyc | |||||
__pycache__ | |||||
*.swp | |||||
.vscode/ | |||||
.idea/** | |||||
caches | |||||
# fitlog | |||||
.fitlog | |||||
logs/ | |||||
.fitconfig |
@@ -8,7 +8,7 @@ install: | |||||
- pip install pytest-cov | - pip install pytest-cov | ||||
# command to run tests | # command to run tests | ||||
script: | script: | ||||
- pytest --cov=./ | |||||
- pytest --cov=./ test/ | |||||
after_success: | after_success: | ||||
- bash <(curl -s https://codecov.io/bash) | - bash <(curl -s https://codecov.io/bash) |
@@ -92,7 +92,7 @@ http://docutils.sf.net/ 孤立的网址会自动生成链接 | |||||
各种连接 | 各种连接 | ||||
=========== | =========== | ||||
:doc:`/user/with_fitlog.rst` | |||||
:doc:`/user/with_fitlog` | |||||
:mod:`~fastNLP.core.batch` | :mod:`~fastNLP.core.batch` | ||||
@@ -12,6 +12,7 @@ from queue import Empty, Full | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
import torch.multiprocessing as mp | import torch.multiprocessing as mp | ||||
from numbers import Number | |||||
from .sampler import RandomSampler | from .sampler import RandomSampler | ||||
@@ -78,8 +79,10 @@ class Batch(object): | |||||
for field_name, field in self.dataset.get_all_fields().items(): | for field_name, field in self.dataset.get_all_fields().items(): | ||||
if field.is_target or field.is_input: | if field.is_target or field.is_input: | ||||
batch = field.get(indices) | batch = field.get(indices) | ||||
if not self.as_numpy and field.padder is not None: | |||||
batch = _to_tensor(batch, field.dtype) | |||||
if not self.as_numpy and \ | |||||
field.dtype is not None and \ | |||||
issubclass(field.dtype, Number) and not isinstance(batch, torch.Tensor): | |||||
batch = _to_tensor(batch) | |||||
if field.is_target: | if field.is_target: | ||||
batch_y[field_name] = batch | batch_y[field_name] = batch | ||||
if field.is_input: | if field.is_input: | ||||
@@ -174,12 +177,12 @@ class Batch(object): | |||||
# print('iter done') | # print('iter done') | ||||
def _to_tensor(batch, dtype): | |||||
def _to_tensor(batch): | |||||
try: | try: | ||||
if dtype in (int, np.int8, np.int16, np.int32, np.int64): | |||||
batch = torch.LongTensor(batch) | |||||
if dtype in (float, np.float32, np.float64): | |||||
batch = torch.FloatTensor(batch) | |||||
if issubclass(batch.dtype.type, np.floating): | |||||
batch = torch.as_tensor(batch).float() # 默认使用float32 | |||||
else: | |||||
batch = torch.as_tensor(batch) # 复用内存地址,避免复制 | |||||
except: | except: | ||||
pass | pass | ||||
return batch | return batch |
@@ -285,7 +285,8 @@ from .field import AutoPadder | |||||
from .field import FieldArray | from .field import FieldArray | ||||
from .instance import Instance | from .instance import Instance | ||||
from .utils import _get_func_signature | from .utils import _get_func_signature | ||||
from .field import AppendToTargetOrInputException | |||||
from .field import SetInputOrTargetException | |||||
class DataSet(object): | class DataSet(object): | ||||
""" | """ | ||||
@@ -422,7 +423,7 @@ class DataSet(object): | |||||
if len(self.field_arrays) == 0: | if len(self.field_arrays) == 0: | ||||
# DataSet has no field yet | # DataSet has no field yet | ||||
for name, field in instance.fields.items(): | for name, field in instance.fields.items(): | ||||
field = field.tolist() if isinstance(field, np.ndarray) else field | |||||
# field = field.tolist() if isinstance(field, np.ndarray) else field | |||||
self.field_arrays[name] = FieldArray(name, [field]) # 第一个样本,必须用list包装起来 | self.field_arrays[name] = FieldArray(name, [field]) # 第一个样本,必须用list包装起来 | ||||
else: | else: | ||||
if len(self.field_arrays) != len(instance.fields): | if len(self.field_arrays) != len(instance.fields): | ||||
@@ -431,7 +432,11 @@ class DataSet(object): | |||||
.format(len(self.field_arrays), len(instance.fields))) | .format(len(self.field_arrays), len(instance.fields))) | ||||
for name, field in instance.fields.items(): | for name, field in instance.fields.items(): | ||||
assert name in self.field_arrays | assert name in self.field_arrays | ||||
self.field_arrays[name].append(field) | |||||
try: | |||||
self.field_arrays[name].append(field) | |||||
except AppendToTargetOrInputException as e: | |||||
print(f"Cannot append to field:{name}.") | |||||
raise e | |||||
def add_fieldarray(self, field_name, fieldarray): | def add_fieldarray(self, field_name, fieldarray): | ||||
""" | """ | ||||
@@ -565,7 +570,11 @@ class DataSet(object): | |||||
assert isinstance(flag, bool), "Only bool type supported." | assert isinstance(flag, bool), "Only bool type supported." | ||||
for name in field_names: | for name in field_names: | ||||
if name in self.field_arrays: | if name in self.field_arrays: | ||||
self.field_arrays[name].is_target = flag | |||||
try: | |||||
self.field_arrays[name].is_target = flag | |||||
except SetInputOrTargetException as e: | |||||
print(f"Cannot set field:{name} as target.") | |||||
raise e | |||||
else: | else: | ||||
raise KeyError("{} is not a valid field name.".format(name)) | raise KeyError("{} is not a valid field name.".format(name)) | ||||
@@ -581,7 +590,11 @@ class DataSet(object): | |||||
""" | """ | ||||
for name in field_names: | for name in field_names: | ||||
if name in self.field_arrays: | if name in self.field_arrays: | ||||
self.field_arrays[name].is_input = flag | |||||
try: | |||||
self.field_arrays[name].is_input = flag | |||||
except SetInputOrTargetException as e: | |||||
print(f"Cannot set field:{name} as input.") | |||||
raise e | |||||
else: | else: | ||||
raise KeyError("{} is not a valid field name.".format(name)) | raise KeyError("{} is not a valid field name.".format(name)) | ||||
@@ -1,251 +1,162 @@ | |||||
""" | |||||
field模块实现了 FieldArray 和若干 Padder。 FieldArray 是 :class:`~fastNLP.DataSet` 中一列的存储方式, | |||||
原理部分请参考 :doc:`fastNLP.core.dataset` | |||||
""" | |||||
__all__ = [ | |||||
"FieldArray", | |||||
"Padder", | |||||
"AutoPadder", | |||||
"EngChar2DPadder" | |||||
] | |||||
from copy import deepcopy | |||||
from numbers import Number | |||||
import torch | |||||
import numpy as np | import numpy as np | ||||
from typing import Any | |||||
from abc import abstractmethod | |||||
from copy import deepcopy | |||||
class FieldArray(object): | |||||
""" | |||||
别名::class:`fastNLP.FieldArray` :class:`fastNLP.core.field.FieldArray` | |||||
FieldArray 是用于保存 :class:`~fastNLP.DataSet` 中一个field的类型。 | |||||
:param str name: FieldArray的名称 | |||||
:param list,numpy.ndarray content: 列表的元素可以为list,int,float, | |||||
:param bool is_target: 这个field是否是一个target field。 | |||||
:param bool is_input: 这个field是否是一个input field。 | |||||
:param padder: :class:`~fastNLP.Padder` 类型。赋值给fieldarray的padder的对象会被deepcopy一份,需要修改padder参数必须通过 | |||||
fieldarray.set_pad_val()。默认为None,即使用 :class:`~fastNLP.AutoPadder` 。 | |||||
:param bool ignore_type: 是否忽略该field的type,一般如果这个field不需要转为torch.FloatTensor或torch.LongTensor, | |||||
就可以设置为True。具体意义请参考 :class:`~fastNLP.DataSet` 。 | |||||
""" | |||||
def __init__(self, name, content, is_target=None, is_input=None, padder=None, ignore_type=False): | |||||
class SetInputOrTargetException(Exception): | |||||
def __init__(self, msg, index=None, field_name=None): | |||||
super().__init__(msg) | |||||
self.msg = msg | |||||
self.index = index # 标示在哪个数据遭遇到问题了 | |||||
self.field_name = field_name # 标示当前field的名称 | |||||
class AppendToTargetOrInputException(Exception): | |||||
def __init__(self, msg, index=None, field_name=None): | |||||
super().__init__(msg) | |||||
self.msg = msg | |||||
self.index = index # 标示在哪个数据遭遇到问题了 | |||||
self.field_name = field_name # 标示当前field的名称 | |||||
class FieldArray: | |||||
def __init__(self, name, content, is_target=False, is_input=False, padder=None, ignore_type=False): | |||||
if len(content)==0: | |||||
raise RuntimeError("Empty fieldarray is not allowed.") | |||||
_content = content | |||||
try: | |||||
_content = list(_content) | |||||
except BaseException as e: | |||||
print(f"Cannot convert content(of type:{type(content)}) into list.") | |||||
raise e | |||||
self.name = name | self.name = name | ||||
if isinstance(content, list): | |||||
# 如果DataSet使用dict初始化, content 可能是二维list/二维array/三维list | |||||
# 如果DataSet使用list of Instance 初始化, content可能是 [list]/[array]/[2D list] | |||||
for idx, item in enumerate(content): | |||||
# 这是使用list of Instance 初始化时第一个样本:FieldArray(name, [field]) | |||||
# 将[np.array] 转化为 list of list | |||||
# 也可以支持[array, array, array]的情况 | |||||
if isinstance(item, np.ndarray): | |||||
content[idx] = content[idx].tolist() | |||||
elif isinstance(content, np.ndarray): | |||||
content = content.tolist() # convert np.ndarray into 2-D list | |||||
else: | |||||
raise TypeError("content in FieldArray can only be list or numpy.ndarray, got {}.".format(type(content))) | |||||
if len(content) == 0: | |||||
raise RuntimeError("Cannot initialize FieldArray with empty list.") | |||||
self.content = content # 1维 或 2维 或 3维 list, 形状可能不对齐 | |||||
self.content_dim = None # 表示content是多少维的list | |||||
self.content = _content | |||||
self._ignore_type = ignore_type | |||||
# 根据input的情况设置input,target等 | |||||
self._cell_ndim = None # 多少维度 | |||||
self.dtype = None # 最内层的element都是什么类型的 | |||||
self._is_input = False | |||||
self._is_target = False | |||||
if is_input: | |||||
self.is_input = is_input | |||||
if is_target: | |||||
self.is_target = is_target | |||||
if padder is None: | if padder is None: | ||||
padder = AutoPadder(pad_val=0) | padder = AutoPadder(pad_val=0) | ||||
else: | else: | ||||
assert isinstance(padder, Padder), "padder must be of type Padder." | |||||
assert isinstance(padder, Padder), "padder must be of type fastNLP.Padder." | |||||
padder = deepcopy(padder) | padder = deepcopy(padder) | ||||
self.set_padder(padder) | self.set_padder(padder) | ||||
self.ignore_type = ignore_type | |||||
self.BASIC_TYPES = (int, float, str) # content中可接受的Python基本类型,这里没有np.array | |||||
self.pytype = None | |||||
self.dtype = None | |||||
self._is_input = None | |||||
self._is_target = None | |||||
if is_input is not None or is_target is not None: | |||||
self.is_input = is_input | |||||
self.is_target = is_target | |||||
def _set_dtype(self): | |||||
if self.ignore_type is False: | |||||
self.pytype = self._type_detection(self.content) | |||||
self.dtype = self._map_to_np_type(self.pytype) | |||||
@property | |||||
def ignore_type(self): | |||||
return self._ignore_type | |||||
@ignore_type.setter | |||||
def ignore_type(self, value): | |||||
if value: | |||||
self._cell_ndim = None | |||||
self.dtype = None | |||||
@property | @property | ||||
def is_input(self): | def is_input(self): | ||||
return self._is_input | return self._is_input | ||||
@is_input.setter | @is_input.setter | ||||
def is_input(self, value): | def is_input(self, value): | ||||
""" | """ | ||||
当 field_array.is_input = True / False 时被调用 | 当 field_array.is_input = True / False 时被调用 | ||||
""" | """ | ||||
if value is True: | |||||
self._set_dtype() | |||||
# 如果(value为True)且(_is_input和_is_target都是False)且(ignore_type为False) | |||||
if value is True and \ | |||||
self._is_target is False and \ | |||||
self._ignore_type is False: | |||||
self._check_dtype_and_ndim() | |||||
if value is False and self._is_target is False: | |||||
self.dtype = None | |||||
self._cell_ndim = None | |||||
self._is_input = value | self._is_input = value | ||||
@property | @property | ||||
def is_target(self): | def is_target(self): | ||||
return self._is_target | return self._is_target | ||||
@is_target.setter | @is_target.setter | ||||
def is_target(self, value): | def is_target(self, value): | ||||
""" | """ | ||||
当 field_array.is_target = True / False 时被调用 | 当 field_array.is_target = True / False 时被调用 | ||||
""" | """ | ||||
if value is True: | |||||
self._set_dtype() | |||||
if value is True and \ | |||||
self._is_input is False and \ | |||||
self._ignore_type is False: | |||||
self._check_dtype_and_ndim() | |||||
if value is False and self._is_input is False: | |||||
self.dtype = None | |||||
self._cell_ndim = None | |||||
self._is_target = value | self._is_target = value | ||||
def _type_detection(self, content): | |||||
""" | |||||
当该field被设置为is_input或者is_target时被调用 | |||||
def _check_dtype_and_ndim(self): | |||||
""" | """ | ||||
if len(content) == 0: | |||||
raise RuntimeError("Empty list in Field {}.".format(self.name)) | |||||
type_set = set([type(item) for item in content]) | |||||
if list in type_set: | |||||
if len(type_set) > 1: | |||||
# list 跟 非list 混在一起 | |||||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) | |||||
# >1维list | |||||
inner_type_set = set() | |||||
for l in content: | |||||
[inner_type_set.add(type(obj)) for obj in l] | |||||
if list not in inner_type_set: | |||||
# 二维list | |||||
self.content_dim = 2 | |||||
return self._basic_type_detection(inner_type_set) | |||||
else: | |||||
if len(inner_type_set) == 1: | |||||
# >2维list | |||||
inner_inner_type_set = set() | |||||
for _2d_list in content: | |||||
for _1d_list in _2d_list: | |||||
[inner_inner_type_set.add(type(obj)) for obj in _1d_list] | |||||
if list in inner_inner_type_set: | |||||
raise RuntimeError("FieldArray cannot handle 4-D or more-D list.") | |||||
# 3维list | |||||
self.content_dim = 3 | |||||
return self._basic_type_detection(inner_inner_type_set) | |||||
else: | |||||
# list 跟 非list 混在一起 | |||||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(inner_type_set))) | |||||
else: | |||||
# 一维list | |||||
for content_type in type_set: | |||||
if content_type not in self.BASIC_TYPES: | |||||
raise RuntimeError("Unexpected data type in Field '{}'. Expect one of {}. Got {}.".format( | |||||
self.name, self.BASIC_TYPES, content_type)) | |||||
self.content_dim = 1 | |||||
return self._basic_type_detection(type_set) | |||||
def _basic_type_detection(self, type_set): | |||||
检查当前content所有的element是否是同一个类型,且是否每个元素具有相同的维度。通过的话,设置_cell_ndim与_ele_type属性;没有 | |||||
通过将直接报错. | |||||
:return: | |||||
""" | """ | ||||
:param type_set: a set of Python types | |||||
:return: one of self.BASIC_TYPES | |||||
cell_0 = self.content[0] | |||||
index = 0 | |||||
try: | |||||
type_0, dim_0 = _get_ele_type_and_dim(cell_0) | |||||
for cell in self.content[1:]: | |||||
index += 1 | |||||
type_i, dim_i = _get_ele_type_and_dim(cell) | |||||
if type_i!=type_0: | |||||
raise SetInputOrTargetException("Type:{} in index {} is different from the first element with type:{}." | |||||
".".format(type_i, index, type_0)) | |||||
if dim_0!=dim_i: | |||||
raise SetInputOrTargetException("Dimension:{} in index {} is different from the first element with " | |||||
"dimension:{}.".format(dim_i, index, dim_0)) | |||||
self._cell_ndim = dim_0 | |||||
self.dtype = type_0 | |||||
except SetInputOrTargetException as e: | |||||
e.index = index | |||||
raise e | |||||
def append(self, val:Any): | |||||
""" | |||||
:param val: 把该val append到fieldarray。 | |||||
:return: | |||||
""" | """ | ||||
if len(type_set) == 1: | |||||
return type_set.pop() | |||||
elif len(type_set) == 2: | |||||
# 有多个basic type; 可能需要up-cast | |||||
if float in type_set and int in type_set: | |||||
# up-cast int to float | |||||
return float | |||||
else: | |||||
# str 跟 int 或者 float 混在一起 | |||||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) | |||||
if (self._is_target or self._is_input) and self._ignore_type is False: | |||||
type_, dim_ = _get_ele_type_and_dim(val) | |||||
if self.dtype!=type_: | |||||
raise AppendToTargetOrInputException(f"Value(type:{type_}) are of different types with " | |||||
f"previous values(type:{self.dtype}).") | |||||
if self._cell_ndim!=dim_: | |||||
raise AppendToTargetOrInputException(f"Value(dim:{dim_}) are of different dimensions with " | |||||
f"previous values(dim:{self._cell_ndim}).") | |||||
self.content.append(val) | |||||
else: | else: | ||||
# str, int, float混在一起 | |||||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) | |||||
def _1d_list_check(self, val): | |||||
"""如果不是1D list就报错 | |||||
""" | |||||
type_set = set((type(obj) for obj in val)) | |||||
if any(obj not in self.BASIC_TYPES for obj in type_set): | |||||
raise ValueError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) | |||||
self._basic_type_detection(type_set) | |||||
# otherwise: _basic_type_detection will raise error | |||||
return True | |||||
def _2d_list_check(self, val): | |||||
"""如果不是2D list 就报错 | |||||
""" | |||||
type_set = set(type(obj) for obj in val) | |||||
if list(type_set) != [list]: | |||||
raise ValueError("Mixed data types in Field {}: {}".format(self.name, type_set)) | |||||
inner_type_set = set() | |||||
for l in val: | |||||
for obj in l: | |||||
inner_type_set.add(type(obj)) | |||||
self._basic_type_detection(inner_type_set) | |||||
return True | |||||
@staticmethod | |||||
def _map_to_np_type(basic_type): | |||||
type_mapping = {int: np.int64, float: np.float64, str: np.str, np.ndarray: np.ndarray} | |||||
return type_mapping[basic_type] | |||||
def __repr__(self): | |||||
return "FieldArray {}: {}".format(self.name, self.content.__repr__()) | |||||
def append(self, val): | |||||
"""将val append到这个field的尾部。如果这个field已经被设置为input或者target,则在append之前会检查该类型是否与已有 | |||||
的内容是匹配的。 | |||||
:param Any val: 需要append的值。 | |||||
""" | |||||
if self.ignore_type is False: | |||||
if isinstance(val, list): | |||||
pass | |||||
elif isinstance(val, tuple): # 确保最外层是list | |||||
val = list(val) | |||||
elif isinstance(val, np.ndarray): | |||||
val = val.tolist() | |||||
elif any((isinstance(val, t) for t in self.BASIC_TYPES)): | |||||
pass | |||||
else: | |||||
raise RuntimeError( | |||||
"Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) | |||||
if self.is_input is True or self.is_target is True: | |||||
if type(val) == list: | |||||
if len(val) == 0: | |||||
raise ValueError("Cannot append an empty list.") | |||||
if self.content_dim == 2 and self._1d_list_check(val): | |||||
# 1维list检查 | |||||
pass | |||||
elif self.content_dim == 3 and self._2d_list_check(val): | |||||
# 2维list检查 | |||||
pass | |||||
else: | |||||
raise RuntimeError( | |||||
"Dimension not matched: expect dim={}, got {}.".format(self.content_dim - 1, val)) | |||||
elif type(val) in self.BASIC_TYPES and self.content_dim == 1: | |||||
# scalar检查 | |||||
if type(val) == float and self.pytype == int: | |||||
self.pytype = float | |||||
self.dtype = self._map_to_np_type(self.pytype) | |||||
else: | |||||
raise RuntimeError( | |||||
"Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) | |||||
self.content.append(val) | |||||
self.content.append(val) | |||||
def __getitem__(self, indices): | def __getitem__(self, indices): | ||||
return self.get(indices, pad=False) | return self.get(indices, pad=False) | ||||
def __setitem__(self, idx, val): | def __setitem__(self, idx, val): | ||||
assert isinstance(idx, int) | assert isinstance(idx, int) | ||||
if (self._is_target or self._is_input) and self.ignore_type is False: # 需要检测类型 | |||||
type_, dim_ = _get_ele_type_and_dim(val) | |||||
if self.dtype!=type_: | |||||
raise RuntimeError(f"Value(type:{type_}) are of different types with " | |||||
f"other values(type:{self.dtype}).") | |||||
if self._cell_ndim!=dim_: | |||||
raise RuntimeError(f"Value(dim:{dim_}) are of different dimensions with " | |||||
f"previous values(dim:{self._cell_ndim}).") | |||||
self.content[idx] = val | self.content[idx] = val | ||||
def get(self, indices, pad=True): | def get(self, indices, pad=True): | ||||
""" | """ | ||||
根据给定的indices返回内容 | 根据给定的indices返回内容 | ||||
@@ -257,14 +168,14 @@ class FieldArray(object): | |||||
if isinstance(indices, int): | if isinstance(indices, int): | ||||
return self.content[indices] | return self.content[indices] | ||||
if self.is_input is False and self.is_target is False: | if self.is_input is False and self.is_target is False: | ||||
raise RuntimeError("Please specify either is_input or is_target is True for {}".format(self.name)) | |||||
raise RuntimeError("Please specify either is_input or is_target to True for {}".format(self.name)) | |||||
contents = [self.content[i] for i in indices] | contents = [self.content[i] for i in indices] | ||||
if self.padder is None or pad is False: | if self.padder is None or pad is False: | ||||
return np.array(contents) | return np.array(contents) | ||||
else: | else: | ||||
return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype) | |||||
return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim) | |||||
def set_padder(self, padder): | def set_padder(self, padder): | ||||
""" | """ | ||||
设置padder,在这个field进行pad的时候用这个padder进行pad,如果为None则不进行pad。 | 设置padder,在这个field进行pad的时候用这个padder进行pad,如果为None则不进行pad。 | ||||
@@ -276,7 +187,7 @@ class FieldArray(object): | |||||
self.padder = deepcopy(padder) | self.padder = deepcopy(padder) | ||||
else: | else: | ||||
self.padder = None | self.padder = None | ||||
def set_pad_val(self, pad_val): | def set_pad_val(self, pad_val): | ||||
""" | """ | ||||
修改padder的pad_val. | 修改padder的pad_val. | ||||
@@ -286,7 +197,7 @@ class FieldArray(object): | |||||
if self.padder is not None: | if self.padder is not None: | ||||
self.padder.set_pad_val(pad_val) | self.padder.set_pad_val(pad_val) | ||||
return self | return self | ||||
def __len__(self): | def __len__(self): | ||||
""" | """ | ||||
Returns the size of FieldArray. | Returns the size of FieldArray. | ||||
@@ -294,7 +205,7 @@ class FieldArray(object): | |||||
:return int length: | :return int length: | ||||
""" | """ | ||||
return len(self.content) | return len(self.content) | ||||
def to(self, other): | def to(self, other): | ||||
""" | """ | ||||
将other的属性复制给本FieldArray(other必须为FieldArray类型). | 将other的属性复制给本FieldArray(other必须为FieldArray类型). | ||||
@@ -303,22 +214,63 @@ class FieldArray(object): | |||||
:param other: :class:`~fastNLP.FieldArray` 从哪个field拷贝属性 | :param other: :class:`~fastNLP.FieldArray` 从哪个field拷贝属性 | ||||
:return: :class:`~fastNLP.FieldArray` | :return: :class:`~fastNLP.FieldArray` | ||||
""" | """ | ||||
assert isinstance(other, FieldArray), "Only support FieldArray type, not {}.".format(type(other)) | |||||
assert isinstance(other, FieldArray), "Only supports fastNLP.FieldArray type, not {}.".format(type(other)) | |||||
self.ignore_type = other.ignore_type | |||||
self.is_input = other.is_input | self.is_input = other.is_input | ||||
self.is_target = other.is_target | self.is_target = other.is_target | ||||
self.padder = other.padder | self.padder = other.padder | ||||
self.ignore_type = other.ignore_type | |||||
return self | return self | ||||
def _is_iterable(content): | |||||
def _get_ele_type_and_dim(cell:Any, dim=0): | |||||
""" | |||||
识别cell的类别与dimension的数量 | |||||
numpy scalar type:https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html | |||||
:param cell: | |||||
:param dim: | |||||
:return: | |||||
""" | |||||
if isinstance(cell, (str, Number, np.bool_)): | |||||
return type(cell), dim | |||||
elif isinstance(cell, list): | |||||
dim += 1 | |||||
res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] | |||||
types = set([i for i,j in res]) | |||||
dims = set([j for i,j in res]) | |||||
if len(types)>1: | |||||
raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) | |||||
if len(dims)>1: | |||||
raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) | |||||
return types.pop(), dims.pop() | |||||
elif isinstance(cell, torch.Tensor): | |||||
return cell.dtype, cell.dim() + dim # 如果是torch.mean的结果是0 | |||||
elif isinstance(cell, np.ndarray): | |||||
if cell.dtype != np.dtype('O'): # 如果不是object的话说明是well-formatted的了 | |||||
return cell.dtype.type, cell.ndim + dim | |||||
# 否则需要继续往下iterate | |||||
dim += 1 | |||||
res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] | |||||
types = set([i for i,j in res]) | |||||
dims = set([j for i,j in res]) | |||||
if len(types)>1: | |||||
raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) | |||||
if len(dims)>1: | |||||
raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) | |||||
return types.pop(), dims.pop() | |||||
else: # 包含tuple, set, dict以及其它的类型 | |||||
raise SetInputOrTargetException(f"Cannot process type:{type(cell)}.") | |||||
def _is_iterable(value): | |||||
# 检查是否是iterable的, duck typing | |||||
try: | try: | ||||
_ = (e for e in content) | |||||
except TypeError: | |||||
iter(value) | |||||
return True | |||||
except BaseException as e: | |||||
return False | return False | ||||
return True | |||||
class Padder: | class Padder: | ||||
@@ -327,32 +279,35 @@ class Padder: | |||||
所有padder都需要继承这个类,并覆盖__call__方法。 | 所有padder都需要继承这个类,并覆盖__call__方法。 | ||||
用于对batch进行padding操作。传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前deepcopy一份。 | 用于对batch进行padding操作。传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前deepcopy一份。 | ||||
.. py:function:: __call__(self, contents, field_name, field_ele_dtype): | .. py:function:: __call__(self, contents, field_name, field_ele_dtype): | ||||
传入的是List内容。假设有以下的DataSet。 | 传入的是List内容。假设有以下的DataSet。 | ||||
:param list(Any) contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前 | :param list(Any) contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前 | ||||
deepcopy一份。 | deepcopy一份。 | ||||
:param str, field_name: field的名称。 | :param str, field_name: field的名称。 | ||||
:param np.int64,np.float64,np.str,None, field_ele_dtype: 该field的内层元素的类型。如果该field的ignore_type为True,该这个值为None。 | :param np.int64,np.float64,np.str,None, field_ele_dtype: 该field的内层元素的类型。如果该field的ignore_type为True,该这个值为None。 | ||||
:return: np.array([padded_element]) | :return: np.array([padded_element]) | ||||
""" | """ | ||||
def __init__(self, pad_val=0, **kwargs): | def __init__(self, pad_val=0, **kwargs): | ||||
self.pad_val = pad_val | self.pad_val = pad_val | ||||
def set_pad_val(self, pad_val): | def set_pad_val(self, pad_val): | ||||
self.pad_val = pad_val | self.pad_val = pad_val | ||||
def __call__(self, contents, field_name, field_ele_dtype): | |||||
@abstractmethod | |||||
def __call__(self, contents, field_name, field_ele_dtype, dim:int): | |||||
""" | """ | ||||
传入的是List内容。假设有以下的DataSet。 | 传入的是List内容。假设有以下的DataSet。 | ||||
:param list(Any) contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前 | :param list(Any) contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前 | ||||
deepcopy一份。 | deepcopy一份。 | ||||
:param str, field_name: field的名称。 | :param str, field_name: field的名称。 | ||||
:param np.int64,np.float64,np.str,None, field_ele_dtype: 该field的内层元素的类型。如果该field的ignore_type为True,该这个值为None。 | |||||
:param np.int64,np.float64,np.str,None, field_ele_dtype: 该field的内层元素的类型。如果该field的ignore_type为True, | |||||
该这个值为None。 | |||||
:param dim: 这个field的维度。当ignore_type为True时,该值为None | |||||
:return: np.array([padded_element]) | :return: np.array([padded_element]) | ||||
Example:: | Example:: | ||||
@@ -394,50 +349,87 @@ class AutoPadder(Padder): | |||||
根据contents的数据自动判定是否需要做padding。 | 根据contents的数据自动判定是否需要做padding。 | ||||
1 如果元素类型(元素类型是指field中最里层元素的数据类型, 可以通过FieldArray.dtype查看,比如['This', 'is', ...]的元素类 | 1 如果元素类型(元素类型是指field中最里层元素的数据类型, 可以通过FieldArray.dtype查看,比如['This', 'is', ...]的元素类 | ||||
型为np.str, [[1,2], ...]的元素类型为np.int64)的数据不为(np.int64, np.float64)则不会进行pad | |||||
型为str, [[1,2], ...]的元素类型为int)的数据不为数值类型则不会进行pad | |||||
2 如果元素类型为数值类型,比如np.int64, np.float64, int, float, torch.int64等 | |||||
2 如果元素类型为(np.int64, np.float64), | |||||
2.1 如果该field的内容为数值类型(包括int, float等),比如为seq_len, 则不进行padding | |||||
2.1 如果该field的内容为(np.int64, np.float64),比如为seq_len, 则不进行padding | |||||
2.2 如果该field的内容等价于一维list, 那么会将Batch中的List pad为一样长。 | |||||
2.2 如果该field的内容为List, 那么会将Batch中的List pad为一样长。若该List下还有里层的List需要padding,请使用其它padder。 | |||||
即如果Instance中field形如[1, 2, 3, ...],则可以pad;若为[[1,2], [3,4, ...]]则不能进行pad | |||||
2.3 如果该field的内容等价于二维list,那么会按照英语character padding的方式进行padding。如果是character padding建议使用 | |||||
:class: fastNLP.EngChar2DPadder. | |||||
2.4 如果该field的内容等价于三维list,则如果每个instance在每个维度上相等,会组成一个batch的tensor返回,这种情况应该是为图片 | |||||
的情况。 | |||||
3 其它情况不进行处理,返回一个np.array类型。 | |||||
""" | """ | ||||
def __init__(self, pad_val=0): | def __init__(self, pad_val=0): | ||||
""" | |||||
:param pad_val: int, padding的位置使用该index | |||||
""" | |||||
super().__init__(pad_val=pad_val) | super().__init__(pad_val=pad_val) | ||||
def _is_two_dimension(self, contents): | |||||
""" | |||||
判断contents是不是只有两个维度。[[1,2], [3]]是两个维度. [[[1,2], [3, 4, 5]], [[4,5]]]有三个维度 | |||||
:param contents: | |||||
:return: | |||||
""" | |||||
value = contents[0] | |||||
if isinstance(value, (np.ndarray, list)): | |||||
value = value[0] | |||||
if isinstance(value, (np.ndarray, list)): | |||||
return False | |||||
return True | |||||
return False | |||||
def __call__(self, contents, field_name, field_ele_dtype): | |||||
if not _is_iterable(contents[0]): | |||||
array = np.array([content for content in contents], dtype=field_ele_dtype) | |||||
elif field_ele_dtype in (np.int64, np.float64) and self._is_two_dimension(contents): | |||||
max_len = max([len(content) for content in contents]) | |||||
array = np.full((len(contents), max_len), self.pad_val, dtype=field_ele_dtype) | |||||
for i, content in enumerate(contents): | |||||
array[i][:len(content)] = content | |||||
elif field_ele_dtype is None: | |||||
array = np.array(contents) # 当ignore_type=True时,直接返回contents | |||||
else: # should only be str | |||||
array = np.array([content for content in contents]) | |||||
return array | |||||
def __call__(self, contents, field_name, field_ele_dtype, dim): | |||||
if field_ele_dtype: | |||||
if dim>3: | |||||
return np.array(contents) | |||||
if isinstance(field_ele_dtype, np.dtype) or field_ele_dtype in (float, int, bool, str): | |||||
if isinstance(field_ele_dtype, np.number) or field_ele_dtype in (float, int, bool): | |||||
if dim==0: | |||||
array = np.array(contents, dtype=field_ele_dtype) | |||||
elif dim==1: | |||||
max_len = max(map(len, contents)) | |||||
array = np.full((len(contents), max_len), self.pad_val, dtype=field_ele_dtype) | |||||
for i, content_i in enumerate(contents): | |||||
array[i, :len(content_i)] = content_i | |||||
elif dim==2: | |||||
max_len = max(map(len, contents)) | |||||
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for | |||||
content_i in contents]) | |||||
array = np.full((len(contents), max_len, max_word_len), self.pad_val, dtype=field_ele_dtype) | |||||
for i, content_i in enumerate(contents): | |||||
for j, content_ii in enumerate(content_i): | |||||
array[i, j, :len(content_ii)] = content_ii | |||||
else: | |||||
shape = np.shape(contents) | |||||
if len(shape)==4: # 说明各dimension是相同的大小 | |||||
array = np.array(contents, dtype=field_ele_dtype) | |||||
else: | |||||
raise RuntimeError(f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||||
return array | |||||
return np.array(contents) | |||||
elif str(field_ele_dtype).startswith('torch'): | |||||
if dim==0: | |||||
tensor = torch.tensor(contents).to(field_ele_dtype) | |||||
elif dim==1: | |||||
max_len = max(map(len, contents)) | |||||
tensor = torch.full((len(contents), max_len), fill_value=self.pad_val, dtype=field_ele_dtype) | |||||
for i, content_i in enumerate(contents): | |||||
tensor[i, :len(content_i)] = torch.tensor(content_i) | |||||
elif dim==2: | |||||
max_len = max(map(len, contents)) | |||||
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for | |||||
content_i in contents]) | |||||
tensor = torch.full((len(contents), max_len, max_word_len), fill_value=self.pad_val, | |||||
dtype=field_ele_dtype) | |||||
for i, content_i in enumerate(contents): | |||||
for j, content_ii in enumerate(content_i): | |||||
tensor[i, j, :len(content_ii)] = torch.tensor(content_ii) | |||||
else: | |||||
shapes = set([np.shape(content_i) for content_i in contents]) | |||||
if len(shapes)>1: | |||||
raise RuntimeError(f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||||
shape = shapes.pop() | |||||
if len(shape)==3: | |||||
tensor = torch.full([len(contents)]+list(shape), fill_value=self.pad_val, dtype=field_ele_dtype) | |||||
for i, content_i in enumerate(contents): | |||||
tensor[i] = torch.tensor(content_i, dtype=field_ele_dtype) | |||||
else: | |||||
raise RuntimeError(f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||||
return tensor | |||||
else: | |||||
return np.array(contents) # 不进行任何操作 | |||||
else: | |||||
return np.array(contents) | |||||
class EngChar2DPadder(Padder): | class EngChar2DPadder(Padder): | ||||
@@ -463,7 +455,7 @@ class EngChar2DPadder(Padder): | |||||
dataset.set_padder('chars', padder) # chars这个field的设置为了EnChar2DPadder | dataset.set_padder('chars', padder) # chars这个field的设置为了EnChar2DPadder | ||||
""" | """ | ||||
def __init__(self, pad_val=0, pad_length=0): | def __init__(self, pad_val=0, pad_length=0): | ||||
""" | """ | ||||
:param pad_val: int, pad的位置使用该index | :param pad_val: int, pad的位置使用该index | ||||
@@ -471,32 +463,10 @@ class EngChar2DPadder(Padder): | |||||
都pad或截取到该长度. | 都pad或截取到该长度. | ||||
""" | """ | ||||
super().__init__(pad_val=pad_val) | super().__init__(pad_val=pad_val) | ||||
self.pad_length = pad_length | self.pad_length = pad_length | ||||
def _exactly_three_dims(self, contents, field_name): | |||||
""" | |||||
检查传入的contents是否刚好是3维,如果不是3维就报错。理论上,第一个维度是batch,第二个维度是word,第三个维度是character | |||||
:param contents: | |||||
:param field_name: str | |||||
:return: | |||||
""" | |||||
if not isinstance(contents, list): | |||||
raise TypeError("contents should be a list, not {}.".format(type(contents))) | |||||
value = contents[0] | |||||
try: | |||||
value = value[0] | |||||
except: | |||||
raise ValueError("Field:{} only has one dimension.".format(field_name)) | |||||
try: | |||||
value = value[0] | |||||
except: | |||||
raise ValueError("Field:{} only has two dimensions.".format(field_name)) | |||||
if _is_iterable(value): | |||||
raise ValueError("Field:{} has more than 3 dimension.".format(field_name)) | |||||
def __call__(self, contents, field_name, field_ele_dtype): | |||||
def __call__(self, contents, field_name, field_ele_dtype, dim): | |||||
""" | """ | ||||
期望输入类似于 | 期望输入类似于 | ||||
[ | [ | ||||
@@ -510,11 +480,11 @@ class EngChar2DPadder(Padder): | |||||
:param field_ele_dtype | :param field_ele_dtype | ||||
:return: | :return: | ||||
""" | """ | ||||
if field_ele_dtype not in (np.int64, np.float64): | |||||
if field_ele_dtype not in (np.int64, np.float64, int, float): | |||||
raise TypeError('dtype of Field:{} should be np.int64 or np.float64 to do 2D padding, get {}.'.format( | raise TypeError('dtype of Field:{} should be np.int64 or np.float64 to do 2D padding, get {}.'.format( | ||||
field_name, field_ele_dtype | field_name, field_ele_dtype | ||||
)) | )) | ||||
self._exactly_three_dims(contents, field_name) | |||||
assert dim==2, f"Field:{field_name} has {dim}, EngChar2DPadder only supports input with 2 dimensions." | |||||
if self.pad_length < 1: | if self.pad_length < 1: | ||||
max_char_length = max([max(len(char_lst) for char_lst in word_lst) for word_lst in contents]) | max_char_length = max([max(len(char_lst) for char_lst in word_lst) for word_lst in contents]) | ||||
else: | else: | ||||
@@ -522,12 +492,12 @@ class EngChar2DPadder(Padder): | |||||
max_sent_length = max(len(word_lst) for word_lst in contents) | max_sent_length = max(len(word_lst) for word_lst in contents) | ||||
batch_size = len(contents) | batch_size = len(contents) | ||||
dtype = type(contents[0][0][0]) | dtype = type(contents[0][0][0]) | ||||
padded_array = np.full((batch_size, max_sent_length, max_char_length), fill_value=self.pad_val, | padded_array = np.full((batch_size, max_sent_length, max_char_length), fill_value=self.pad_val, | ||||
dtype=dtype) | dtype=dtype) | ||||
for b_idx, word_lst in enumerate(contents): | for b_idx, word_lst in enumerate(contents): | ||||
for c_idx, char_lst in enumerate(word_lst): | for c_idx, char_lst in enumerate(word_lst): | ||||
chars = char_lst[:max_char_length] | chars = char_lst[:max_char_length] | ||||
padded_array[b_idx, c_idx, :len(chars)] = chars | padded_array[b_idx, c_idx, :len(chars)] = chars | ||||
return padded_array | return padded_array |
@@ -440,7 +440,7 @@ def _bio_tag_to_spans(tags, ignore_labels=None): | |||||
class SpanFPreRecMetric(MetricBase): | class SpanFPreRecMetric(MetricBase): | ||||
""" | |||||
r""" | |||||
别名::class:`fastNLP.SpanFPreRecMetric` :class:`fastNLP.core.metrics.SpanFPreRecMetric` | 别名::class:`fastNLP.SpanFPreRecMetric` :class:`fastNLP.core.metrics.SpanFPreRecMetric` | ||||
在序列标注问题中,以span的方式计算F, pre, rec. | 在序列标注问题中,以span的方式计算F, pre, rec. | ||||
@@ -478,7 +478,7 @@ class SpanFPreRecMetric(MetricBase): | |||||
label的f1, pre, rec | label的f1, pre, rec | ||||
:param str f_type: 'micro'或'macro'. 'micro':通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; 'macro': | :param str f_type: 'micro'或'macro'. 'micro':通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; 'macro': | ||||
分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同) | 分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同) | ||||
:param float beta: f_beta分数,:math:`f_beta = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}`. | |||||
:param float beta: f_beta分数, :math:`f_{beta} = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` . | |||||
常用为beta=0.5, 1, 2. 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | 常用为beta=0.5, 1, 2. 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | ||||
""" | """ | ||||
@@ -701,16 +701,16 @@ def _pred_topk(y_prob, k=1): | |||||
class SQuADMetric(MetricBase): | class SQuADMetric(MetricBase): | ||||
""" | |||||
r""" | |||||
别名::class:`fastNLP.SQuADMetric` :class:`fastNLP.core.metrics.SQuADMetric` | 别名::class:`fastNLP.SQuADMetric` :class:`fastNLP.core.metrics.SQuADMetric` | ||||
SQuAD数据集metric | SQuAD数据集metric | ||||
:param pred1: 参数映射表中`pred1`的映射关系,None表示映射关系为`pred1`->`pred1` | |||||
:param pred2: 参数映射表中`pred2`的映射关系,None表示映射关系为`pred2`->`pred2` | |||||
:param target1: 参数映射表中`target1`的映射关系,None表示映射关系为`target1`->`target1` | |||||
:param target2: 参数映射表中`target2`的映射关系,None表示映射关系为`target2`->`target2` | |||||
:param float beta: f_beta分数,:math:`f_beta = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}`. | |||||
:param pred1: 参数映射表中 `pred1` 的映射关系,None表示映射关系为 `pred1` -> `pred1` | |||||
:param pred2: 参数映射表中 `pred2` 的映射关系,None表示映射关系为 `pred2` -> `pred2` | |||||
:param target1: 参数映射表中 `target1` 的映射关系,None表示映射关系为 `target1` -> `target1` | |||||
:param target2: 参数映射表中 `target2` 的映射关系,None表示映射关系为 `target2` -> `target2` | |||||
:param float beta: f_beta分数, :math:`f_{beta} = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` . | |||||
常用为beta=0.5, 1, 2. 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | 常用为beta=0.5, 1, 2. 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | ||||
:param bool right_open: right_open为true表示start跟end指针指向一个左闭右开区间,为false表示指向一个左闭右闭区间。 | :param bool right_open: right_open为true表示start跟end指针指向一个左闭右开区间,为false表示指向一个左闭右闭区间。 | ||||
:param bool print_predict_stat: True则输出预测答案是否为空与正确答案是否为空的统计信息, False则不输出 | :param bool print_predict_stat: True则输出预测答案是否为空与正确答案是否为空的统计信息, False则不输出 | ||||
@@ -532,7 +532,7 @@ class Trainer(object): | |||||
self._train() | self._train() | ||||
self.callback_manager.on_train_end() | self.callback_manager.on_train_end() | ||||
except Exception as e: | |||||
except BaseException as e: | |||||
self.callback_manager.on_exception(e) | self.callback_manager.on_exception(e) | ||||
if on_exception == 'auto': | if on_exception == 'auto': | ||||
if not isinstance(e, (CallbackException, KeyboardInterrupt)): | if not isinstance(e, (CallbackException, KeyboardInterrupt)): | ||||
@@ -28,6 +28,8 @@ from ..core.instance import Instance | |||||
from .file_reader import _read_csv, _read_json, _read_conll | from .file_reader import _read_csv, _read_json, _read_conll | ||||
from .base_loader import DataSetLoader | from .base_loader import DataSetLoader | ||||
from .data_loader.sst import SSTLoader | from .data_loader.sst import SSTLoader | ||||
from ..core.const import Const | |||||
class PeopleDailyCorpusLoader(DataSetLoader): | class PeopleDailyCorpusLoader(DataSetLoader): | ||||
""" | """ | ||||
@@ -257,9 +259,9 @@ class SNLILoader(JsonLoader): | |||||
def __init__(self): | def __init__(self): | ||||
fields = { | fields = { | ||||
'sentence1_parse': 'words1', | |||||
'sentence2_parse': 'words2', | |||||
'gold_label': 'target', | |||||
'sentence1_parse': Const.INPUTS(0), | |||||
'sentence2_parse': Const.INPUTS(1), | |||||
'gold_label': Const.TARGET, | |||||
} | } | ||||
super(SNLILoader, self).__init__(fields=fields) | super(SNLILoader, self).__init__(fields=fields) | ||||
@@ -271,10 +273,10 @@ class SNLILoader(JsonLoader): | |||||
return t.leaves() | return t.leaves() | ||||
ds.apply(lambda ins: parse_tree( | ds.apply(lambda ins: parse_tree( | ||||
ins['words1']), new_field_name='words1') | |||||
ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0)) | |||||
ds.apply(lambda ins: parse_tree( | ds.apply(lambda ins: parse_tree( | ||||
ins['words2']), new_field_name='words2') | |||||
ds.drop(lambda x: x['target'] == '-') | |||||
ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1)) | |||||
ds.drop(lambda x: x[Const.TARGET] == '-') | |||||
return ds | return ds | ||||
@@ -107,9 +107,9 @@ class EmbedLoader(BaseLoader): | |||||
:param bool normalize: 是否将每个vector归一化到norm为1 | :param bool normalize: 是否将每个vector归一化到norm为1 | ||||
:param str error: `ignore` , `strict` ; 如果 `ignore` ,错误将自动跳过; 如果 `strict` , 错误将抛出。这里主要可能出错的地 | :param str error: `ignore` , `strict` ; 如果 `ignore` ,错误将自动跳过; 如果 `strict` , 错误将抛出。这里主要可能出错的地 | ||||
方在于词表有空行或者词表出现了维度不一致。 | 方在于词表有空行或者词表出现了维度不一致。 | ||||
:return numpy.ndarray: shape为 [len(vocab), dimension], dimension由pretrain的embedding决定。 | |||||
:return numpy.ndarray: Vocabulary Embedding的shape是[词表大小+x, 词表维度], "词表大小+x"是由于最终的大小还取决与 | |||||
:return (numpy.ndarray, Vocabulary): Embedding的shape是[词表大小+x, 词表维度], "词表大小+x"是由于最终的大小还取决与 | |||||
是否使用padding, 以及unknown有没有在词表中找到对应的词。 Vocabulary中的词的顺序与Embedding的顺序是一一对应的。 | 是否使用padding, 以及unknown有没有在词表中找到对应的词。 Vocabulary中的词的顺序与Embedding的顺序是一一对应的。 | ||||
""" | """ | ||||
vocab = Vocabulary(padding=padding, unknown=unknown) | vocab = Vocabulary(padding=padding, unknown=unknown) | ||||
vec_dict = {} | vec_dict = {} | ||||
@@ -2,43 +2,28 @@ | |||||
这里复现了在fastNLP中实现的模型,旨在达到与论文中相符的性能。 | 这里复现了在fastNLP中实现的模型,旨在达到与论文中相符的性能。 | ||||
复现的模型有: | 复现的模型有: | ||||
- Star-Transformer | |||||
- [Star-Transformer](Star_transformer/) | |||||
- ... | - ... | ||||
# 任务复现 | |||||
## Text Classification (文本分类) | |||||
- still in progress | |||||
## Matching (自然语言推理/句子匹配) | |||||
- still in progress | |||||
## Sequence Labeling (序列标注) | |||||
- still in progress | |||||
## Coreference resolution (指代消解) | |||||
- still in progress | |||||
## Summarization (摘要) | |||||
- still in progress | |||||
## Star-Transformer | |||||
[reference](https://arxiv.org/abs/1902.09113) | |||||
### Performance (still in progress) | |||||
|任务| 数据集 | SOTA | 模型表现 | | |||||
|------|------| ------| ------| | |||||
|Pos Tagging|CTB 9.0|-|ACC 92.31| | |||||
|Pos Tagging|CONLL 2012|-|ACC 96.51| | |||||
|Named Entity Recognition|CONLL 2012|-|F1 85.66| | |||||
|Text Classification|SST|-|49.18| | |||||
|Natural Language Inference|SNLI|-|83.76| | |||||
### Usage | |||||
``` python | |||||
# for sequence labeling(ner, pos tagging, etc) | |||||
from fastNLP.models.star_transformer import STSeqLabel | |||||
model = STSeqLabel( | |||||
vocab_size=10000, num_cls=50, | |||||
emb_dim=300) | |||||
# for sequence classification | |||||
from fastNLP.models.star_transformer import STSeqCls | |||||
model = STSeqCls( | |||||
vocab_size=10000, num_cls=50, | |||||
emb_dim=300) | |||||
# for natural language inference | |||||
from fastNLP.models.star_transformer import STNLICls | |||||
model = STNLICls( | |||||
vocab_size=10000, num_cls=50, | |||||
emb_dim=300) | |||||
``` | |||||
## ... | ## ... |
@@ -0,0 +1,34 @@ | |||||
# Star-Transformer | |||||
paper: [Star-Transformer](https://arxiv.org/abs/1902.09113) | |||||
## Performance (still in progress) | |||||
|任务| 数据集 | SOTA | 模型表现 | | |||||
|------|------| ------| ------| | |||||
|Pos Tagging|CTB 9.0|-|ACC 92.31| | |||||
|Pos Tagging|CONLL 2012|-|ACC 96.51| | |||||
|Named Entity Recognition|CONLL 2012|-|F1 85.66| | |||||
|Text Classification|SST|-|49.18| | |||||
|Natural Language Inference|SNLI|-|83.76| | |||||
## Usage | |||||
``` python | |||||
# for sequence labeling(ner, pos tagging, etc) | |||||
from fastNLP.models.star_transformer import STSeqLabel | |||||
model = STSeqLabel( | |||||
vocab_size=10000, num_cls=50, | |||||
emb_dim=300) | |||||
# for sequence classification | |||||
from fastNLP.models.star_transformer import STSeqCls | |||||
model = STSeqCls( | |||||
vocab_size=10000, num_cls=50, | |||||
emb_dim=300) | |||||
# for natural language inference | |||||
from fastNLP.models.star_transformer import STNLICls | |||||
model = STNLICls( | |||||
vocab_size=10000, num_cls=50, | |||||
emb_dim=300) | |||||
``` |
@@ -0,0 +1,6 @@ | |||||
from fastNLP.io.dataset_loader import SNLILoader | |||||
# TODO: still in progress | |||||
@@ -0,0 +1,41 @@ | |||||
import torch | |||||
import torch.nn as nn | |||||
from fastNLP.core.const import Const | |||||
from fastNLP.models import BaseModel | |||||
from fastNLP.modules.encoder.bert import BertModel | |||||
class BertForNLI(BaseModel): | |||||
# TODO: still in progress | |||||
def __init__(self, class_num=3, bert_dir=None): | |||||
super(BertForNLI, self).__init__() | |||||
if bert_dir is not None: | |||||
self.bert = BertModel.from_pretrained(bert_dir) | |||||
else: | |||||
self.bert = BertModel() | |||||
hidden_size = self.bert.pooler.dense._parameters['bias'].size(-1) | |||||
self.classifier = nn.Linear(hidden_size, class_num) | |||||
def forward(self, words, seq_len1, seq_len2, target=None): | |||||
""" | |||||
:param torch.Tensor words: [batch_size, seq_len] input_ids | |||||
:param torch.Tensor seq_len1: [batch_size, seq_len] token_type_ids | |||||
:param torch.Tensor seq_len2: [batch_size, seq_len] attention_mask | |||||
:param torch.Tensor target: [batch] | |||||
:return: | |||||
""" | |||||
_, pooled_output = self.bert(words, seq_len1, seq_len2) | |||||
logits = self.classifier(pooled_output) | |||||
if target is not None: | |||||
loss_func = torch.nn.CrossEntropyLoss() | |||||
loss = loss_func(logits, target) | |||||
return {Const.OUTPUT: logits, Const.LOSS: loss} | |||||
return {Const.OUTPUT: logits} | |||||
def predict(self, words, seq_len1, seq_len2, target=None): | |||||
return self.forward(words, seq_len1, seq_len2) | |||||
@@ -0,0 +1,97 @@ | |||||
import os | |||||
import torch | |||||
from fastNLP.core import Vocabulary, DataSet, Trainer, Tester, Const, Adam, AccuracyMetric | |||||
from reproduction.matching.data.SNLIDataLoader import SNLILoader | |||||
from legacy.component.bert_tokenizer import BertTokenizer | |||||
from reproduction.matching.model.bert import BertForNLI | |||||
def preprocess_data(data: DataSet, bert_dir): | |||||
""" | |||||
preprocess data set to bert-need data set. | |||||
:param data: | |||||
:param bert_dir: | |||||
:return: | |||||
""" | |||||
tokenizer = BertTokenizer.from_pretrained(os.path.join(bert_dir, 'vocab.txt')) | |||||
vocab = Vocabulary(padding=None, unknown=None) | |||||
with open(os.path.join(bert_dir, 'vocab.txt')) as f: | |||||
lines = f.readlines() | |||||
vocab_list = [] | |||||
for line in lines: | |||||
vocab_list.append(line.strip()) | |||||
vocab.add_word_lst(vocab_list) | |||||
vocab.build_vocab() | |||||
vocab.padding = '[PAD]' | |||||
vocab.unknown = '[UNK]' | |||||
for i in range(2): | |||||
data.apply(lambda x: tokenizer.tokenize(" ".join(x[Const.INPUTS(i)])), | |||||
new_field_name=Const.INPUTS(i)) | |||||
data.apply(lambda x: ['[CLS]'] + x[Const.INPUTS(0)] + ['[SEP]'] + x[Const.INPUTS(1)] + ['[SEP]'], | |||||
new_field_name=Const.INPUT) | |||||
data.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1), | |||||
new_field_name=Const.INPUT_LENS(0)) | |||||
data.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), new_field_name=Const.INPUT_LENS(1)) | |||||
max_len = 512 | |||||
data.apply(lambda x: x[Const.INPUT][: max_len], new_field_name=Const.INPUT) | |||||
data.apply(lambda x: [vocab.to_index(w) for w in x[Const.INPUT]], new_field_name=Const.INPUT) | |||||
data.apply(lambda x: x[Const.INPUT_LENS(0)][: max_len], new_field_name=Const.INPUT_LENS(0)) | |||||
data.apply(lambda x: x[Const.INPUT_LENS(1)][: max_len], new_field_name=Const.INPUT_LENS(1)) | |||||
target_vocab = Vocabulary(padding=None, unknown=None) | |||||
target_vocab.add_word_lst(['neutral', 'contradiction', 'entailment']) | |||||
target_vocab.build_vocab() | |||||
data.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET) | |||||
data.set_input(Const.INPUT, Const.INPUT_LENS(0), Const.INPUT_LENS(1), Const.TARGET) | |||||
data.set_target(Const.TARGET) | |||||
return data | |||||
bert_dirs = 'path/to/bert/dir' | |||||
# load raw data set | |||||
train_data = SNLILoader().load('./data/snli/snli_1.0_train.jsonl') | |||||
dev_data = SNLILoader().load('./data/snli/snli_1.0_dev.jsonl') | |||||
test_data = SNLILoader().load('./data/snli/snli_1.0_test.jsonl') | |||||
print('successfully load data sets!') | |||||
train_data = preprocess_data(train_data, bert_dirs) | |||||
dev_data = preprocess_data(dev_data, bert_dirs) | |||||
test_data = preprocess_data(test_data, bert_dirs) | |||||
model = BertForNLI(bert_dir=bert_dirs) | |||||
trainer = Trainer( | |||||
train_data=train_data, | |||||
model=model, | |||||
optimizer=Adam(lr=2e-5, model_params=model.parameters()), | |||||
batch_size=torch.cuda.device_count() * 12, | |||||
n_epochs=4, | |||||
print_every=-1, | |||||
dev_data=dev_data, | |||||
metrics=AccuracyMetric(), | |||||
metric_key='acc', | |||||
device=[i for i in range(torch.cuda.device_count())], | |||||
check_code_level=-1 | |||||
) | |||||
trainer.train(load_best_model=True) | |||||
tester = Tester( | |||||
data=test_data, | |||||
model=model, | |||||
metrics=AccuracyMetric(), | |||||
batch_size=torch.cuda.device_count() * 12, | |||||
device=[i for i in range(torch.cuda.device_count())], | |||||
) | |||||
tester.test() | |||||
@@ -0,0 +1,10 @@ | |||||
import unittest | |||||
from ..data import SNLIDataLoader | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
class TestCWSDataLoader(unittest.TestCase): | |||||
def test_case1(self): | |||||
snli_loader = SNLIDataLoader() | |||||
# TODO: still in progress | |||||
@@ -1,7 +1,7 @@ | |||||
import unittest | import unittest | ||||
from reproduction.seqence_labelling.cws.data.CWSDataLoader import SigHanLoader | |||||
from ..data.CWSDataLoader import SigHanLoader | |||||
from fastNLP.core.vocabulary import VocabularyOption | from fastNLP.core.vocabulary import VocabularyOption | ||||
@@ -0,0 +1 @@ | |||||
# TODO |
@@ -0,0 +1 @@ | |||||
# TODO |
@@ -0,0 +1 @@ | |||||
# TODO |
@@ -12,6 +12,7 @@ from fastNLP import AccuracyMetric | |||||
from fastNLP import SGD | from fastNLP import SGD | ||||
from fastNLP import Trainer | from fastNLP import Trainer | ||||
from fastNLP.models.base_model import NaiveClassifier | from fastNLP.models.base_model import NaiveClassifier | ||||
from fastNLP.core.callback import EarlyStopError | |||||
def prepare_env(): | def prepare_env(): | ||||
@@ -1,8 +1,55 @@ | |||||
import unittest | import unittest | ||||
import numpy as np | import numpy as np | ||||
import torch | |||||
from fastNLP import FieldArray | from fastNLP import FieldArray | ||||
from fastNLP.core.field import _get_ele_type_and_dim | |||||
from fastNLP import AutoPadder | |||||
class TestFieldArrayTyepDimDetect(unittest.TestCase): | |||||
""" | |||||
检测FieldArray能否正确识别type与ndim | |||||
""" | |||||
def test_case1(self): | |||||
# 1.1 常规类型测试 | |||||
for value in [1, True, 1.0, 'abc']: | |||||
type_ = type(value) | |||||
_type, _dim = _get_ele_type_and_dim(cell=value) | |||||
self.assertListEqual([_type, _dim], [type_, 0]) | |||||
# 1.2 mix类型报错 | |||||
with self.assertRaises(Exception): | |||||
value = [1, 2, 1.0] | |||||
self.assertRaises(_get_ele_type_and_dim(value)) | |||||
# 带有numpy的测试 | |||||
# 2.1 | |||||
value = np.array([1, 2, 3]) | |||||
type_ = value.dtype | |||||
dim_ = 1 | |||||
self.assertSequenceEqual(_get_ele_type_and_dim(cell=value), [type_, dim_]) | |||||
# 2.2 | |||||
value = np.array([[1, 2], [3, 4, 5]]) # char embedding的场景 | |||||
self.assertSequenceEqual([int, 2], _get_ele_type_and_dim(value)) | |||||
# 2.3 | |||||
value = np.zeros((3, 4)) | |||||
self.assertSequenceEqual([value.dtype, 2], _get_ele_type_and_dim(value)) | |||||
# 2.4 测试错误的dimension | |||||
with self.assertRaises(Exception): | |||||
value = np.array([[1, 2], [3, [1]]]) | |||||
_get_ele_type_and_dim(value) | |||||
# 2.5 测试混合类型 | |||||
with self.assertRaises(Exception): | |||||
value = np.array([[1, 2], [3.0]]) | |||||
_get_ele_type_and_dim(value) | |||||
# 带有tensor的测试 | |||||
# 3.1 word embedding的场景 | |||||
value = torch.zeros(3, 10) | |||||
self.assertSequenceEqual([value.dtype, 2], _get_ele_type_and_dim(value)) | |||||
# 3.2 char embedding/image的场景 | |||||
value = torch.zeros(3, 32, 32) | |||||
self.assertSequenceEqual([value.dtype, 3], _get_ele_type_and_dim(value)) | |||||
class TestFieldArrayInit(unittest.TestCase): | class TestFieldArrayInit(unittest.TestCase): | ||||
@@ -31,12 +78,6 @@ class TestFieldArrayInit(unittest.TestCase): | |||||
# 三维list | # 三维list | ||||
fa = FieldArray("x", [[[1, 2], [3, 4]], [[1, 2], [3, 4]]], is_input=True) | fa = FieldArray("x", [[[1, 2], [3, 4]], [[1, 2], [3, 4]]], is_input=True) | ||||
def test_init_v7(self): | |||||
# list of array | |||||
fa = FieldArray("x", [np.array([[1, 2], [3, 4]]), np.array([[1, 2], [3, 4]])], is_input=True) | |||||
self.assertEqual(fa.pytype, int) | |||||
self.assertEqual(fa.dtype, np.int) | |||||
def test_init_v4(self): | def test_init_v4(self): | ||||
# 一维list | # 一维list | ||||
val = [1, 2, 3, 4] | val = [1, 2, 3, 4] | ||||
@@ -56,6 +97,11 @@ class TestFieldArrayInit(unittest.TestCase): | |||||
fa.append(val) | fa.append(val) | ||||
def test_init_v7(self): | def test_init_v7(self): | ||||
# list of array | |||||
fa = FieldArray("x", [np.array([[1, 2], [3, 4]]), np.array([[1, 2], [3, 4]])], is_input=True) | |||||
self.assertEqual(fa.dtype, np.array([1]).dtype) | |||||
def test_init_v8(self): | |||||
# 二维list | # 二维list | ||||
val = np.array([[1, 2], [3, 4]]) | val = np.array([[1, 2], [3, 4]]) | ||||
fa = FieldArray("x", [val], is_input=True) | fa = FieldArray("x", [val], is_input=True) | ||||
@@ -79,33 +125,23 @@ class TestFieldArray(unittest.TestCase): | |||||
self.assertListEqual(list(fa.get([0, 1, 2])), [1, 2, 3]) | self.assertListEqual(list(fa.get([0, 1, 2])), [1, 2, 3]) | ||||
def test_type_conversion(self): | def test_type_conversion(self): | ||||
fa = FieldArray("x", [1.2, 2.2, 3, 4, 5], is_input=True) | |||||
self.assertEqual(fa.pytype, float) | |||||
self.assertEqual(fa.dtype, np.float64) | |||||
fa = FieldArray("x", [1, 2, 3, 4, 5], is_input=True) | fa = FieldArray("x", [1, 2, 3, 4, 5], is_input=True) | ||||
fa.append(1.3333) | |||||
self.assertEqual(fa.pytype, float) | |||||
self.assertEqual(fa.dtype, np.float64) | |||||
self.assertEqual(fa.dtype, int) | |||||
fa = FieldArray("y", [1.1, 2.2, 3.3, 4.4, 5.5], is_input=True) | fa = FieldArray("y", [1.1, 2.2, 3.3, 4.4, 5.5], is_input=True) | ||||
fa.append(10) | |||||
self.assertEqual(fa.pytype, float) | |||||
self.assertEqual(fa.dtype, np.float64) | |||||
fa.append(10.0) | |||||
self.assertEqual(fa.dtype, float) | |||||
fa = FieldArray("y", ["a", "b", "c", "d"], is_input=True) | fa = FieldArray("y", ["a", "b", "c", "d"], is_input=True) | ||||
fa.append("e") | fa.append("e") | ||||
self.assertEqual(fa.dtype, np.str) | |||||
self.assertEqual(fa.pytype, str) | |||||
self.assertEqual(fa.dtype, str) | |||||
def test_support_np_array(self): | def test_support_np_array(self): | ||||
fa = FieldArray("y", np.array([[1.1, 2.2, 3.3, 4.4, 5.5]]), is_input=True) | fa = FieldArray("y", np.array([[1.1, 2.2, 3.3, 4.4, 5.5]]), is_input=True) | ||||
self.assertEqual(fa.dtype, np.float64) | self.assertEqual(fa.dtype, np.float64) | ||||
self.assertEqual(fa.pytype, float) | |||||
fa.append(np.array([1.1, 2.2, 3.3, 4.4, 5.5])) | fa.append(np.array([1.1, 2.2, 3.3, 4.4, 5.5])) | ||||
self.assertEqual(fa.dtype, np.float64) | self.assertEqual(fa.dtype, np.float64) | ||||
self.assertEqual(fa.pytype, float) | |||||
fa = FieldArray("my_field", np.random.rand(3, 5), is_input=True) | fa = FieldArray("my_field", np.random.rand(3, 5), is_input=True) | ||||
# in this case, pytype is actually a float. We do not care about it. | # in this case, pytype is actually a float. We do not care about it. | ||||
@@ -113,11 +149,10 @@ class TestFieldArray(unittest.TestCase): | |||||
def test_nested_list(self): | def test_nested_list(self): | ||||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.1, 2.2, 3.3, 4.4, 5.5]], is_input=True) | fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.1, 2.2, 3.3, 4.4, 5.5]], is_input=True) | ||||
self.assertEqual(fa.pytype, float) | |||||
self.assertEqual(fa.dtype, np.float64) | |||||
self.assertEqual(fa.dtype, float) | |||||
def test_getitem_v1(self): | def test_getitem_v1(self): | ||||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True) | |||||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]], is_input=True) | |||||
self.assertEqual(fa[0], [1.1, 2.2, 3.3, 4.4, 5.5]) | self.assertEqual(fa[0], [1.1, 2.2, 3.3, 4.4, 5.5]) | ||||
ans = fa[[0, 1]] | ans = fa[[0, 1]] | ||||
self.assertTrue(isinstance(ans, np.ndarray)) | self.assertTrue(isinstance(ans, np.ndarray)) | ||||
@@ -150,7 +185,7 @@ class TestFieldArray(unittest.TestCase): | |||||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True) | fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True) | ||||
fa.append(["str", 0, 0, 0, 1.89]) | fa.append(["str", 0, 0, 0, 1.89]) | ||||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True) | |||||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]], is_input=True) | |||||
fa.append([1.2, 2.3, 3.4, 4.5, 5.6]) | fa.append([1.2, 2.3, 3.4, 4.5, 5.6]) | ||||
self.assertEqual(len(fa), 3) | self.assertEqual(len(fa), 3) | ||||
self.assertEqual(fa[2], [1.2, 2.3, 3.4, 4.5, 5.6]) | self.assertEqual(fa[2], [1.2, 2.3, 3.4, 4.5, 5.6]) | ||||
@@ -163,33 +198,86 @@ class TestFieldArray(unittest.TestCase): | |||||
fa = FieldArray("y", [(1, "1"), (2, "2"), (3, "3"), (4, "4")], is_target=True, ignore_type=True) | fa = FieldArray("y", [(1, "1"), (2, "2"), (3, "3"), (4, "4")], is_target=True, ignore_type=True) | ||||
class TestPadder(unittest.TestCase): | |||||
class TestAutoPadder(unittest.TestCase): | |||||
def test00(self): | |||||
padder = AutoPadder() | |||||
# 没有类型时 | |||||
contents = [(1, 2), ('str', 'a')] | |||||
padder(contents, None, None, None) | |||||
def test01(self): | def test01(self): | ||||
""" | |||||
测试AutoPadder能否正常工作 | |||||
:return: | |||||
""" | |||||
from fastNLP import AutoPadder | |||||
# 测试使用多维的bool, int, str, float的情况 | |||||
# str | |||||
padder = AutoPadder() | padder = AutoPadder() | ||||
content = ['This is a str', 'this is another str'] | content = ['This is a str', 'this is another str'] | ||||
self.assertListEqual(content, padder(content, None, np.str).tolist()) | |||||
self.assertListEqual(content, padder(content, None, str, 0).tolist()) | |||||
content = [1, 2] | |||||
self.assertListEqual(content, padder(content, None, np.int64).tolist()) | |||||
content = [[1,2], [3], [4]] | |||||
self.assertListEqual([[1,2], [3, 0], [4, 0]], | |||||
padder(content, None, np.int64).tolist()) | |||||
# 1维int | |||||
content = [[1, 2, 3], [4,], [5, 6, 7, 8]] | |||||
padded_content = [[1, 2, 3, 0], [4, 0, 0, 0], [5, 6, 7, 8]] | |||||
self.assertListEqual(padder(content, None, int, 1).tolist(), padded_content) | |||||
# 二维int | |||||
padded_content = [[[1, 2, 3, 0], [4, 5, 0, 0], [7, 8, 9, 10]], [[1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]] | |||||
content = [ | content = [ | ||||
[[1, 2, 3], [4, 5], [7,8,9,10]], | |||||
[[1]] | |||||
] | |||||
self.assertListEqual(content, | |||||
padder(content, None, np.int64).tolist()) | |||||
[[1, 2, 3], [4, 5], [7, 8, 9, 10]], | |||||
[[1]] | |||||
] | |||||
self.assertListEqual(padder(content, None, int, 2).tolist(), padded_content) | |||||
# 3维图片 | |||||
contents = [np.random.rand(3, 4, 4).tolist() for _ in range(5)] | |||||
self.assertTrue(padder(contents, None, float, 3).shape==(5, 3, 4, 4)) | |||||
# 更高维度直接返回 | |||||
contents = [np.random.rand(24, 3, 4, 4).tolist() for _ in range(5)] | |||||
self.assertTrue(isinstance(padder(contents, None, float, 4), np.ndarray)) | |||||
def test02(self): | def test02(self): | ||||
padder = AutoPadder() | |||||
# 测试numpy的情况 | |||||
# 0维 | |||||
contents = np.arange(12) | |||||
self.assertListEqual(padder(contents, None, contents.dtype, 0).tolist(), contents.tolist()) | |||||
# 1维 | |||||
contents = np.arange(12).reshape((3, 4)) | |||||
self.assertListEqual(padder(contents, None, contents.dtype, 1).tolist(), contents.tolist()) | |||||
# 2维 | |||||
contents = np.ones((3, 10, 5)) | |||||
self.assertListEqual(padder(contents, None, contents.dtype, 2).tolist(), contents.tolist()) | |||||
# 3维 | |||||
contents = [np.random.rand(3, 4, 4) for _ in range(5)] | |||||
l_contents = [content.tolist() for content in contents] | |||||
self.assertListEqual(padder(contents, None, contents[0].dtype, 3).tolist(), l_contents) | |||||
def test03(self): | |||||
padder = AutoPadder() | |||||
# 测试tensor的情况 | |||||
# 0维 | |||||
contents = torch.arange(12) | |||||
r_contents = padder(contents, None, contents.dtype, 0) | |||||
self.assertSequenceEqual(r_contents.tolist(), contents.tolist()) | |||||
self.assertTrue(r_contents.dtype==contents.dtype) | |||||
# 0维 | |||||
contents = [torch.tensor(1) for _ in range(10)] | |||||
self.assertSequenceEqual(padder(contents, None, torch.int64, 0).tolist(), contents) | |||||
# 1维 | |||||
contents = torch.randn(3, 4) | |||||
padder(contents, None, torch.float64, 1) | |||||
# 3维 | |||||
contents = [torch.randn(3, 4, 4) for _ in range(5)] | |||||
padder(contents, None, torch.float64, 3) | |||||
class TestEngChar2DPadder(unittest.TestCase): | |||||
def test01(self): | |||||
""" | """ | ||||
测试EngChar2DPadder能不能正确使用 | 测试EngChar2DPadder能不能正确使用 | ||||
:return: | :return: | ||||
@@ -198,38 +286,31 @@ class TestPadder(unittest.TestCase): | |||||
padder = EngChar2DPadder(pad_length=0) | padder = EngChar2DPadder(pad_length=0) | ||||
contents = [1, 2] | contents = [1, 2] | ||||
# 不能是1维 | |||||
with self.assertRaises(ValueError): | |||||
padder(contents, None, np.int64) | |||||
# 不能是0维 | |||||
with self.assertRaises(Exception): | |||||
padder(contents, None, np.int64, 0) | |||||
contents = [[1, 2]] | contents = [[1, 2]] | ||||
# 不能是2维 | |||||
with self.assertRaises(ValueError): | |||||
padder(contents, None, np.int64) | |||||
contents = [[[[1, 2]]]] | |||||
# 不能是1维 | |||||
with self.assertRaises(Exception): | |||||
padder(contents, None, np.int64, 1) | |||||
contents = [ | |||||
[[[[1, 2]]]] | |||||
] | |||||
# 不能是3维以上 | # 不能是3维以上 | ||||
with self.assertRaises(ValueError): | |||||
padder(contents, None, np.int64) | |||||
with self.assertRaises(Exception): | |||||
padder(contents, None, np.int64, 3) | |||||
contents = [ | contents = [ | ||||
[[1, 2, 3], [4, 5], [7,8,9,10]], | [[1, 2, 3], [4, 5], [7,8,9,10]], | ||||
[[1]] | [[1]] | ||||
] | ] | ||||
self.assertListEqual([[[1, 2, 3, 0], [4, 5, 0, 0], [7, 8, 9, 10]], [[1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]], | self.assertListEqual([[[1, 2, 3, 0], [4, 5, 0, 0], [7, 8, 9, 10]], [[1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]], | ||||
padder(contents, None, np.int64).tolist()) | |||||
padder(contents, None, np.int64, 2).tolist()) | |||||
padder = EngChar2DPadder(pad_length=5, pad_val=-100) | padder = EngChar2DPadder(pad_length=5, pad_val=-100) | ||||
self.assertListEqual( | self.assertListEqual( | ||||
[[[1, 2, 3, -100, -100], [4, 5, -100, -100, -100], [7, 8, 9, 10, -100]], | [[[1, 2, 3, -100, -100], [4, 5, -100, -100, -100], [7, 8, 9, 10, -100]], | ||||
[[1, -100, -100, -100, -100], [-100, -100, -100, -100, -100], [-100, -100, -100, -100, -100]]], | [[1, -100, -100, -100, -100], [-100, -100, -100, -100, -100], [-100, -100, -100, -100, -100]]], | ||||
padder(contents, None, np.int64).tolist() | |||||
padder(contents, None, np.int64, 2).tolist() | |||||
) | ) | ||||
def test_None_dtype(self): | |||||
from fastNLP import AutoPadder | |||||
padder = AutoPadder() | |||||
content = [ | |||||
[[1, 2, 3], [4, 5], [7, 8, 9, 10]], | |||||
[[1]] | |||||
] | |||||
ans = padder(content, None, None).tolist() | |||||
self.assertListEqual(content, ans) |
@@ -18,7 +18,7 @@ class Model(nn.Module): | |||||
self.param = nn.Parameter(torch.zeros(0)) | self.param = nn.Parameter(torch.zeros(0)) | ||||
class TestMoveModelDeivce(unittest.TestCase): | |||||
class TestMoveModelDevice(unittest.TestCase): | |||||
def test_case1(self): | def test_case1(self): | ||||
# 测试str | # 测试str | ||||
model = Model() | model = Model() | ||||
@@ -1,7 +1,7 @@ | |||||
import unittest | import unittest | ||||
import os | |||||
from fastNLP.io import Conll2003Loader, PeopleDailyCorpusLoader, CSVLoader, SNLILoader, JsonLoader | from fastNLP.io import Conll2003Loader, PeopleDailyCorpusLoader, CSVLoader, SNLILoader, JsonLoader | ||||
from fastNLP.io.dataset_loader import SSTLoader | |||||
class TestDatasetLoader(unittest.TestCase): | class TestDatasetLoader(unittest.TestCase): | ||||
@@ -28,3 +28,34 @@ class TestDatasetLoader(unittest.TestCase): | |||||
def test_JsonLoader(self): | def test_JsonLoader(self): | ||||
ds = JsonLoader().load('test/data_for_tests/sample_snli.jsonl') | ds = JsonLoader().load('test/data_for_tests/sample_snli.jsonl') | ||||
assert len(ds) == 3 | assert len(ds) == 3 | ||||
def test_SST(self): | |||||
train_data = """(3 (2 (2 The) (2 Rock)) (4 (3 (2 is) (4 (2 destined) (2 (2 (2 (2 (2 to) (2 (2 be) (2 (2 the) (2 (2 21st) (2 (2 (2 Century) (2 's)) (2 (3 new) (2 (2 ``) (2 Conan)))))))) (2 '')) (2 and)) (3 (2 that) (3 (2 he) (3 (2 's) (3 (2 going) (3 (2 to) (4 (3 (2 make) (3 (3 (2 a) (3 splash)) (2 (2 even) (3 greater)))) (2 (2 than) (2 (2 (2 (2 (1 (2 Arnold) (2 Schwarzenegger)) (2 ,)) (2 (2 Jean-Claud) (2 (2 Van) (2 Damme)))) (2 or)) (2 (2 Steven) (2 Segal))))))))))))) (2 .))) | |||||
(4 (4 (4 (2 The) (4 (3 gorgeously) (3 (2 elaborate) (2 continuation)))) (2 (2 (2 of) (2 ``)) (2 (2 The) (2 (2 (2 Lord) (2 (2 of) (2 (2 the) (2 Rings)))) (2 (2 '') (2 trilogy)))))) (2 (3 (2 (2 is) (2 (2 so) (2 huge))) (2 (2 that) (3 (2 (2 (2 a) (2 column)) (2 (2 of) (2 words))) (2 (2 (2 (2 can) (1 not)) (3 adequately)) (2 (2 describe) (2 (3 (2 (2 co-writer\/director) (2 (2 Peter) (3 (2 Jackson) (2 's)))) (3 (2 expanded) (2 vision))) (2 (2 of) (2 (2 (2 J.R.R.) (2 (2 Tolkien) (2 's))) (2 Middle-earth))))))))) (2 .))) | |||||
(3 (3 (2 (2 (2 (2 (2 Singer\/composer) (2 (2 Bryan) (2 Adams))) (2 (2 contributes) (2 (2 (2 a) (2 slew)) (2 (2 of) (2 songs))))) (2 (2 --) (2 (2 (2 (2 a) (2 (2 few) (3 potential))) (2 (2 (2 hits) (2 ,)) (2 (2 (2 a) (2 few)) (1 (1 (2 more) (1 (2 simply) (2 intrusive))) (2 (2 to) (2 (2 the) (2 story))))))) (2 --)))) (2 but)) (3 (4 (2 the) (3 (2 whole) (2 package))) (2 (3 certainly) (3 (2 captures) (2 (1 (2 the) (2 (2 (2 intended) (2 (2 ,) (2 (2 er) (2 ,)))) (3 spirit))) (2 (2 of) (2 (2 the) (2 piece)))))))) (2 .)) | |||||
(2 (2 (2 You) (2 (2 'd) (2 (2 think) (2 (2 by) (2 now))))) (2 (2 America) (2 (2 (2 would) (1 (2 have) (2 (2 (2 had) (1 (2 enough) (2 (2 of) (2 (2 plucky) (2 (2 British) (1 eccentrics)))))) (4 (2 with) (4 (3 hearts) (3 (2 of) (3 gold))))))) (2 .)))) | |||||
""" | |||||
test_data = """(3 (2 Yet) (3 (2 (2 the) (2 act)) (3 (4 (3 (2 is) (3 (2 still) (4 charming))) (2 here)) (2 .)))) | |||||
(4 (2 (2 Whether) (2 (2 (2 (2 or) (1 not)) (3 (2 you) (2 (2 're) (3 (3 enlightened) (2 (2 by) (2 (2 any) (2 (2 of) (2 (2 Derrida) (2 's))))))))) (2 (2 lectures) (2 (2 on) (2 (2 ``) (2 (2 (2 (2 (2 (2 the) (2 other)) (2 '')) (2 and)) (2 ``)) (2 (2 the) (2 self)))))))) (3 (2 ,) (3 (2 '') (3 (2 Derrida) (3 (3 (2 is) (4 (2 an) (4 (4 (2 undeniably) (3 (4 (3 fascinating) (2 and)) (4 playful))) (2 fellow)))) (2 .)))))) | |||||
(4 (3 (2 (2 Just) (2 (2 the) (2 labour))) (3 (2 involved) (3 (2 in) (4 (2 creating) (3 (3 (2 the) (3 (3 layered) (2 richness))) (3 (2 of) (3 (2 (2 the) (2 imagery)) (2 (2 in) (3 (2 (2 this) (2 chiaroscuro)) (2 (2 of) (2 (2 (2 madness) (2 and)) (2 light)))))))))))) (3 (3 (2 is) (4 astonishing)) (2 .))) | |||||
(3 (3 (2 Part) (3 (2 of) (4 (2 (2 the) (3 charm)) (2 (2 of) (2 (2 Satin) (2 Rouge)))))) (3 (3 (2 is) (3 (2 that) (3 (2 it) (2 (1 (2 avoids) (2 (2 the) (1 obvious))) (3 (2 with) (3 (3 (3 humour) (2 and)) (2 lightness))))))) (2 .))) | |||||
(4 (2 (2 a) (2 (2 screenplay) (2 more))) (3 (4 ingeniously) (2 (2 constructed) (2 (2 (2 (2 than) (2 ``)) (2 Memento)) (2 ''))))) | |||||
(3 (2 ``) (3 (2 (2 Extreme) (2 Ops)) (3 (2 '') (4 (4 (3 exceeds) (2 expectations)) (2 .))))) | |||||
""" | |||||
train, test = 'train--', 'test--' | |||||
with open(train, 'w', encoding='utf-8') as f: | |||||
f.write(train_data) | |||||
with open(test, 'w', encoding='utf-8') as f: | |||||
f.write(test_data) | |||||
loader = SSTLoader() | |||||
info = loader.process( | |||||
{train: train, test: test}, | |||||
train_ds=[train], | |||||
src_vocab_op=dict(min_freq=2) | |||||
) | |||||
assert len(list(info.vocabs.items())) == 2 | |||||
assert len(list(info.datasets.items())) == 2 | |||||
print(info.vocabs) | |||||
print(info.datasets) | |||||
os.remove(train), os.remove(test) |