@@ -21,7 +21,6 @@ __all__ = [ | |||
"AutoPadder", | |||
"EngChar2DPadder", | |||
"CollectFn", | |||
"ConcatCollectFn", | |||
"Vocabulary", | |||
@@ -97,4 +96,4 @@ from .tester import Tester | |||
from .trainer import Trainer | |||
from .utils import cache_results, seq_len_to_mask, get_seq_len | |||
from .vocabulary import Vocabulary | |||
from .collect_fn import CollectFn, ConcatCollectFn | |||
from .collect_fn import ConcatCollectFn |
@@ -9,17 +9,16 @@ __all__ = [ | |||
] | |||
import atexit | |||
from numbers import Number | |||
import abc | |||
from numbers import Number | |||
import numpy as np | |||
import torch | |||
import torch.utils.data | |||
from collections import defaultdict | |||
from ._logger import logger | |||
from .dataset import DataSet | |||
from .sampler import SequentialSampler | |||
from .field import _get_ele_type_and_dim | |||
_python_is_exit = False | |||
@@ -33,6 +32,9 @@ atexit.register(_set_python_is_exit) | |||
class DataSetGetter: | |||
""" | |||
传递给torch.utils.data.DataLoader获取数据,DataLoder会传入int的idx获取数据(调用这里的__getitem__()函数)。 | |||
""" | |||
def __init__(self, dataset: DataSet, as_numpy=False): | |||
self.dataset = dataset | |||
self.as_numpy = as_numpy | |||
@@ -56,7 +58,6 @@ class DataSetGetter: | |||
:param batch: [[idx1, x_dict1, y_dict1], [idx2, x_dict2, y_dict2], [xx, xx, xx]] | |||
:return: | |||
""" | |||
# TODO 支持在DataSet中定义collate_fn,因为有时候可能需要不同的field之间融合,比如BERT的场景 | |||
indices = [] | |||
sin_x, sin_y = defaultdict(list), defaultdict(list) | |||
for idx, ins in ins_list: | |||
@@ -67,24 +68,6 @@ class DataSetGetter: | |||
if n in self.y_names: | |||
sin_y[n].append(v) | |||
def may_to_tensor(data): | |||
dtype, dim = _get_ele_type_and_dim(data) | |||
# print(dtype, type(dtype), str(dtype)) | |||
if not self.as_numpy: | |||
try: | |||
data, flag = _to_tensor(data, dtype) | |||
except TypeError as e: | |||
logger.error(f"Field {n} cannot be converted to torch.tensor.") | |||
raise e | |||
# if torch.is_tensor(data): | |||
# str_dtype = str(dtype) | |||
# if 'float' in str_dtype: | |||
# data = data.float() | |||
# elif 'int' in str_dtype: | |||
# data = data.long() | |||
# print(data.dtype) | |||
return data | |||
def pad(batch_dict): | |||
result = {} | |||
for n, vlist in batch_dict.items(): | |||
@@ -98,25 +81,13 @@ class DataSetGetter: | |||
sin_x = pad(sin_x) | |||
sin_y = pad(sin_y) | |||
bx, by = self.dataset._collect_batch(ins_list) | |||
def convert_tensor(batch_dict): | |||
for n, v in batch_dict.items(): | |||
batch_dict[n] = may_to_tensor(v) | |||
# collect_fn replaces single field | |||
sin_x.update(bx) | |||
sin_y.update(by) | |||
convert_tensor(sin_x) | |||
convert_tensor(sin_y) | |||
if not self.dataset.collector.is_empty(): | |||
bx, by = self.dataset._collect_batch(ins_list) | |||
sin_x.update(bx) | |||
sin_y.update(by) | |||
return (indices, sin_x, sin_y) | |||
def set_idx_list(self, idx_list): | |||
if len(idx_list) != len(self.idx_list): | |||
raise ValueError | |||
self.idx_list = idx_list | |||
def __getattr__(self, item): | |||
if hasattr(self.dataset, item): | |||
return getattr(self.dataset, item) | |||
@@ -125,6 +96,10 @@ class DataSetGetter: | |||
class SamplerAdapter(torch.utils.data.Sampler): | |||
""" | |||
用于传入torch.utils.data.DataLoader中,DataLoader会调用__iter__()方法获取index(一次只取一个int) | |||
""" | |||
def __init__(self, sampler, dataset): | |||
super().__init__(dataset) | |||
self.sampler = sampler | |||
@@ -138,6 +113,11 @@ class SamplerAdapter(torch.utils.data.Sampler): | |||
class BatchIter: | |||
""" | |||
Trainer用于迭代数据的类。继承该类,并实现get_num_batches(), get_batch_indices(), dataset(), num_batches(), | |||
__iter__()方法。 | |||
""" | |||
def __init__(self, dataset, batch_size=1, sampler=None, | |||
num_workers=0, pin_memory=False, drop_last=False, | |||
timeout=0, worker_init_fn=None, collate_fn=None): | |||
@@ -145,6 +125,8 @@ class BatchIter: | |||
self.sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset) | |||
else: | |||
self.sampler = sampler | |||
# DataLoader的collect_fn输入是List[],里面的元素是dataset[index]返回的结果 | |||
if collate_fn is None: | |||
# pytoch <= 1.1 中不能设置collate_fn=None | |||
self.dataiter = torch.utils.data.DataLoader( | |||
@@ -160,17 +142,25 @@ class BatchIter: | |||
timeout=timeout, worker_init_fn=worker_init_fn) | |||
# 以sampler的数量为准,因为DistributedSampler的时候每个进程上并不是所有的数据都用上了 | |||
self.num_batches = self.get_num_batches(len(self.dataiter.sampler), batch_size, drop_last) | |||
self._num_batches = self.get_num_batches(len(self.dataiter.sampler), batch_size, drop_last) | |||
self.batch_size = batch_size | |||
self.cur_batch_indices = None | |||
@property | |||
def num_batches(self): | |||
return self._num_batches | |||
@num_batches.setter | |||
def num_batches(self, value): | |||
self._num_batches = value | |||
def init_iter(self): | |||
pass | |||
@staticmethod | |||
def get_num_batches(num_samples, batch_size, drop_last): | |||
""" | |||
计算batch的数量。 | |||
计算batch的数量。用于前端显示进度 | |||
:param int num_samples: | |||
:param int batch_size: | |||
@@ -184,7 +174,7 @@ class BatchIter: | |||
def get_batch_indices(self): | |||
""" | |||
获取当前已经输出的batch的index。 | |||
获取最近输出的batch的index。用于溯源当前batch的数据 | |||
:return: | |||
""" | |||
@@ -195,8 +185,22 @@ class BatchIter: | |||
@property | |||
def dataset(self): | |||
""" | |||
获取正在参与iterate的dataset | |||
:return: | |||
""" | |||
return self.dataiter.dataset | |||
@abc.abstractmethod | |||
def __iter__(self): | |||
""" | |||
用于实际数据循环的类,返回值需要为两个dict, 第一个dict中的内容会认为是input, 第二个dict中的内容会认为是target | |||
:return: | |||
""" | |||
raise NotImplemented | |||
class DataSetIter(BatchIter): | |||
""" | |||
@@ -4,7 +4,8 @@ from builtins import sorted | |||
import torch | |||
import numpy as np | |||
from .field import _get_ele_type_and_dim | |||
from collections import defaultdict | |||
from .utils import logger | |||
from copy import deepcopy | |||
def _check_type(batch_dict, fields): | |||
@@ -36,127 +37,89 @@ def batching(samples, max_len=0, padding_val=0): | |||
class Collector: | |||
""" | |||
辅助DataSet管理collect_fn的类 | |||
""" | |||
def __init__(self): | |||
self.fns = {} | |||
self.input2fn = defaultdict(list) | |||
self.output2fn = defaultdict(list) | |||
self.fn2input = {} | |||
self.fn2output = {} | |||
def add_fn(self, fn, inputs, outputs, is_input, is_target): | |||
for name in outputs: | |||
if name in self.output2fn: | |||
raise ValueError("Duplicated name: {} for CollectFn: {}".format(name, fn)) | |||
if fn.num_inputs() > 0 and len(inputs) != fn.num_inputs(): | |||
raise ValueError( | |||
"Incorrect num of inputs, should be {} not {}".format( | |||
fn.num_inputs(), len(inputs) | |||
)) | |||
if fn.num_outputs() > 0 and len(outputs) != fn.num_outputs(): | |||
raise ValueError("Incorrect num of inputs, should be {} not {}".format( | |||
fn.num_outputs(), len(outputs))) | |||
self.fns[fn] = {'is_input': is_input, 'is_target': is_target} | |||
for i, field in enumerate(inputs): | |||
self.input2fn[field].append((fn, i)) | |||
for i, name in enumerate(outputs): | |||
self.output2fn[name].append((fn, i)) | |||
def _rebuild_fn2io(self): | |||
def transpose(name2fn): | |||
fn2names = defaultdict(list) | |||
for name, vlist in name2fn.items(): | |||
for fn, i in vlist: | |||
fn2names[fn].append((name, i)) | |||
for fn, vlist in fn2names.items(): | |||
vlist = sorted(vlist, key=lambda x: x[1]) | |||
fn2names[fn] = [name for name, i in vlist] | |||
return fn2names | |||
self.fn2input = transpose(self.input2fn) | |||
self.fn2output = transpose(self.output2fn) | |||
def _clear_fn2io(self): | |||
self.fn2input.clear() | |||
self.fn2output.clear() | |||
self.collect_fns = {} | |||
def add_fn(self, fn, name=None): | |||
""" | |||
向collector新增一个collect_fn函数 | |||
:param callable fn: | |||
:param str,int name: | |||
:return: | |||
""" | |||
if name in self.collect_fns: | |||
logger.warn(f"collect_fn:{name} will be overwritten.") | |||
if name is None: | |||
name = len(self.collect_fns) | |||
self.collect_fns[name] = fn | |||
def is_empty(self): | |||
""" | |||
返回是否包含collect_fn | |||
:return: | |||
""" | |||
return len(self.collect_fns)==0 | |||
def delete_fn(self, name=None): | |||
""" | |||
删除collect_fn | |||
:param str,int name: 如果为None就删除最近加入的collect_fn | |||
:return: | |||
""" | |||
if not self.is_empty(): | |||
if name in self.collect_fns: | |||
self.collect_fns.pop(name) | |||
elif name is None: | |||
last_key = list(self.collect_fns.keys())[0] | |||
self.collect_fns.pop(last_key) | |||
def collect_batch(self, ins_list): | |||
if len(ins_list) == 0: | |||
return {}, {} | |||
if len(self.fn2output) == 0: | |||
self._rebuild_fn2io() | |||
bx = {} | |||
by = {} | |||
for fn, attr in self.fns.items(): | |||
inputs = self.fn2input.get(fn, None) | |||
outputs = self.fn2output.get(fn, None) | |||
res = fn.collect(ins_list, inputs, outputs) | |||
if attr.get('is_input', False): | |||
bx.update(res) | |||
if attr.get('is_target', False): | |||
by.update(res) | |||
bx, by = {}, {} | |||
for name, fn in self.collect_fns.items(): | |||
try: | |||
batch_x, batch_y = fn(ins_list) | |||
except BaseException as e: | |||
logger.error(f"Exception:`{e}` happens when call collect_fn:`{name}`.") | |||
raise e | |||
bx.update(batch_x) | |||
by.update(batch_y) | |||
return bx, by | |||
def rename_field(self, old_f, new_f): | |||
if new_f in self.input2fn: | |||
# name conflict | |||
raise ValueError | |||
if old_f not in self.input2fn: | |||
# renamed field not affect collectors | |||
return | |||
self.input2fn[new_f] = self.input2fn[old_f] | |||
self._clear_fn2io() | |||
def drop_field(self, f): | |||
if f in self.input2fn: | |||
raise ValueError | |||
def outputs(self): | |||
return self.output2fn.keys() | |||
def copy_from(self, col): | |||
assert isinstance(col, Collector) | |||
self.fns = col.fns.copy() | |||
self.input2fn = col.input2fn.copy() | |||
self.output2fn = col.output2fn.copy() | |||
self._clear_fn2io() | |||
class CollectFn: | |||
def __init__(self): | |||
self.fields = [] | |||
def collect(self, ins_list, inputs, outputs): | |||
raise NotImplementedError | |||
new_col = Collector() | |||
new_col.collect_fns = deepcopy(col) | |||
return new_col | |||
def num_inputs(self): | |||
return 0 | |||
def num_outputs(self): | |||
return 0 | |||
@staticmethod | |||
def get_batch_size(batch_dict): | |||
if len(batch_dict) == 0: | |||
return 0 | |||
return len(next(iter(batch_dict.values()))) | |||
class ConcatCollectFn(CollectFn): | |||
class ConcatCollectFn: | |||
""" | |||
field拼接Fn,将不同field按序拼接后,padding产生数据。所有field必须有相同的dim。 | |||
field拼接collect_fn,将不同field按序拼接后,padding产生数据。 | |||
:param List[str] inputs: 将哪些field的数据拼接起来, 目前仅支持1d的field | |||
:param str output: 拼接后的field名称 | |||
:param pad_val: padding的数值 | |||
:param max_len: 拼接后最大长度 | |||
:param is_input: 是否将生成的output设置为input | |||
:param is_target: 是否将生成的output设置为target | |||
""" | |||
def __init__(self, pad_val=0, max_len=0): | |||
def __init__(self, inputs, output, pad_val=0, max_len=0, is_input=True, is_target=False): | |||
super().__init__() | |||
assert isinstance(inputs, list) | |||
self.inputs = inputs | |||
self.output = output | |||
self.pad_val = pad_val | |||
self.max_len = max_len | |||
self.is_input = is_input | |||
self.is_target = is_target | |||
@staticmethod | |||
def _to_numpy(seq): | |||
@@ -165,21 +128,18 @@ class ConcatCollectFn(CollectFn): | |||
else: | |||
return np.array(seq) | |||
def collect(self, ins_list, inputs, outputs): | |||
def __call__(self, ins_list): | |||
samples = [] | |||
for i, ins in ins_list: | |||
sample = [] | |||
for i in inputs: | |||
sample.append(self._to_numpy(ins[i])) | |||
for input_name in self.inputs: | |||
sample.append(self._to_numpy(ins[input_name])) | |||
samples.append(np.concatenate(sample, axis=0)) | |||
seq_len = [s.shape[0] for s in samples] | |||
batch = batching(samples, max_len=self.max_len, padding_val=self.pad_val) | |||
o1, o2 = outputs | |||
return {o1: batch, o2: seq_len} | |||
def num_inputs(self): | |||
return 0 | |||
b_x, b_y = {}, {} | |||
if self.is_input: | |||
b_x[self.output] = batch | |||
if self.is_target: | |||
b_y[self.output] = batch | |||
def num_outputs(self): | |||
# (concat_words, seq_len) | |||
return 2 | |||
return b_x, b_y |
@@ -281,6 +281,75 @@ | |||
# 也可以设置pad的value | |||
dataset.set_pad_val('chars', -1) | |||
3.3 根据DataSet中多个field合成新的field | |||
-------------------------------------- | |||
DataSet支持在进行batch时,默认只能看到当前的field的值,但在某些训练中可能存在以下的情况: (1)需要两个field拼接成为一个field; | |||
(2)需要在batch中进行负采样。这时候就需要能够同时利用多个field进行batch的操作,DataSet中的add_collect_fn()函数支持添加 | |||
自定义涉及多个field的collect_fn函数。例如下例中将两个field拼接成一个field的场景 | |||
.. code-block:: | |||
from fastNLP import DataSet, DataSetIter | |||
import torch | |||
data = DataSet({ | |||
'x1': [[0, 1], | |||
[2]], | |||
'x2': [[3], | |||
[2, 4, 5]], | |||
'y': [0, 1] | |||
}) | |||
data.set_target('y') | |||
# 所有的collect_fn函数都接受list[(ind1, instance1), (ind2, instance2), ...]作为输入,其中ind1/ind2是该instance在dataset中 | |||
# 的index,instance1/instance2是这次batch取出来的数据,包含了所有的field. | |||
def concat_collect_fn(ins_list): | |||
x1 = [ins['x1'] for ind,ins in ins_list] | |||
x2 = [ins['x2'] for ind,ins in ins_list] | |||
xs = [] | |||
for i in range(len(ins_list)): | |||
xs.append(torch.LongTensor(x1[i] + x2[i])) | |||
# 需要自行pad并转换为tensor,但不需要移动到gpu | |||
arr = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=0) | |||
b_x = {'x': arr} | |||
b_y = {} | |||
# 返回值一定是两个dict,第一个dict的值会认为是input,第二个dict的值会认为是target. 若名称与已有input或target重复,则 | |||
# 采用返回值。 | |||
return b_x, b_y | |||
data.add_collect_fn(concat_collect_fn) | |||
for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2): | |||
print("batch_x:", batch_x) | |||
print("batch_y:", batch_y) | |||
# batch_x: {'x': tensor([[0, 1, 3, 0], | |||
# [2, 2, 4, 5]])} | |||
# batch_y: {'y': array([0, 1])} | |||
# 如果取batch过程含有一些参数,可以通过类来实现 | |||
class ConCollectFn: | |||
def __init__(self, max_len=3): | |||
self.max_len = max_len | |||
def __call__(self, ins_list): # 实现该类的__call__函数 | |||
x1 = [ins['x1'] for ind, ins in ins_list] | |||
x2 = [ins['x2'] for ind, ins in ins_list] | |||
xs = [] | |||
for i in range(len(ins_list)): | |||
xs.append(torch.LongTensor(x1[i] + x2[i])[:self.max_len]) | |||
arr = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=0) | |||
b_x = {'x': arr} | |||
b_y = {} | |||
return b_x, b_y | |||
data.delete_collect_fn() # 删除之前的collect_fn | |||
data.add_collect_fn(ConCollectFn(max_len=3)) | |||
for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2): | |||
print("batch_x:", batch_x) | |||
print("batch_y:", batch_y) | |||
# batch_x: {'x': tensor([[0, 1, 3], | |||
# [2, 2, 4]])} | |||
# batch_y: {'y': array([0, 1])} | |||
""" | |||
__all__ = [ | |||
@@ -300,7 +369,6 @@ from .field import AutoPadder | |||
from .field import FieldArray | |||
from .field import SetInputOrTargetException | |||
from .instance import Instance | |||
from .utils import _get_func_signature | |||
from .utils import pretty_table_printer | |||
from .collect_fn import Collector | |||
@@ -394,6 +462,7 @@ class DataSet(object): | |||
for field in self.field_arrays.values(): | |||
data_set.add_field(field_name=field.name, fields=field.content[idx], padder=field.padder, | |||
is_input=field.is_input, is_target=field.is_target, ignore_type=field.ignore_type) | |||
data_set.collector = self.collector.copy_from(self.collector) | |||
return data_set | |||
elif isinstance(idx, str): | |||
if idx not in self: | |||
@@ -407,6 +476,7 @@ class DataSet(object): | |||
dataset.append(instance) | |||
for field_name, field in self.field_arrays.items(): | |||
dataset.field_arrays[field_name].to(field) | |||
dataset.collector = self.collector.copy_from(self.collector) | |||
return dataset | |||
else: | |||
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | |||
@@ -575,7 +645,6 @@ class DataSet(object): | |||
:param str field_name: 需要删除的field的名称. | |||
""" | |||
self.field_arrays.pop(field_name) | |||
self.collector.drop_field(field_name) | |||
return self | |||
def copy_field(self, field_name, new_field_name): | |||
@@ -648,7 +717,6 @@ class DataSet(object): | |||
if field_name in self.field_arrays: | |||
self.field_arrays[new_field_name] = self.field_arrays.pop(field_name) | |||
self.field_arrays[new_field_name].name = new_field_name | |||
self.collector.rename_field(field_name, new_field_name) | |||
else: | |||
raise KeyError("DataSet has no field named {}.".format(field_name)) | |||
return self | |||
@@ -1040,30 +1108,30 @@ class DataSet(object): | |||
assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d)) | |||
return d | |||
def add_collect_fn(self, fn, inputs, outputs, is_input, is_target): | |||
def add_collect_fn(self, fn, name=None): | |||
""" | |||
添加 CollectFn,使用多个field产生batch中的数据 | |||
添加 CollectFn,collect_fn允许在生成的batch的过程中动态生成一些数据(在DataSetIter作为迭代器的情况下有效,默认情况下就是用的 | |||
这个)。支持依次添加多个collect_fn, 如果相同的key,后面的collect_fn的结果覆盖前面的collect_fn的结果。 | |||
:param CollectFn fn: 定义产生数据的方式 | |||
:param list inputs: 生成的数据在batch中的名称 | |||
:param list outputs: 用于产生数据的 fields,有序 | |||
:param bool is_input: 是否出现在input中,为否则出现在target batch中 | |||
:param bool is_target: | |||
:param callable fn: 传入一个可调用的function, 该function可接受的参数为List[(ind1, instance1), (ind2, instance2)] | |||
(某个batch被选中的所有的indice以及instance),其中ind1/ind2是该instance在dataset中的index,instance1/instance2是 | |||
这次batch取出来的数据,包含了所有的field。返回值需要为两个dict,第一个dict的值将被认为是input,第二个dict的值被认为是 | |||
target,返回的值至多允许一个空dict。若返回的dict中包含了被设置为input或target的field的名称,将覆盖dataset中的field。 | |||
fastNLP不会将collect_fn的返回结果pad和转换为tensor,需要在collect_fn中完成pad和转换为tensor(不需要将tensor移动到 | |||
gpu中,如果是pytorch的tensor,fastNLP会自动将其移动到特定gpu)。不要修改传入collect_fn中的数据,否则可能导致未知问题。 | |||
:param str,int name: collect_fn的名称,如果不传入,默认使用自增长的数字作为key。相同的name会覆盖之前的collect_fn。 | |||
""" | |||
def check_fields(fields): | |||
for f in fields: | |||
if f not in self.field_arrays: | |||
raise ValueError(f) | |||
def check_name(names): | |||
for name in names: | |||
if name in self.field_arrays: | |||
logger.warning('name of collect_fn will cover the field name in dataset') | |||
assert callable(fn), "You must pass in a callable object." | |||
self.collector.add_fn(fn, name=name) | |||
check_fields(inputs) | |||
check_name(outputs) | |||
def delete_collect_fn(self, name=None): | |||
""" | |||
删除某个collect_fn | |||
self.collector.add_fn(fn, inputs, outputs, is_input, is_target) | |||
:param str,int name: 如果为None,则删除最近加入的collect_fn | |||
:return: | |||
""" | |||
self.collector.delete_fn(name) | |||
def _collect_batch(self, ins_list): | |||
return self.collector.collect_batch(ins_list) |
@@ -191,24 +191,31 @@ class FieldArray: | |||
def get(self, indices, pad=True): | |||
""" | |||
根据给定的indices返回内容 | |||
根据给定的indices返回内容。 | |||
:param int,List[int] indices: 获取indices对应的内容。 | |||
:param bool pad: 是否对返回的结果进行padding。仅对indices为List[int]时有效 | |||
:return: 根据给定的indices返回的内容,可能是单个值或List | |||
:param bool pad: 是否对返回的结果进行padding。仅对: (1) indices为List[int]; (2)padder不为None; (3)field设置了input | |||
或target,有效 | |||
:return: 根据给定的indices返回的内容,可能是单个值或ndarray | |||
""" | |||
if isinstance(indices, int): | |||
return self.content[indices] | |||
if self.is_input is False and self.is_target is False: | |||
raise RuntimeError("Please specify either is_input or is_target to True for {}".format(self.name)) | |||
contents = [self.content[i] for i in indices] | |||
if self.padder is None or pad is False: | |||
return np.array(contents) | |||
else: | |||
elif self.is_input or self.is_target: | |||
return self.pad(contents) | |||
else: | |||
return np.array(contents) | |||
def pad(self, contents): | |||
""" | |||
传入list的contents,将contents使用padder进行padding,contents必须为从本FieldArray中取出的。 | |||
:param list contents: | |||
:return: | |||
""" | |||
return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim) | |||
def set_padder(self, padder): | |||
@@ -71,7 +71,7 @@ class Tester(object): | |||
def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1, use_tqdm=True): | |||
""" | |||
:param ~fastNLP.DataSet data: 需要测试的数据集 | |||
:param ~fastNLP.DataSet,~fastNLP.BatchIter data: 需要测试的数据集 | |||
:param torch.nn.Module model: 使用的模型 | |||
:param ~fastNLP.core.metrics.MetricBase,List[~fastNLP.core.metrics.MetricBase] metrics: 测试时使用的metrics | |||
:param int batch_size: evaluation时使用的batch_size有多大。 | |||
@@ -375,7 +375,7 @@ class Trainer(object): | |||
callbacks=None, check_code_level=0, **kwargs): | |||
""" | |||
:param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。 | |||
:param train_data: 训练集, :class:`~fastNLP.DataSet` 类型或 :class:`~fastNLP.BatchIter`的子类 | |||
:param nn.modules model: 待训练的模型 | |||
:param optimizer: `torch.optim.Optimizer` 优化器。如果为None,则Trainer使用默认的Adam(model.parameters(), lr=4e-3)这个优化器 | |||
:param int batch_size: 训练和验证的时候的batch大小。 | |||
@@ -624,9 +624,11 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||
if check_res.unused: | |||
_tmp = f"Check key assignment for `{input_func_map.get(_miss,_miss)}` when initialize {module_name}." | |||
if _tmp: | |||
_tmp += f' Or provide `{_miss}` in DataSet or output of {prev_func_signature}.' | |||
_tmp += f' Or provide `{_miss}` in DataSet or the output of {prev_func_signature}. ' | |||
else: | |||
_tmp = f'Provide `{_miss}` in DataSet or output of {prev_func_signature}.' | |||
_tmp = f'Provide `{_miss}` in DataSet or the output of {prev_func_signature}.' | |||
if not dataset.collector.is_empty(): | |||
_tmp += f'Or you need to add `{_miss}` in the output of your collect_fn. ' | |||
suggestions.append(_tmp) | |||
if check_res.duplicated: | |||
@@ -683,12 +685,11 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): | |||
else: | |||
_miss_out_dataset.append(_miss) | |||
if _miss_in_dataset: | |||
suggestions.append(f"You might need to set {_miss_in_dataset} as input. ") | |||
suggestions.append(f"You might need to set `{_miss_in_dataset}` as input. ") | |||
if _miss_out_dataset: | |||
_tmp = f"You need to provide {_miss_out_dataset} in DataSet and set it as input. " | |||
# if check_res.unused: | |||
# _tmp += f"Or you might find it in `unused field:`, you can use DataSet.rename_field() to " \ | |||
# f"rename the field in `unused field:`." | |||
_tmp = f"You need to provide `{_miss_out_dataset}` in DataSet and set it as input. " | |||
if not dataset.collector.is_empty(): | |||
_tmp += f'Or you need to add `{_miss_out_dataset}` in the output of your collect_fn. ' | |||
suggestions.append(_tmp) | |||
if check_res.unused: | |||
@@ -158,20 +158,130 @@ class TestCase1(unittest.TestCase): | |||
dataset.set_input('1','2') | |||
dataset.set_target('0','3') | |||
fn = ConcatCollectFn() | |||
dataset.add_collect_fn(fn, inputs=['1', '2'], | |||
outputs=['12', 'seq_len'], | |||
is_input=True, is_target=False) | |||
fn = ConcatCollectFn(inputs=['1', '2'], output='12', pad_val=0, max_len=0, is_input=True, is_target=False) | |||
dataset.add_collect_fn(fn, name='demo') | |||
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) | |||
for batch_x, batch_y in batch: | |||
for i in range(batch_size): | |||
# print(i) | |||
self.assertEqual(batch_x['12'][i].sum(), batch_x['1'][i].sum() + batch_x['2'][i].sum()) | |||
self.assertEqual( | |||
batch_x['seq_len'][i], | |||
(batch_x['1'][i]!=0).sum() + (batch_x['2'][i]!=0).sum()) | |||
dataset.delete_collect_fn(name='demo') | |||
# 测试非input的情况 | |||
dataset.set_input('1', '2', flag=False) # | |||
fn = ConcatCollectFn(inputs=['1', '2'], output='12', pad_val=0, max_len=0, is_input=True, is_target=False) | |||
dataset.add_collect_fn(fn, name='demo') | |||
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) | |||
for batch_x, batch_y in batch: | |||
for i in range(batch_size): | |||
self.assertTrue('12' in batch_x) | |||
dataset.delete_collect_fn(name='demo') | |||
dataset.set_input('1', '2', flag=True) # | |||
# 测试覆盖其它field的情况 | |||
fn = ConcatCollectFn(inputs=['1', '2'], output='3', pad_val=0, max_len=0, is_input=True, is_target=True) | |||
dataset.add_collect_fn(fn, name='demo') | |||
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) | |||
for batch_x, batch_y in batch: | |||
for i in range(batch_size): | |||
# print(i) | |||
self.assertEqual(batch_y['3'][i].sum(), batch_x['1'][i].sum() + batch_x['2'][i].sum()) | |||
dataset.delete_collect_fn(name='demo') | |||
# 测试非input,target的情况 | |||
dataset.set_input('1', '2', flag=False) | |||
fn = ConcatCollectFn(inputs=['1', '2'], output='3', pad_val=0, max_len=0, is_input=True, is_target=True) | |||
dataset.add_collect_fn(fn, name='demo') | |||
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) | |||
for batch_x, batch_y in batch: | |||
for i in range(batch_size): | |||
# print(i) | |||
self.assertTrue('3' in batch_x) | |||
self.assertTrue('3' in batch_y) | |||
dataset.delete_collect_fn(name='demo') | |||
# 测试加入非法fn的请 | |||
with self.assertRaises(AssertionError): | |||
dataset.add_collect_fn(1) | |||
# 测试collect_fn返回值只有一个的情况 | |||
def demo_collect_fn(ins_list): | |||
return {'3':1} | |||
dataset.add_collect_fn(demo_collect_fn, name='demo') | |||
with self.assertRaises(BaseException): | |||
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) | |||
for batch_x, batch_y in batch: | |||
pass | |||
dataset.delete_collect_fn(name='demo') | |||
# 测试多个collect_fn | |||
dataset.add_collect_fn(demo_collect_fn, name='demo') | |||
dataset.add_collect_fn(demo_collect_fn, name='demo') | |||
# 测试删除 | |||
dataset.delete_collect_fn() | |||
dataset.delete_collect_fn() | |||
self.assertTrue(dataset.collector.is_empty()) | |||
def test_demo(self): | |||
import torch | |||
data = DataSet({ | |||
'x1': [[0, 1], | |||
[2]], | |||
'x2': [[3], | |||
[2, 4, 5] | |||
], | |||
'y': [0, 1] | |||
}) | |||
data.set_target('y') | |||
# 所有的collect_fn函数都接受list[(ind1, instance1), (ind2, instance2), ...]作为输入,其中ind1/ind2是该instance在dataset中 | |||
# 的index,instance1/instance2是这次batch取出来的数据,包含了所有的field. | |||
def concat_collect_fn(ins_list): | |||
x1 = [ins['x1'] for ind,ins in ins_list] | |||
x2 = [ins['x2'] for ind,ins in ins_list] | |||
xs = [] | |||
for i in range(len(ins_list)): | |||
xs.append(torch.LongTensor(x1[i] + x2[i])) | |||
# 需要自行pad并转换为tensor,但不需要移动到gpu | |||
arr = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=0) | |||
b_x = {'x': arr} | |||
b_y = {} | |||
# 返回值一定是两个dict,第一个dict的值会认为是input,第二个dict的值会认为是target. 若名称与已有input或target重复,则 | |||
# 采用返回值。 | |||
return b_x, b_y | |||
data.add_collect_fn(concat_collect_fn) | |||
for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2): | |||
print("batch_x:", batch_x) | |||
print("batch_y:", batch_y) | |||
# batch_x: {'x': tensor([[0, 1, 3, 0], | |||
# [2, 2, 4, 5]])} | |||
# batch_y: {'y': array([0, 1])} | |||
# 如果取batch过程含有一些参数,可以通过类来实现 | |||
class ConCollectFn: | |||
def __init__(self, max_len=3): | |||
self.max_len = max_len | |||
def __call__(self, ins_list): | |||
x1 = [ins['x1'] for ind, ins in ins_list] | |||
x2 = [ins['x2'] for ind, ins in ins_list] | |||
xs = [] | |||
for i in range(len(ins_list)): | |||
xs.append(torch.LongTensor(x1[i] + x2[i])[:self.max_len]) | |||
arr = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=0) | |||
b_x = {'x': arr} | |||
b_y = {} | |||
return b_x, b_y | |||
data.delete_collect_fn() # 删除之前的collect_fn | |||
data.add_collect_fn(ConCollectFn(max_len=3)) | |||
for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2): | |||
print("batch_x:", batch_x) | |||
print("batch_y:", batch_y) | |||
# batch_x: {'x': tensor([[0, 1, 3], | |||
# [2, 2, 4]])} | |||
# batch_y: {'y': array([0, 1])} | |||
def testTensorLoaderIter(self): | |||
class FakeData: | |||