@@ -21,7 +21,6 @@ __all__ = [ | |||||
"AutoPadder", | "AutoPadder", | ||||
"EngChar2DPadder", | "EngChar2DPadder", | ||||
"CollectFn", | |||||
"ConcatCollectFn", | "ConcatCollectFn", | ||||
"Vocabulary", | "Vocabulary", | ||||
@@ -97,4 +96,4 @@ from .tester import Tester | |||||
from .trainer import Trainer | from .trainer import Trainer | ||||
from .utils import cache_results, seq_len_to_mask, get_seq_len | from .utils import cache_results, seq_len_to_mask, get_seq_len | ||||
from .vocabulary import Vocabulary | from .vocabulary import Vocabulary | ||||
from .collect_fn import CollectFn, ConcatCollectFn | |||||
from .collect_fn import ConcatCollectFn |
@@ -9,17 +9,16 @@ __all__ = [ | |||||
] | ] | ||||
import atexit | import atexit | ||||
from numbers import Number | |||||
import abc | |||||
from numbers import Number | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
import torch.utils.data | import torch.utils.data | ||||
from collections import defaultdict | from collections import defaultdict | ||||
from ._logger import logger | |||||
from .dataset import DataSet | from .dataset import DataSet | ||||
from .sampler import SequentialSampler | from .sampler import SequentialSampler | ||||
from .field import _get_ele_type_and_dim | |||||
_python_is_exit = False | _python_is_exit = False | ||||
@@ -33,6 +32,9 @@ atexit.register(_set_python_is_exit) | |||||
class DataSetGetter: | class DataSetGetter: | ||||
""" | |||||
传递给torch.utils.data.DataLoader获取数据,DataLoder会传入int的idx获取数据(调用这里的__getitem__()函数)。 | |||||
""" | |||||
def __init__(self, dataset: DataSet, as_numpy=False): | def __init__(self, dataset: DataSet, as_numpy=False): | ||||
self.dataset = dataset | self.dataset = dataset | ||||
self.as_numpy = as_numpy | self.as_numpy = as_numpy | ||||
@@ -56,7 +58,6 @@ class DataSetGetter: | |||||
:param batch: [[idx1, x_dict1, y_dict1], [idx2, x_dict2, y_dict2], [xx, xx, xx]] | :param batch: [[idx1, x_dict1, y_dict1], [idx2, x_dict2, y_dict2], [xx, xx, xx]] | ||||
:return: | :return: | ||||
""" | """ | ||||
# TODO 支持在DataSet中定义collate_fn,因为有时候可能需要不同的field之间融合,比如BERT的场景 | |||||
indices = [] | indices = [] | ||||
sin_x, sin_y = defaultdict(list), defaultdict(list) | sin_x, sin_y = defaultdict(list), defaultdict(list) | ||||
for idx, ins in ins_list: | for idx, ins in ins_list: | ||||
@@ -67,24 +68,6 @@ class DataSetGetter: | |||||
if n in self.y_names: | if n in self.y_names: | ||||
sin_y[n].append(v) | sin_y[n].append(v) | ||||
def may_to_tensor(data): | |||||
dtype, dim = _get_ele_type_and_dim(data) | |||||
# print(dtype, type(dtype), str(dtype)) | |||||
if not self.as_numpy: | |||||
try: | |||||
data, flag = _to_tensor(data, dtype) | |||||
except TypeError as e: | |||||
logger.error(f"Field {n} cannot be converted to torch.tensor.") | |||||
raise e | |||||
# if torch.is_tensor(data): | |||||
# str_dtype = str(dtype) | |||||
# if 'float' in str_dtype: | |||||
# data = data.float() | |||||
# elif 'int' in str_dtype: | |||||
# data = data.long() | |||||
# print(data.dtype) | |||||
return data | |||||
def pad(batch_dict): | def pad(batch_dict): | ||||
result = {} | result = {} | ||||
for n, vlist in batch_dict.items(): | for n, vlist in batch_dict.items(): | ||||
@@ -98,25 +81,13 @@ class DataSetGetter: | |||||
sin_x = pad(sin_x) | sin_x = pad(sin_x) | ||||
sin_y = pad(sin_y) | sin_y = pad(sin_y) | ||||
bx, by = self.dataset._collect_batch(ins_list) | |||||
def convert_tensor(batch_dict): | |||||
for n, v in batch_dict.items(): | |||||
batch_dict[n] = may_to_tensor(v) | |||||
# collect_fn replaces single field | |||||
sin_x.update(bx) | |||||
sin_y.update(by) | |||||
convert_tensor(sin_x) | |||||
convert_tensor(sin_y) | |||||
if not self.dataset.collector.is_empty(): | |||||
bx, by = self.dataset._collect_batch(ins_list) | |||||
sin_x.update(bx) | |||||
sin_y.update(by) | |||||
return (indices, sin_x, sin_y) | return (indices, sin_x, sin_y) | ||||
def set_idx_list(self, idx_list): | |||||
if len(idx_list) != len(self.idx_list): | |||||
raise ValueError | |||||
self.idx_list = idx_list | |||||
def __getattr__(self, item): | def __getattr__(self, item): | ||||
if hasattr(self.dataset, item): | if hasattr(self.dataset, item): | ||||
return getattr(self.dataset, item) | return getattr(self.dataset, item) | ||||
@@ -125,6 +96,10 @@ class DataSetGetter: | |||||
class SamplerAdapter(torch.utils.data.Sampler): | class SamplerAdapter(torch.utils.data.Sampler): | ||||
""" | |||||
用于传入torch.utils.data.DataLoader中,DataLoader会调用__iter__()方法获取index(一次只取一个int) | |||||
""" | |||||
def __init__(self, sampler, dataset): | def __init__(self, sampler, dataset): | ||||
super().__init__(dataset) | super().__init__(dataset) | ||||
self.sampler = sampler | self.sampler = sampler | ||||
@@ -138,6 +113,11 @@ class SamplerAdapter(torch.utils.data.Sampler): | |||||
class BatchIter: | class BatchIter: | ||||
""" | |||||
Trainer用于迭代数据的类。继承该类,并实现get_num_batches(), get_batch_indices(), dataset(), num_batches(), | |||||
__iter__()方法。 | |||||
""" | |||||
def __init__(self, dataset, batch_size=1, sampler=None, | def __init__(self, dataset, batch_size=1, sampler=None, | ||||
num_workers=0, pin_memory=False, drop_last=False, | num_workers=0, pin_memory=False, drop_last=False, | ||||
timeout=0, worker_init_fn=None, collate_fn=None): | timeout=0, worker_init_fn=None, collate_fn=None): | ||||
@@ -145,6 +125,8 @@ class BatchIter: | |||||
self.sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset) | self.sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset) | ||||
else: | else: | ||||
self.sampler = sampler | self.sampler = sampler | ||||
# DataLoader的collect_fn输入是List[],里面的元素是dataset[index]返回的结果 | |||||
if collate_fn is None: | if collate_fn is None: | ||||
# pytoch <= 1.1 中不能设置collate_fn=None | # pytoch <= 1.1 中不能设置collate_fn=None | ||||
self.dataiter = torch.utils.data.DataLoader( | self.dataiter = torch.utils.data.DataLoader( | ||||
@@ -160,17 +142,25 @@ class BatchIter: | |||||
timeout=timeout, worker_init_fn=worker_init_fn) | timeout=timeout, worker_init_fn=worker_init_fn) | ||||
# 以sampler的数量为准,因为DistributedSampler的时候每个进程上并不是所有的数据都用上了 | # 以sampler的数量为准,因为DistributedSampler的时候每个进程上并不是所有的数据都用上了 | ||||
self.num_batches = self.get_num_batches(len(self.dataiter.sampler), batch_size, drop_last) | |||||
self._num_batches = self.get_num_batches(len(self.dataiter.sampler), batch_size, drop_last) | |||||
self.batch_size = batch_size | self.batch_size = batch_size | ||||
self.cur_batch_indices = None | self.cur_batch_indices = None | ||||
@property | |||||
def num_batches(self): | |||||
return self._num_batches | |||||
@num_batches.setter | |||||
def num_batches(self, value): | |||||
self._num_batches = value | |||||
def init_iter(self): | def init_iter(self): | ||||
pass | pass | ||||
@staticmethod | @staticmethod | ||||
def get_num_batches(num_samples, batch_size, drop_last): | def get_num_batches(num_samples, batch_size, drop_last): | ||||
""" | """ | ||||
计算batch的数量。 | |||||
计算batch的数量。用于前端显示进度 | |||||
:param int num_samples: | :param int num_samples: | ||||
:param int batch_size: | :param int batch_size: | ||||
@@ -184,7 +174,7 @@ class BatchIter: | |||||
def get_batch_indices(self): | def get_batch_indices(self): | ||||
""" | """ | ||||
获取当前已经输出的batch的index。 | |||||
获取最近输出的batch的index。用于溯源当前batch的数据 | |||||
:return: | :return: | ||||
""" | """ | ||||
@@ -195,8 +185,22 @@ class BatchIter: | |||||
@property | @property | ||||
def dataset(self): | def dataset(self): | ||||
""" | |||||
获取正在参与iterate的dataset | |||||
:return: | |||||
""" | |||||
return self.dataiter.dataset | return self.dataiter.dataset | ||||
@abc.abstractmethod | |||||
def __iter__(self): | |||||
""" | |||||
用于实际数据循环的类,返回值需要为两个dict, 第一个dict中的内容会认为是input, 第二个dict中的内容会认为是target | |||||
:return: | |||||
""" | |||||
raise NotImplemented | |||||
class DataSetIter(BatchIter): | class DataSetIter(BatchIter): | ||||
""" | """ | ||||
@@ -4,7 +4,8 @@ from builtins import sorted | |||||
import torch | import torch | ||||
import numpy as np | import numpy as np | ||||
from .field import _get_ele_type_and_dim | from .field import _get_ele_type_and_dim | ||||
from collections import defaultdict | |||||
from .utils import logger | |||||
from copy import deepcopy | |||||
def _check_type(batch_dict, fields): | def _check_type(batch_dict, fields): | ||||
@@ -36,127 +37,89 @@ def batching(samples, max_len=0, padding_val=0): | |||||
class Collector: | class Collector: | ||||
""" | |||||
辅助DataSet管理collect_fn的类 | |||||
""" | |||||
def __init__(self): | def __init__(self): | ||||
self.fns = {} | |||||
self.input2fn = defaultdict(list) | |||||
self.output2fn = defaultdict(list) | |||||
self.fn2input = {} | |||||
self.fn2output = {} | |||||
def add_fn(self, fn, inputs, outputs, is_input, is_target): | |||||
for name in outputs: | |||||
if name in self.output2fn: | |||||
raise ValueError("Duplicated name: {} for CollectFn: {}".format(name, fn)) | |||||
if fn.num_inputs() > 0 and len(inputs) != fn.num_inputs(): | |||||
raise ValueError( | |||||
"Incorrect num of inputs, should be {} not {}".format( | |||||
fn.num_inputs(), len(inputs) | |||||
)) | |||||
if fn.num_outputs() > 0 and len(outputs) != fn.num_outputs(): | |||||
raise ValueError("Incorrect num of inputs, should be {} not {}".format( | |||||
fn.num_outputs(), len(outputs))) | |||||
self.fns[fn] = {'is_input': is_input, 'is_target': is_target} | |||||
for i, field in enumerate(inputs): | |||||
self.input2fn[field].append((fn, i)) | |||||
for i, name in enumerate(outputs): | |||||
self.output2fn[name].append((fn, i)) | |||||
def _rebuild_fn2io(self): | |||||
def transpose(name2fn): | |||||
fn2names = defaultdict(list) | |||||
for name, vlist in name2fn.items(): | |||||
for fn, i in vlist: | |||||
fn2names[fn].append((name, i)) | |||||
for fn, vlist in fn2names.items(): | |||||
vlist = sorted(vlist, key=lambda x: x[1]) | |||||
fn2names[fn] = [name for name, i in vlist] | |||||
return fn2names | |||||
self.fn2input = transpose(self.input2fn) | |||||
self.fn2output = transpose(self.output2fn) | |||||
def _clear_fn2io(self): | |||||
self.fn2input.clear() | |||||
self.fn2output.clear() | |||||
self.collect_fns = {} | |||||
def add_fn(self, fn, name=None): | |||||
""" | |||||
向collector新增一个collect_fn函数 | |||||
:param callable fn: | |||||
:param str,int name: | |||||
:return: | |||||
""" | |||||
if name in self.collect_fns: | |||||
logger.warn(f"collect_fn:{name} will be overwritten.") | |||||
if name is None: | |||||
name = len(self.collect_fns) | |||||
self.collect_fns[name] = fn | |||||
def is_empty(self): | |||||
""" | |||||
返回是否包含collect_fn | |||||
:return: | |||||
""" | |||||
return len(self.collect_fns)==0 | |||||
def delete_fn(self, name=None): | |||||
""" | |||||
删除collect_fn | |||||
:param str,int name: 如果为None就删除最近加入的collect_fn | |||||
:return: | |||||
""" | |||||
if not self.is_empty(): | |||||
if name in self.collect_fns: | |||||
self.collect_fns.pop(name) | |||||
elif name is None: | |||||
last_key = list(self.collect_fns.keys())[0] | |||||
self.collect_fns.pop(last_key) | |||||
def collect_batch(self, ins_list): | def collect_batch(self, ins_list): | ||||
if len(ins_list) == 0: | |||||
return {}, {} | |||||
if len(self.fn2output) == 0: | |||||
self._rebuild_fn2io() | |||||
bx = {} | |||||
by = {} | |||||
for fn, attr in self.fns.items(): | |||||
inputs = self.fn2input.get(fn, None) | |||||
outputs = self.fn2output.get(fn, None) | |||||
res = fn.collect(ins_list, inputs, outputs) | |||||
if attr.get('is_input', False): | |||||
bx.update(res) | |||||
if attr.get('is_target', False): | |||||
by.update(res) | |||||
bx, by = {}, {} | |||||
for name, fn in self.collect_fns.items(): | |||||
try: | |||||
batch_x, batch_y = fn(ins_list) | |||||
except BaseException as e: | |||||
logger.error(f"Exception:`{e}` happens when call collect_fn:`{name}`.") | |||||
raise e | |||||
bx.update(batch_x) | |||||
by.update(batch_y) | |||||
return bx, by | return bx, by | ||||
def rename_field(self, old_f, new_f): | |||||
if new_f in self.input2fn: | |||||
# name conflict | |||||
raise ValueError | |||||
if old_f not in self.input2fn: | |||||
# renamed field not affect collectors | |||||
return | |||||
self.input2fn[new_f] = self.input2fn[old_f] | |||||
self._clear_fn2io() | |||||
def drop_field(self, f): | |||||
if f in self.input2fn: | |||||
raise ValueError | |||||
def outputs(self): | |||||
return self.output2fn.keys() | |||||
def copy_from(self, col): | def copy_from(self, col): | ||||
assert isinstance(col, Collector) | assert isinstance(col, Collector) | ||||
self.fns = col.fns.copy() | |||||
self.input2fn = col.input2fn.copy() | |||||
self.output2fn = col.output2fn.copy() | |||||
self._clear_fn2io() | |||||
class CollectFn: | |||||
def __init__(self): | |||||
self.fields = [] | |||||
def collect(self, ins_list, inputs, outputs): | |||||
raise NotImplementedError | |||||
new_col = Collector() | |||||
new_col.collect_fns = deepcopy(col) | |||||
return new_col | |||||
def num_inputs(self): | |||||
return 0 | |||||
def num_outputs(self): | |||||
return 0 | |||||
@staticmethod | |||||
def get_batch_size(batch_dict): | |||||
if len(batch_dict) == 0: | |||||
return 0 | |||||
return len(next(iter(batch_dict.values()))) | |||||
class ConcatCollectFn(CollectFn): | |||||
class ConcatCollectFn: | |||||
""" | """ | ||||
field拼接Fn,将不同field按序拼接后,padding产生数据。所有field必须有相同的dim。 | |||||
field拼接collect_fn,将不同field按序拼接后,padding产生数据。 | |||||
:param List[str] inputs: 将哪些field的数据拼接起来, 目前仅支持1d的field | |||||
:param str output: 拼接后的field名称 | |||||
:param pad_val: padding的数值 | :param pad_val: padding的数值 | ||||
:param max_len: 拼接后最大长度 | :param max_len: 拼接后最大长度 | ||||
:param is_input: 是否将生成的output设置为input | |||||
:param is_target: 是否将生成的output设置为target | |||||
""" | """ | ||||
def __init__(self, pad_val=0, max_len=0): | |||||
def __init__(self, inputs, output, pad_val=0, max_len=0, is_input=True, is_target=False): | |||||
super().__init__() | super().__init__() | ||||
assert isinstance(inputs, list) | |||||
self.inputs = inputs | |||||
self.output = output | |||||
self.pad_val = pad_val | self.pad_val = pad_val | ||||
self.max_len = max_len | self.max_len = max_len | ||||
self.is_input = is_input | |||||
self.is_target = is_target | |||||
@staticmethod | @staticmethod | ||||
def _to_numpy(seq): | def _to_numpy(seq): | ||||
@@ -165,21 +128,18 @@ class ConcatCollectFn(CollectFn): | |||||
else: | else: | ||||
return np.array(seq) | return np.array(seq) | ||||
def collect(self, ins_list, inputs, outputs): | |||||
def __call__(self, ins_list): | |||||
samples = [] | samples = [] | ||||
for i, ins in ins_list: | for i, ins in ins_list: | ||||
sample = [] | sample = [] | ||||
for i in inputs: | |||||
sample.append(self._to_numpy(ins[i])) | |||||
for input_name in self.inputs: | |||||
sample.append(self._to_numpy(ins[input_name])) | |||||
samples.append(np.concatenate(sample, axis=0)) | samples.append(np.concatenate(sample, axis=0)) | ||||
seq_len = [s.shape[0] for s in samples] | |||||
batch = batching(samples, max_len=self.max_len, padding_val=self.pad_val) | batch = batching(samples, max_len=self.max_len, padding_val=self.pad_val) | ||||
o1, o2 = outputs | |||||
return {o1: batch, o2: seq_len} | |||||
def num_inputs(self): | |||||
return 0 | |||||
b_x, b_y = {}, {} | |||||
if self.is_input: | |||||
b_x[self.output] = batch | |||||
if self.is_target: | |||||
b_y[self.output] = batch | |||||
def num_outputs(self): | |||||
# (concat_words, seq_len) | |||||
return 2 | |||||
return b_x, b_y |
@@ -281,6 +281,75 @@ | |||||
# 也可以设置pad的value | # 也可以设置pad的value | ||||
dataset.set_pad_val('chars', -1) | dataset.set_pad_val('chars', -1) | ||||
3.3 根据DataSet中多个field合成新的field | |||||
-------------------------------------- | |||||
DataSet支持在进行batch时,默认只能看到当前的field的值,但在某些训练中可能存在以下的情况: (1)需要两个field拼接成为一个field; | |||||
(2)需要在batch中进行负采样。这时候就需要能够同时利用多个field进行batch的操作,DataSet中的add_collect_fn()函数支持添加 | |||||
自定义涉及多个field的collect_fn函数。例如下例中将两个field拼接成一个field的场景 | |||||
.. code-block:: | |||||
from fastNLP import DataSet, DataSetIter | |||||
import torch | |||||
data = DataSet({ | |||||
'x1': [[0, 1], | |||||
[2]], | |||||
'x2': [[3], | |||||
[2, 4, 5]], | |||||
'y': [0, 1] | |||||
}) | |||||
data.set_target('y') | |||||
# 所有的collect_fn函数都接受list[(ind1, instance1), (ind2, instance2), ...]作为输入,其中ind1/ind2是该instance在dataset中 | |||||
# 的index,instance1/instance2是这次batch取出来的数据,包含了所有的field. | |||||
def concat_collect_fn(ins_list): | |||||
x1 = [ins['x1'] for ind,ins in ins_list] | |||||
x2 = [ins['x2'] for ind,ins in ins_list] | |||||
xs = [] | |||||
for i in range(len(ins_list)): | |||||
xs.append(torch.LongTensor(x1[i] + x2[i])) | |||||
# 需要自行pad并转换为tensor,但不需要移动到gpu | |||||
arr = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=0) | |||||
b_x = {'x': arr} | |||||
b_y = {} | |||||
# 返回值一定是两个dict,第一个dict的值会认为是input,第二个dict的值会认为是target. 若名称与已有input或target重复,则 | |||||
# 采用返回值。 | |||||
return b_x, b_y | |||||
data.add_collect_fn(concat_collect_fn) | |||||
for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2): | |||||
print("batch_x:", batch_x) | |||||
print("batch_y:", batch_y) | |||||
# batch_x: {'x': tensor([[0, 1, 3, 0], | |||||
# [2, 2, 4, 5]])} | |||||
# batch_y: {'y': array([0, 1])} | |||||
# 如果取batch过程含有一些参数,可以通过类来实现 | |||||
class ConCollectFn: | |||||
def __init__(self, max_len=3): | |||||
self.max_len = max_len | |||||
def __call__(self, ins_list): # 实现该类的__call__函数 | |||||
x1 = [ins['x1'] for ind, ins in ins_list] | |||||
x2 = [ins['x2'] for ind, ins in ins_list] | |||||
xs = [] | |||||
for i in range(len(ins_list)): | |||||
xs.append(torch.LongTensor(x1[i] + x2[i])[:self.max_len]) | |||||
arr = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=0) | |||||
b_x = {'x': arr} | |||||
b_y = {} | |||||
return b_x, b_y | |||||
data.delete_collect_fn() # 删除之前的collect_fn | |||||
data.add_collect_fn(ConCollectFn(max_len=3)) | |||||
for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2): | |||||
print("batch_x:", batch_x) | |||||
print("batch_y:", batch_y) | |||||
# batch_x: {'x': tensor([[0, 1, 3], | |||||
# [2, 2, 4]])} | |||||
# batch_y: {'y': array([0, 1])} | |||||
""" | """ | ||||
__all__ = [ | __all__ = [ | ||||
@@ -300,7 +369,6 @@ from .field import AutoPadder | |||||
from .field import FieldArray | from .field import FieldArray | ||||
from .field import SetInputOrTargetException | from .field import SetInputOrTargetException | ||||
from .instance import Instance | from .instance import Instance | ||||
from .utils import _get_func_signature | |||||
from .utils import pretty_table_printer | from .utils import pretty_table_printer | ||||
from .collect_fn import Collector | from .collect_fn import Collector | ||||
@@ -394,6 +462,7 @@ class DataSet(object): | |||||
for field in self.field_arrays.values(): | for field in self.field_arrays.values(): | ||||
data_set.add_field(field_name=field.name, fields=field.content[idx], padder=field.padder, | data_set.add_field(field_name=field.name, fields=field.content[idx], padder=field.padder, | ||||
is_input=field.is_input, is_target=field.is_target, ignore_type=field.ignore_type) | is_input=field.is_input, is_target=field.is_target, ignore_type=field.ignore_type) | ||||
data_set.collector = self.collector.copy_from(self.collector) | |||||
return data_set | return data_set | ||||
elif isinstance(idx, str): | elif isinstance(idx, str): | ||||
if idx not in self: | if idx not in self: | ||||
@@ -407,6 +476,7 @@ class DataSet(object): | |||||
dataset.append(instance) | dataset.append(instance) | ||||
for field_name, field in self.field_arrays.items(): | for field_name, field in self.field_arrays.items(): | ||||
dataset.field_arrays[field_name].to(field) | dataset.field_arrays[field_name].to(field) | ||||
dataset.collector = self.collector.copy_from(self.collector) | |||||
return dataset | return dataset | ||||
else: | else: | ||||
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | ||||
@@ -575,7 +645,6 @@ class DataSet(object): | |||||
:param str field_name: 需要删除的field的名称. | :param str field_name: 需要删除的field的名称. | ||||
""" | """ | ||||
self.field_arrays.pop(field_name) | self.field_arrays.pop(field_name) | ||||
self.collector.drop_field(field_name) | |||||
return self | return self | ||||
def copy_field(self, field_name, new_field_name): | def copy_field(self, field_name, new_field_name): | ||||
@@ -648,7 +717,6 @@ class DataSet(object): | |||||
if field_name in self.field_arrays: | if field_name in self.field_arrays: | ||||
self.field_arrays[new_field_name] = self.field_arrays.pop(field_name) | self.field_arrays[new_field_name] = self.field_arrays.pop(field_name) | ||||
self.field_arrays[new_field_name].name = new_field_name | self.field_arrays[new_field_name].name = new_field_name | ||||
self.collector.rename_field(field_name, new_field_name) | |||||
else: | else: | ||||
raise KeyError("DataSet has no field named {}.".format(field_name)) | raise KeyError("DataSet has no field named {}.".format(field_name)) | ||||
return self | return self | ||||
@@ -1040,30 +1108,30 @@ class DataSet(object): | |||||
assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d)) | assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d)) | ||||
return d | return d | ||||
def add_collect_fn(self, fn, inputs, outputs, is_input, is_target): | |||||
def add_collect_fn(self, fn, name=None): | |||||
""" | """ | ||||
添加 CollectFn,使用多个field产生batch中的数据 | |||||
添加 CollectFn,collect_fn允许在生成的batch的过程中动态生成一些数据(在DataSetIter作为迭代器的情况下有效,默认情况下就是用的 | |||||
这个)。支持依次添加多个collect_fn, 如果相同的key,后面的collect_fn的结果覆盖前面的collect_fn的结果。 | |||||
:param CollectFn fn: 定义产生数据的方式 | |||||
:param list inputs: 生成的数据在batch中的名称 | |||||
:param list outputs: 用于产生数据的 fields,有序 | |||||
:param bool is_input: 是否出现在input中,为否则出现在target batch中 | |||||
:param bool is_target: | |||||
:param callable fn: 传入一个可调用的function, 该function可接受的参数为List[(ind1, instance1), (ind2, instance2)] | |||||
(某个batch被选中的所有的indice以及instance),其中ind1/ind2是该instance在dataset中的index,instance1/instance2是 | |||||
这次batch取出来的数据,包含了所有的field。返回值需要为两个dict,第一个dict的值将被认为是input,第二个dict的值被认为是 | |||||
target,返回的值至多允许一个空dict。若返回的dict中包含了被设置为input或target的field的名称,将覆盖dataset中的field。 | |||||
fastNLP不会将collect_fn的返回结果pad和转换为tensor,需要在collect_fn中完成pad和转换为tensor(不需要将tensor移动到 | |||||
gpu中,如果是pytorch的tensor,fastNLP会自动将其移动到特定gpu)。不要修改传入collect_fn中的数据,否则可能导致未知问题。 | |||||
:param str,int name: collect_fn的名称,如果不传入,默认使用自增长的数字作为key。相同的name会覆盖之前的collect_fn。 | |||||
""" | """ | ||||
def check_fields(fields): | |||||
for f in fields: | |||||
if f not in self.field_arrays: | |||||
raise ValueError(f) | |||||
def check_name(names): | |||||
for name in names: | |||||
if name in self.field_arrays: | |||||
logger.warning('name of collect_fn will cover the field name in dataset') | |||||
assert callable(fn), "You must pass in a callable object." | |||||
self.collector.add_fn(fn, name=name) | |||||
check_fields(inputs) | |||||
check_name(outputs) | |||||
def delete_collect_fn(self, name=None): | |||||
""" | |||||
删除某个collect_fn | |||||
self.collector.add_fn(fn, inputs, outputs, is_input, is_target) | |||||
:param str,int name: 如果为None,则删除最近加入的collect_fn | |||||
:return: | |||||
""" | |||||
self.collector.delete_fn(name) | |||||
def _collect_batch(self, ins_list): | def _collect_batch(self, ins_list): | ||||
return self.collector.collect_batch(ins_list) | return self.collector.collect_batch(ins_list) |
@@ -191,24 +191,31 @@ class FieldArray: | |||||
def get(self, indices, pad=True): | def get(self, indices, pad=True): | ||||
""" | """ | ||||
根据给定的indices返回内容 | |||||
根据给定的indices返回内容。 | |||||
:param int,List[int] indices: 获取indices对应的内容。 | :param int,List[int] indices: 获取indices对应的内容。 | ||||
:param bool pad: 是否对返回的结果进行padding。仅对indices为List[int]时有效 | |||||
:return: 根据给定的indices返回的内容,可能是单个值或List | |||||
:param bool pad: 是否对返回的结果进行padding。仅对: (1) indices为List[int]; (2)padder不为None; (3)field设置了input | |||||
或target,有效 | |||||
:return: 根据给定的indices返回的内容,可能是单个值或ndarray | |||||
""" | """ | ||||
if isinstance(indices, int): | if isinstance(indices, int): | ||||
return self.content[indices] | return self.content[indices] | ||||
if self.is_input is False and self.is_target is False: | |||||
raise RuntimeError("Please specify either is_input or is_target to True for {}".format(self.name)) | |||||
contents = [self.content[i] for i in indices] | contents = [self.content[i] for i in indices] | ||||
if self.padder is None or pad is False: | if self.padder is None or pad is False: | ||||
return np.array(contents) | return np.array(contents) | ||||
else: | |||||
elif self.is_input or self.is_target: | |||||
return self.pad(contents) | return self.pad(contents) | ||||
else: | |||||
return np.array(contents) | |||||
def pad(self, contents): | def pad(self, contents): | ||||
""" | |||||
传入list的contents,将contents使用padder进行padding,contents必须为从本FieldArray中取出的。 | |||||
:param list contents: | |||||
:return: | |||||
""" | |||||
return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim) | return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim) | ||||
def set_padder(self, padder): | def set_padder(self, padder): | ||||
@@ -71,7 +71,7 @@ class Tester(object): | |||||
def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1, use_tqdm=True): | def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1, use_tqdm=True): | ||||
""" | """ | ||||
:param ~fastNLP.DataSet data: 需要测试的数据集 | |||||
:param ~fastNLP.DataSet,~fastNLP.BatchIter data: 需要测试的数据集 | |||||
:param torch.nn.Module model: 使用的模型 | :param torch.nn.Module model: 使用的模型 | ||||
:param ~fastNLP.core.metrics.MetricBase,List[~fastNLP.core.metrics.MetricBase] metrics: 测试时使用的metrics | :param ~fastNLP.core.metrics.MetricBase,List[~fastNLP.core.metrics.MetricBase] metrics: 测试时使用的metrics | ||||
:param int batch_size: evaluation时使用的batch_size有多大。 | :param int batch_size: evaluation时使用的batch_size有多大。 | ||||
@@ -375,7 +375,7 @@ class Trainer(object): | |||||
callbacks=None, check_code_level=0, **kwargs): | callbacks=None, check_code_level=0, **kwargs): | ||||
""" | """ | ||||
:param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。 | |||||
:param train_data: 训练集, :class:`~fastNLP.DataSet` 类型或 :class:`~fastNLP.BatchIter`的子类 | |||||
:param nn.modules model: 待训练的模型 | :param nn.modules model: 待训练的模型 | ||||
:param optimizer: `torch.optim.Optimizer` 优化器。如果为None,则Trainer使用默认的Adam(model.parameters(), lr=4e-3)这个优化器 | :param optimizer: `torch.optim.Optimizer` 优化器。如果为None,则Trainer使用默认的Adam(model.parameters(), lr=4e-3)这个优化器 | ||||
:param int batch_size: 训练和验证的时候的batch大小。 | :param int batch_size: 训练和验证的时候的batch大小。 | ||||
@@ -624,9 +624,11 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||||
if check_res.unused: | if check_res.unused: | ||||
_tmp = f"Check key assignment for `{input_func_map.get(_miss,_miss)}` when initialize {module_name}." | _tmp = f"Check key assignment for `{input_func_map.get(_miss,_miss)}` when initialize {module_name}." | ||||
if _tmp: | if _tmp: | ||||
_tmp += f' Or provide `{_miss}` in DataSet or output of {prev_func_signature}.' | |||||
_tmp += f' Or provide `{_miss}` in DataSet or the output of {prev_func_signature}. ' | |||||
else: | else: | ||||
_tmp = f'Provide `{_miss}` in DataSet or output of {prev_func_signature}.' | |||||
_tmp = f'Provide `{_miss}` in DataSet or the output of {prev_func_signature}.' | |||||
if not dataset.collector.is_empty(): | |||||
_tmp += f'Or you need to add `{_miss}` in the output of your collect_fn. ' | |||||
suggestions.append(_tmp) | suggestions.append(_tmp) | ||||
if check_res.duplicated: | if check_res.duplicated: | ||||
@@ -683,12 +685,11 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): | |||||
else: | else: | ||||
_miss_out_dataset.append(_miss) | _miss_out_dataset.append(_miss) | ||||
if _miss_in_dataset: | if _miss_in_dataset: | ||||
suggestions.append(f"You might need to set {_miss_in_dataset} as input. ") | |||||
suggestions.append(f"You might need to set `{_miss_in_dataset}` as input. ") | |||||
if _miss_out_dataset: | if _miss_out_dataset: | ||||
_tmp = f"You need to provide {_miss_out_dataset} in DataSet and set it as input. " | |||||
# if check_res.unused: | |||||
# _tmp += f"Or you might find it in `unused field:`, you can use DataSet.rename_field() to " \ | |||||
# f"rename the field in `unused field:`." | |||||
_tmp = f"You need to provide `{_miss_out_dataset}` in DataSet and set it as input. " | |||||
if not dataset.collector.is_empty(): | |||||
_tmp += f'Or you need to add `{_miss_out_dataset}` in the output of your collect_fn. ' | |||||
suggestions.append(_tmp) | suggestions.append(_tmp) | ||||
if check_res.unused: | if check_res.unused: | ||||
@@ -158,20 +158,130 @@ class TestCase1(unittest.TestCase): | |||||
dataset.set_input('1','2') | dataset.set_input('1','2') | ||||
dataset.set_target('0','3') | dataset.set_target('0','3') | ||||
fn = ConcatCollectFn() | |||||
dataset.add_collect_fn(fn, inputs=['1', '2'], | |||||
outputs=['12', 'seq_len'], | |||||
is_input=True, is_target=False) | |||||
fn = ConcatCollectFn(inputs=['1', '2'], output='12', pad_val=0, max_len=0, is_input=True, is_target=False) | |||||
dataset.add_collect_fn(fn, name='demo') | |||||
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) | batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) | ||||
for batch_x, batch_y in batch: | for batch_x, batch_y in batch: | ||||
for i in range(batch_size): | for i in range(batch_size): | ||||
# print(i) | # print(i) | ||||
self.assertEqual(batch_x['12'][i].sum(), batch_x['1'][i].sum() + batch_x['2'][i].sum()) | self.assertEqual(batch_x['12'][i].sum(), batch_x['1'][i].sum() + batch_x['2'][i].sum()) | ||||
self.assertEqual( | |||||
batch_x['seq_len'][i], | |||||
(batch_x['1'][i]!=0).sum() + (batch_x['2'][i]!=0).sum()) | |||||
dataset.delete_collect_fn(name='demo') | |||||
# 测试非input的情况 | |||||
dataset.set_input('1', '2', flag=False) # | |||||
fn = ConcatCollectFn(inputs=['1', '2'], output='12', pad_val=0, max_len=0, is_input=True, is_target=False) | |||||
dataset.add_collect_fn(fn, name='demo') | |||||
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) | |||||
for batch_x, batch_y in batch: | |||||
for i in range(batch_size): | |||||
self.assertTrue('12' in batch_x) | |||||
dataset.delete_collect_fn(name='demo') | |||||
dataset.set_input('1', '2', flag=True) # | |||||
# 测试覆盖其它field的情况 | |||||
fn = ConcatCollectFn(inputs=['1', '2'], output='3', pad_val=0, max_len=0, is_input=True, is_target=True) | |||||
dataset.add_collect_fn(fn, name='demo') | |||||
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) | |||||
for batch_x, batch_y in batch: | |||||
for i in range(batch_size): | |||||
# print(i) | |||||
self.assertEqual(batch_y['3'][i].sum(), batch_x['1'][i].sum() + batch_x['2'][i].sum()) | |||||
dataset.delete_collect_fn(name='demo') | |||||
# 测试非input,target的情况 | |||||
dataset.set_input('1', '2', flag=False) | |||||
fn = ConcatCollectFn(inputs=['1', '2'], output='3', pad_val=0, max_len=0, is_input=True, is_target=True) | |||||
dataset.add_collect_fn(fn, name='demo') | |||||
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) | |||||
for batch_x, batch_y in batch: | |||||
for i in range(batch_size): | |||||
# print(i) | |||||
self.assertTrue('3' in batch_x) | |||||
self.assertTrue('3' in batch_y) | |||||
dataset.delete_collect_fn(name='demo') | |||||
# 测试加入非法fn的请 | |||||
with self.assertRaises(AssertionError): | |||||
dataset.add_collect_fn(1) | |||||
# 测试collect_fn返回值只有一个的情况 | |||||
def demo_collect_fn(ins_list): | |||||
return {'3':1} | |||||
dataset.add_collect_fn(demo_collect_fn, name='demo') | |||||
with self.assertRaises(BaseException): | |||||
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) | |||||
for batch_x, batch_y in batch: | |||||
pass | |||||
dataset.delete_collect_fn(name='demo') | |||||
# 测试多个collect_fn | |||||
dataset.add_collect_fn(demo_collect_fn, name='demo') | |||||
dataset.add_collect_fn(demo_collect_fn, name='demo') | |||||
# 测试删除 | |||||
dataset.delete_collect_fn() | |||||
dataset.delete_collect_fn() | |||||
self.assertTrue(dataset.collector.is_empty()) | |||||
def test_demo(self): | |||||
import torch | |||||
data = DataSet({ | |||||
'x1': [[0, 1], | |||||
[2]], | |||||
'x2': [[3], | |||||
[2, 4, 5] | |||||
], | |||||
'y': [0, 1] | |||||
}) | |||||
data.set_target('y') | |||||
# 所有的collect_fn函数都接受list[(ind1, instance1), (ind2, instance2), ...]作为输入,其中ind1/ind2是该instance在dataset中 | |||||
# 的index,instance1/instance2是这次batch取出来的数据,包含了所有的field. | |||||
def concat_collect_fn(ins_list): | |||||
x1 = [ins['x1'] for ind,ins in ins_list] | |||||
x2 = [ins['x2'] for ind,ins in ins_list] | |||||
xs = [] | |||||
for i in range(len(ins_list)): | |||||
xs.append(torch.LongTensor(x1[i] + x2[i])) | |||||
# 需要自行pad并转换为tensor,但不需要移动到gpu | |||||
arr = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=0) | |||||
b_x = {'x': arr} | |||||
b_y = {} | |||||
# 返回值一定是两个dict,第一个dict的值会认为是input,第二个dict的值会认为是target. 若名称与已有input或target重复,则 | |||||
# 采用返回值。 | |||||
return b_x, b_y | |||||
data.add_collect_fn(concat_collect_fn) | |||||
for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2): | |||||
print("batch_x:", batch_x) | |||||
print("batch_y:", batch_y) | |||||
# batch_x: {'x': tensor([[0, 1, 3, 0], | |||||
# [2, 2, 4, 5]])} | |||||
# batch_y: {'y': array([0, 1])} | |||||
# 如果取batch过程含有一些参数,可以通过类来实现 | |||||
class ConCollectFn: | |||||
def __init__(self, max_len=3): | |||||
self.max_len = max_len | |||||
def __call__(self, ins_list): | |||||
x1 = [ins['x1'] for ind, ins in ins_list] | |||||
x2 = [ins['x2'] for ind, ins in ins_list] | |||||
xs = [] | |||||
for i in range(len(ins_list)): | |||||
xs.append(torch.LongTensor(x1[i] + x2[i])[:self.max_len]) | |||||
arr = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=0) | |||||
b_x = {'x': arr} | |||||
b_y = {} | |||||
return b_x, b_y | |||||
data.delete_collect_fn() # 删除之前的collect_fn | |||||
data.add_collect_fn(ConCollectFn(max_len=3)) | |||||
for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2): | |||||
print("batch_x:", batch_x) | |||||
print("batch_y:", batch_y) | |||||
# batch_x: {'x': tensor([[0, 1, 3], | |||||
# [2, 2, 4]])} | |||||
# batch_y: {'y': array([0, 1])} | |||||
def testTensorLoaderIter(self): | def testTensorLoaderIter(self): | ||||
class FakeData: | class FakeData: | ||||