From 7f019713214417435b1ff616b4257116953159c6 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Thu, 19 Mar 2020 18:04:13 +0800 Subject: [PATCH] =?UTF-8?q?=E8=B0=83=E6=95=B4=E6=B6=89=E5=8F=8A=E5=88=B0?= =?UTF-8?q?=E5=A4=9A=E4=B8=AAfield=E5=8F=96batch=E7=9A=84=E5=AE=9E?= =?UTF-8?q?=E7=8E=B0=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/__init__.py | 3 +- fastNLP/core/batch.py | 86 +++++++++-------- fastNLP/core/collect_fn.py | 192 +++++++++++++++---------------------- fastNLP/core/dataset.py | 112 +++++++++++++++++----- fastNLP/core/field.py | 21 ++-- fastNLP/core/tester.py | 2 +- fastNLP/core/trainer.py | 2 +- fastNLP/core/utils.py | 15 +-- test/core/test_batch.py | 126 ++++++++++++++++++++++-- 9 files changed, 354 insertions(+), 205 deletions(-) diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index bda9c11e..c09d71d0 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -21,7 +21,6 @@ __all__ = [ "AutoPadder", "EngChar2DPadder", - "CollectFn", "ConcatCollectFn", "Vocabulary", @@ -97,4 +96,4 @@ from .tester import Tester from .trainer import Trainer from .utils import cache_results, seq_len_to_mask, get_seq_len from .vocabulary import Vocabulary -from .collect_fn import CollectFn, ConcatCollectFn +from .collect_fn import ConcatCollectFn diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index 7090ea01..cbc9429d 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -9,17 +9,16 @@ __all__ = [ ] import atexit -from numbers import Number +import abc +from numbers import Number import numpy as np import torch import torch.utils.data from collections import defaultdict -from ._logger import logger from .dataset import DataSet from .sampler import SequentialSampler -from .field import _get_ele_type_and_dim _python_is_exit = False @@ -33,6 +32,9 @@ atexit.register(_set_python_is_exit) class DataSetGetter: + """ + 传递给torch.utils.data.DataLoader获取数据,DataLoder会传入int的idx获取数据(调用这里的__getitem__()函数)。 + """ def __init__(self, dataset: DataSet, as_numpy=False): self.dataset = dataset self.as_numpy = as_numpy @@ -56,7 +58,6 @@ class DataSetGetter: :param batch: [[idx1, x_dict1, y_dict1], [idx2, x_dict2, y_dict2], [xx, xx, xx]] :return: """ - # TODO 支持在DataSet中定义collate_fn,因为有时候可能需要不同的field之间融合,比如BERT的场景 indices = [] sin_x, sin_y = defaultdict(list), defaultdict(list) for idx, ins in ins_list: @@ -67,24 +68,6 @@ class DataSetGetter: if n in self.y_names: sin_y[n].append(v) - def may_to_tensor(data): - dtype, dim = _get_ele_type_and_dim(data) - # print(dtype, type(dtype), str(dtype)) - if not self.as_numpy: - try: - data, flag = _to_tensor(data, dtype) - except TypeError as e: - logger.error(f"Field {n} cannot be converted to torch.tensor.") - raise e - # if torch.is_tensor(data): - # str_dtype = str(dtype) - # if 'float' in str_dtype: - # data = data.float() - # elif 'int' in str_dtype: - # data = data.long() - # print(data.dtype) - return data - def pad(batch_dict): result = {} for n, vlist in batch_dict.items(): @@ -98,25 +81,13 @@ class DataSetGetter: sin_x = pad(sin_x) sin_y = pad(sin_y) - bx, by = self.dataset._collect_batch(ins_list) - def convert_tensor(batch_dict): - for n, v in batch_dict.items(): - batch_dict[n] = may_to_tensor(v) - - # collect_fn replaces single field - sin_x.update(bx) - sin_y.update(by) - - convert_tensor(sin_x) - convert_tensor(sin_y) + if not self.dataset.collector.is_empty(): + bx, by = self.dataset._collect_batch(ins_list) + sin_x.update(bx) + sin_y.update(by) return (indices, sin_x, sin_y) - def set_idx_list(self, idx_list): - if len(idx_list) != len(self.idx_list): - raise ValueError - self.idx_list = idx_list - def __getattr__(self, item): if hasattr(self.dataset, item): return getattr(self.dataset, item) @@ -125,6 +96,10 @@ class DataSetGetter: class SamplerAdapter(torch.utils.data.Sampler): + """ + 用于传入torch.utils.data.DataLoader中,DataLoader会调用__iter__()方法获取index(一次只取一个int) + + """ def __init__(self, sampler, dataset): super().__init__(dataset) self.sampler = sampler @@ -138,6 +113,11 @@ class SamplerAdapter(torch.utils.data.Sampler): class BatchIter: + """ + Trainer用于迭代数据的类。继承该类,并实现get_num_batches(), get_batch_indices(), dataset(), num_batches(), + __iter__()方法。 + + """ def __init__(self, dataset, batch_size=1, sampler=None, num_workers=0, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, collate_fn=None): @@ -145,6 +125,8 @@ class BatchIter: self.sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset) else: self.sampler = sampler + + # DataLoader的collect_fn输入是List[],里面的元素是dataset[index]返回的结果 if collate_fn is None: # pytoch <= 1.1 中不能设置collate_fn=None self.dataiter = torch.utils.data.DataLoader( @@ -160,17 +142,25 @@ class BatchIter: timeout=timeout, worker_init_fn=worker_init_fn) # 以sampler的数量为准,因为DistributedSampler的时候每个进程上并不是所有的数据都用上了 - self.num_batches = self.get_num_batches(len(self.dataiter.sampler), batch_size, drop_last) + self._num_batches = self.get_num_batches(len(self.dataiter.sampler), batch_size, drop_last) self.batch_size = batch_size self.cur_batch_indices = None + @property + def num_batches(self): + return self._num_batches + + @num_batches.setter + def num_batches(self, value): + self._num_batches = value + def init_iter(self): pass @staticmethod def get_num_batches(num_samples, batch_size, drop_last): """ - 计算batch的数量。 + 计算batch的数量。用于前端显示进度 :param int num_samples: :param int batch_size: @@ -184,7 +174,7 @@ class BatchIter: def get_batch_indices(self): """ - 获取当前已经输出的batch的index。 + 获取最近输出的batch的index。用于溯源当前batch的数据 :return: """ @@ -195,8 +185,22 @@ class BatchIter: @property def dataset(self): + """ + 获取正在参与iterate的dataset + + :return: + """ return self.dataiter.dataset + @abc.abstractmethod + def __iter__(self): + """ + 用于实际数据循环的类,返回值需要为两个dict, 第一个dict中的内容会认为是input, 第二个dict中的内容会认为是target + + :return: + """ + raise NotImplemented + class DataSetIter(BatchIter): """ diff --git a/fastNLP/core/collect_fn.py b/fastNLP/core/collect_fn.py index 29f19e2c..d80db154 100644 --- a/fastNLP/core/collect_fn.py +++ b/fastNLP/core/collect_fn.py @@ -4,7 +4,8 @@ from builtins import sorted import torch import numpy as np from .field import _get_ele_type_and_dim -from collections import defaultdict +from .utils import logger +from copy import deepcopy def _check_type(batch_dict, fields): @@ -36,127 +37,89 @@ def batching(samples, max_len=0, padding_val=0): class Collector: + """ + 辅助DataSet管理collect_fn的类 + + """ def __init__(self): - self.fns = {} - self.input2fn = defaultdict(list) - self.output2fn = defaultdict(list) - self.fn2input = {} - self.fn2output = {} - - def add_fn(self, fn, inputs, outputs, is_input, is_target): - for name in outputs: - if name in self.output2fn: - raise ValueError("Duplicated name: {} for CollectFn: {}".format(name, fn)) - - if fn.num_inputs() > 0 and len(inputs) != fn.num_inputs(): - raise ValueError( - "Incorrect num of inputs, should be {} not {}".format( - fn.num_inputs(), len(inputs) - )) - - if fn.num_outputs() > 0 and len(outputs) != fn.num_outputs(): - raise ValueError("Incorrect num of inputs, should be {} not {}".format( - fn.num_outputs(), len(outputs))) - - self.fns[fn] = {'is_input': is_input, 'is_target': is_target} - for i, field in enumerate(inputs): - self.input2fn[field].append((fn, i)) - for i, name in enumerate(outputs): - self.output2fn[name].append((fn, i)) - - def _rebuild_fn2io(self): - def transpose(name2fn): - fn2names = defaultdict(list) - for name, vlist in name2fn.items(): - for fn, i in vlist: - fn2names[fn].append((name, i)) - for fn, vlist in fn2names.items(): - vlist = sorted(vlist, key=lambda x: x[1]) - fn2names[fn] = [name for name, i in vlist] - return fn2names - - self.fn2input = transpose(self.input2fn) - self.fn2output = transpose(self.output2fn) - - def _clear_fn2io(self): - self.fn2input.clear() - self.fn2output.clear() + self.collect_fns = {} + + def add_fn(self, fn, name=None): + """ + 向collector新增一个collect_fn函数 + + :param callable fn: + :param str,int name: + :return: + """ + if name in self.collect_fns: + logger.warn(f"collect_fn:{name} will be overwritten.") + if name is None: + name = len(self.collect_fns) + self.collect_fns[name] = fn + + def is_empty(self): + """ + 返回是否包含collect_fn + + :return: + """ + return len(self.collect_fns)==0 + + def delete_fn(self, name=None): + """ + 删除collect_fn + + :param str,int name: 如果为None就删除最近加入的collect_fn + :return: + """ + if not self.is_empty(): + if name in self.collect_fns: + self.collect_fns.pop(name) + elif name is None: + last_key = list(self.collect_fns.keys())[0] + self.collect_fns.pop(last_key) def collect_batch(self, ins_list): - if len(ins_list) == 0: - return {}, {} - - if len(self.fn2output) == 0: - self._rebuild_fn2io() - - bx = {} - by = {} - for fn, attr in self.fns.items(): - inputs = self.fn2input.get(fn, None) - outputs = self.fn2output.get(fn, None) - res = fn.collect(ins_list, inputs, outputs) - if attr.get('is_input', False): - bx.update(res) - if attr.get('is_target', False): - by.update(res) + bx, by = {}, {} + for name, fn in self.collect_fns.items(): + try: + batch_x, batch_y = fn(ins_list) + except BaseException as e: + logger.error(f"Exception:`{e}` happens when call collect_fn:`{name}`.") + raise e + bx.update(batch_x) + by.update(batch_y) return bx, by - def rename_field(self, old_f, new_f): - if new_f in self.input2fn: - # name conflict - raise ValueError - if old_f not in self.input2fn: - # renamed field not affect collectors - return - self.input2fn[new_f] = self.input2fn[old_f] - self._clear_fn2io() - - def drop_field(self, f): - if f in self.input2fn: - raise ValueError - - def outputs(self): - return self.output2fn.keys() - def copy_from(self, col): assert isinstance(col, Collector) - self.fns = col.fns.copy() - self.input2fn = col.input2fn.copy() - self.output2fn = col.output2fn.copy() - self._clear_fn2io() - -class CollectFn: - def __init__(self): - self.fields = [] - - def collect(self, ins_list, inputs, outputs): - raise NotImplementedError + new_col = Collector() + new_col.collect_fns = deepcopy(col) + return new_col - def num_inputs(self): - return 0 - def num_outputs(self): - return 0 - - @staticmethod - def get_batch_size(batch_dict): - if len(batch_dict) == 0: - return 0 - return len(next(iter(batch_dict.values()))) - - -class ConcatCollectFn(CollectFn): +class ConcatCollectFn: """ - field拼接Fn,将不同field按序拼接后,padding产生数据。所有field必须有相同的dim。 + field拼接collect_fn,将不同field按序拼接后,padding产生数据。 + :param List[str] inputs: 将哪些field的数据拼接起来, 目前仅支持1d的field + :param str output: 拼接后的field名称 :param pad_val: padding的数值 :param max_len: 拼接后最大长度 + :param is_input: 是否将生成的output设置为input + :param is_target: 是否将生成的output设置为target """ - def __init__(self, pad_val=0, max_len=0): + def __init__(self, inputs, output, pad_val=0, max_len=0, is_input=True, is_target=False): super().__init__() + assert isinstance(inputs, list) + self.inputs = inputs + self.output = output self.pad_val = pad_val self.max_len = max_len + self.is_input = is_input + self.is_target = is_target @staticmethod def _to_numpy(seq): @@ -165,21 +128,18 @@ class ConcatCollectFn(CollectFn): else: return np.array(seq) - def collect(self, ins_list, inputs, outputs): + def __call__(self, ins_list): samples = [] for i, ins in ins_list: sample = [] - for i in inputs: - sample.append(self._to_numpy(ins[i])) + for input_name in self.inputs: + sample.append(self._to_numpy(ins[input_name])) samples.append(np.concatenate(sample, axis=0)) - seq_len = [s.shape[0] for s in samples] batch = batching(samples, max_len=self.max_len, padding_val=self.pad_val) - o1, o2 = outputs - return {o1: batch, o2: seq_len} - - def num_inputs(self): - return 0 + b_x, b_y = {}, {} + if self.is_input: + b_x[self.output] = batch + if self.is_target: + b_y[self.output] = batch - def num_outputs(self): - # (concat_words, seq_len) - return 2 + return b_x, b_y diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 74c0023c..8547f30c 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -281,6 +281,75 @@ # 也可以设置pad的value dataset.set_pad_val('chars', -1) +3.3 根据DataSet中多个field合成新的field +-------------------------------------- + + DataSet支持在进行batch时,默认只能看到当前的field的值,但在某些训练中可能存在以下的情况: (1)需要两个field拼接成为一个field; + (2)需要在batch中进行负采样。这时候就需要能够同时利用多个field进行batch的操作,DataSet中的add_collect_fn()函数支持添加 + 自定义涉及多个field的collect_fn函数。例如下例中将两个field拼接成一个field的场景 + + .. code-block:: + + from fastNLP import DataSet, DataSetIter + import torch + + data = DataSet({ + 'x1': [[0, 1], + [2]], + 'x2': [[3], + [2, 4, 5]], + 'y': [0, 1] + }) + data.set_target('y') + + # 所有的collect_fn函数都接受list[(ind1, instance1), (ind2, instance2), ...]作为输入,其中ind1/ind2是该instance在dataset中 + # 的index,instance1/instance2是这次batch取出来的数据,包含了所有的field. + def concat_collect_fn(ins_list): + x1 = [ins['x1'] for ind,ins in ins_list] + x2 = [ins['x2'] for ind,ins in ins_list] + xs = [] + for i in range(len(ins_list)): + xs.append(torch.LongTensor(x1[i] + x2[i])) + # 需要自行pad并转换为tensor,但不需要移动到gpu + arr = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=0) + b_x = {'x': arr} + b_y = {} + # 返回值一定是两个dict,第一个dict的值会认为是input,第二个dict的值会认为是target. 若名称与已有input或target重复,则 + # 采用返回值。 + return b_x, b_y + + data.add_collect_fn(concat_collect_fn) + + for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2): + print("batch_x:", batch_x) + print("batch_y:", batch_y) + # batch_x: {'x': tensor([[0, 1, 3, 0], + # [2, 2, 4, 5]])} + # batch_y: {'y': array([0, 1])} + + # 如果取batch过程含有一些参数,可以通过类来实现 + class ConCollectFn: + def __init__(self, max_len=3): + self.max_len = max_len + + def __call__(self, ins_list): # 实现该类的__call__函数 + x1 = [ins['x1'] for ind, ins in ins_list] + x2 = [ins['x2'] for ind, ins in ins_list] + xs = [] + for i in range(len(ins_list)): + xs.append(torch.LongTensor(x1[i] + x2[i])[:self.max_len]) + arr = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=0) + b_x = {'x': arr} + b_y = {} + return b_x, b_y + data.delete_collect_fn() # 删除之前的collect_fn + data.add_collect_fn(ConCollectFn(max_len=3)) + for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2): + print("batch_x:", batch_x) + print("batch_y:", batch_y) + # batch_x: {'x': tensor([[0, 1, 3], + # [2, 2, 4]])} + # batch_y: {'y': array([0, 1])} """ __all__ = [ @@ -300,7 +369,6 @@ from .field import AutoPadder from .field import FieldArray from .field import SetInputOrTargetException from .instance import Instance -from .utils import _get_func_signature from .utils import pretty_table_printer from .collect_fn import Collector @@ -394,6 +462,7 @@ class DataSet(object): for field in self.field_arrays.values(): data_set.add_field(field_name=field.name, fields=field.content[idx], padder=field.padder, is_input=field.is_input, is_target=field.is_target, ignore_type=field.ignore_type) + data_set.collector = self.collector.copy_from(self.collector) return data_set elif isinstance(idx, str): if idx not in self: @@ -407,6 +476,7 @@ class DataSet(object): dataset.append(instance) for field_name, field in self.field_arrays.items(): dataset.field_arrays[field_name].to(field) + dataset.collector = self.collector.copy_from(self.collector) return dataset else: raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) @@ -575,7 +645,6 @@ class DataSet(object): :param str field_name: 需要删除的field的名称. """ self.field_arrays.pop(field_name) - self.collector.drop_field(field_name) return self def copy_field(self, field_name, new_field_name): @@ -648,7 +717,6 @@ class DataSet(object): if field_name in self.field_arrays: self.field_arrays[new_field_name] = self.field_arrays.pop(field_name) self.field_arrays[new_field_name].name = new_field_name - self.collector.rename_field(field_name, new_field_name) else: raise KeyError("DataSet has no field named {}.".format(field_name)) return self @@ -1040,30 +1108,30 @@ class DataSet(object): assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d)) return d - def add_collect_fn(self, fn, inputs, outputs, is_input, is_target): + def add_collect_fn(self, fn, name=None): """ - 添加 CollectFn,使用多个field产生batch中的数据 + 添加 CollectFn,collect_fn允许在生成的batch的过程中动态生成一些数据(在DataSetIter作为迭代器的情况下有效,默认情况下就是用的 + 这个)。支持依次添加多个collect_fn, 如果相同的key,后面的collect_fn的结果覆盖前面的collect_fn的结果。 - :param CollectFn fn: 定义产生数据的方式 - :param list inputs: 生成的数据在batch中的名称 - :param list outputs: 用于产生数据的 fields,有序 - :param bool is_input: 是否出现在input中,为否则出现在target batch中 - :param bool is_target: + :param callable fn: 传入一个可调用的function, 该function可接受的参数为List[(ind1, instance1), (ind2, instance2)] + (某个batch被选中的所有的indice以及instance),其中ind1/ind2是该instance在dataset中的index,instance1/instance2是 + 这次batch取出来的数据,包含了所有的field。返回值需要为两个dict,第一个dict的值将被认为是input,第二个dict的值被认为是 + target,返回的值至多允许一个空dict。若返回的dict中包含了被设置为input或target的field的名称,将覆盖dataset中的field。 + fastNLP不会将collect_fn的返回结果pad和转换为tensor,需要在collect_fn中完成pad和转换为tensor(不需要将tensor移动到 + gpu中,如果是pytorch的tensor,fastNLP会自动将其移动到特定gpu)。不要修改传入collect_fn中的数据,否则可能导致未知问题。 + :param str,int name: collect_fn的名称,如果不传入,默认使用自增长的数字作为key。相同的name会覆盖之前的collect_fn。 """ - def check_fields(fields): - for f in fields: - if f not in self.field_arrays: - raise ValueError(f) - - def check_name(names): - for name in names: - if name in self.field_arrays: - logger.warning('name of collect_fn will cover the field name in dataset') + assert callable(fn), "You must pass in a callable object." + self.collector.add_fn(fn, name=name) - check_fields(inputs) - check_name(outputs) + def delete_collect_fn(self, name=None): + """ + 删除某个collect_fn - self.collector.add_fn(fn, inputs, outputs, is_input, is_target) + :param str,int name: 如果为None,则删除最近加入的collect_fn + :return: + """ + self.collector.delete_fn(name) def _collect_batch(self, ins_list): return self.collector.collect_batch(ins_list) diff --git a/fastNLP/core/field.py b/fastNLP/core/field.py index 1835bafa..5553322a 100644 --- a/fastNLP/core/field.py +++ b/fastNLP/core/field.py @@ -191,24 +191,31 @@ class FieldArray: def get(self, indices, pad=True): """ - 根据给定的indices返回内容 + 根据给定的indices返回内容。 :param int,List[int] indices: 获取indices对应的内容。 - :param bool pad: 是否对返回的结果进行padding。仅对indices为List[int]时有效 - :return: 根据给定的indices返回的内容,可能是单个值或List + :param bool pad: 是否对返回的结果进行padding。仅对: (1) indices为List[int]; (2)padder不为None; (3)field设置了input + 或target,有效 + :return: 根据给定的indices返回的内容,可能是单个值或ndarray """ if isinstance(indices, int): return self.content[indices] - if self.is_input is False and self.is_target is False: - raise RuntimeError("Please specify either is_input or is_target to True for {}".format(self.name)) - + contents = [self.content[i] for i in indices] if self.padder is None or pad is False: return np.array(contents) - else: + elif self.is_input or self.is_target: return self.pad(contents) + else: + return np.array(contents) def pad(self, contents): + """ + 传入list的contents,将contents使用padder进行padding,contents必须为从本FieldArray中取出的。 + + :param list contents: + :return: + """ return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim) def set_padder(self, padder): diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index be76a29c..78bdc478 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -71,7 +71,7 @@ class Tester(object): def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1, use_tqdm=True): """ - :param ~fastNLP.DataSet data: 需要测试的数据集 + :param ~fastNLP.DataSet,~fastNLP.BatchIter data: 需要测试的数据集 :param torch.nn.Module model: 使用的模型 :param ~fastNLP.core.metrics.MetricBase,List[~fastNLP.core.metrics.MetricBase] metrics: 测试时使用的metrics :param int batch_size: evaluation时使用的batch_size有多大。 diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 27e1e2b1..8f200ad8 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -375,7 +375,7 @@ class Trainer(object): callbacks=None, check_code_level=0, **kwargs): """ - :param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。 + :param train_data: 训练集, :class:`~fastNLP.DataSet` 类型或 :class:`~fastNLP.BatchIter`的子类 :param nn.modules model: 待训练的模型 :param optimizer: `torch.optim.Optimizer` 优化器。如果为None,则Trainer使用默认的Adam(model.parameters(), lr=4e-3)这个优化器 :param int batch_size: 训练和验证的时候的batch大小。 diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index b1d5f4e2..05722c48 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -624,9 +624,11 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re if check_res.unused: _tmp = f"Check key assignment for `{input_func_map.get(_miss,_miss)}` when initialize {module_name}." if _tmp: - _tmp += f' Or provide `{_miss}` in DataSet or output of {prev_func_signature}.' + _tmp += f' Or provide `{_miss}` in DataSet or the output of {prev_func_signature}. ' else: - _tmp = f'Provide `{_miss}` in DataSet or output of {prev_func_signature}.' + _tmp = f'Provide `{_miss}` in DataSet or the output of {prev_func_signature}.' + if not dataset.collector.is_empty(): + _tmp += f'Or you need to add `{_miss}` in the output of your collect_fn. ' suggestions.append(_tmp) if check_res.duplicated: @@ -683,12 +685,11 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): else: _miss_out_dataset.append(_miss) if _miss_in_dataset: - suggestions.append(f"You might need to set {_miss_in_dataset} as input. ") + suggestions.append(f"You might need to set `{_miss_in_dataset}` as input. ") if _miss_out_dataset: - _tmp = f"You need to provide {_miss_out_dataset} in DataSet and set it as input. " - # if check_res.unused: - # _tmp += f"Or you might find it in `unused field:`, you can use DataSet.rename_field() to " \ - # f"rename the field in `unused field:`." + _tmp = f"You need to provide `{_miss_out_dataset}` in DataSet and set it as input. " + if not dataset.collector.is_empty(): + _tmp += f'Or you need to add `{_miss_out_dataset}` in the output of your collect_fn. ' suggestions.append(_tmp) if check_res.unused: diff --git a/test/core/test_batch.py b/test/core/test_batch.py index dfdf28b7..ee61e239 100644 --- a/test/core/test_batch.py +++ b/test/core/test_batch.py @@ -158,20 +158,130 @@ class TestCase1(unittest.TestCase): dataset.set_input('1','2') dataset.set_target('0','3') - fn = ConcatCollectFn() - dataset.add_collect_fn(fn, inputs=['1', '2'], - outputs=['12', 'seq_len'], - is_input=True, is_target=False) - + fn = ConcatCollectFn(inputs=['1', '2'], output='12', pad_val=0, max_len=0, is_input=True, is_target=False) + dataset.add_collect_fn(fn, name='demo') batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) for batch_x, batch_y in batch: for i in range(batch_size): # print(i) self.assertEqual(batch_x['12'][i].sum(), batch_x['1'][i].sum() + batch_x['2'][i].sum()) - self.assertEqual( - batch_x['seq_len'][i], - (batch_x['1'][i]!=0).sum() + (batch_x['2'][i]!=0).sum()) + dataset.delete_collect_fn(name='demo') + + # 测试非input的情况 + dataset.set_input('1', '2', flag=False) # + fn = ConcatCollectFn(inputs=['1', '2'], output='12', pad_val=0, max_len=0, is_input=True, is_target=False) + dataset.add_collect_fn(fn, name='demo') + batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) + for batch_x, batch_y in batch: + for i in range(batch_size): + self.assertTrue('12' in batch_x) + dataset.delete_collect_fn(name='demo') + dataset.set_input('1', '2', flag=True) # + + # 测试覆盖其它field的情况 + fn = ConcatCollectFn(inputs=['1', '2'], output='3', pad_val=0, max_len=0, is_input=True, is_target=True) + dataset.add_collect_fn(fn, name='demo') + batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) + for batch_x, batch_y in batch: + for i in range(batch_size): + # print(i) + self.assertEqual(batch_y['3'][i].sum(), batch_x['1'][i].sum() + batch_x['2'][i].sum()) + dataset.delete_collect_fn(name='demo') + + # 测试非input,target的情况 + dataset.set_input('1', '2', flag=False) + fn = ConcatCollectFn(inputs=['1', '2'], output='3', pad_val=0, max_len=0, is_input=True, is_target=True) + dataset.add_collect_fn(fn, name='demo') + batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) + for batch_x, batch_y in batch: + for i in range(batch_size): + # print(i) + self.assertTrue('3' in batch_x) + self.assertTrue('3' in batch_y) + dataset.delete_collect_fn(name='demo') + + # 测试加入非法fn的请 + with self.assertRaises(AssertionError): + dataset.add_collect_fn(1) + + # 测试collect_fn返回值只有一个的情况 + def demo_collect_fn(ins_list): + return {'3':1} + dataset.add_collect_fn(demo_collect_fn, name='demo') + with self.assertRaises(BaseException): + batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) + for batch_x, batch_y in batch: + pass + dataset.delete_collect_fn(name='demo') + + # 测试多个collect_fn + dataset.add_collect_fn(demo_collect_fn, name='demo') + dataset.add_collect_fn(demo_collect_fn, name='demo') + # 测试删除 + dataset.delete_collect_fn() + dataset.delete_collect_fn() + self.assertTrue(dataset.collector.is_empty()) + + def test_demo(self): + import torch + + data = DataSet({ + 'x1': [[0, 1], + [2]], + 'x2': [[3], + [2, 4, 5] + ], + 'y': [0, 1] + }) + data.set_target('y') + + # 所有的collect_fn函数都接受list[(ind1, instance1), (ind2, instance2), ...]作为输入,其中ind1/ind2是该instance在dataset中 + # 的index,instance1/instance2是这次batch取出来的数据,包含了所有的field. + def concat_collect_fn(ins_list): + x1 = [ins['x1'] for ind,ins in ins_list] + x2 = [ins['x2'] for ind,ins in ins_list] + xs = [] + for i in range(len(ins_list)): + xs.append(torch.LongTensor(x1[i] + x2[i])) + # 需要自行pad并转换为tensor,但不需要移动到gpu + arr = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=0) + b_x = {'x': arr} + b_y = {} + # 返回值一定是两个dict,第一个dict的值会认为是input,第二个dict的值会认为是target. 若名称与已有input或target重复,则 + # 采用返回值。 + return b_x, b_y + + data.add_collect_fn(concat_collect_fn) + + for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2): + print("batch_x:", batch_x) + print("batch_y:", batch_y) + # batch_x: {'x': tensor([[0, 1, 3, 0], + # [2, 2, 4, 5]])} + # batch_y: {'y': array([0, 1])} + # 如果取batch过程含有一些参数,可以通过类来实现 + class ConCollectFn: + def __init__(self, max_len=3): + self.max_len = max_len + def __call__(self, ins_list): + x1 = [ins['x1'] for ind, ins in ins_list] + x2 = [ins['x2'] for ind, ins in ins_list] + xs = [] + for i in range(len(ins_list)): + xs.append(torch.LongTensor(x1[i] + x2[i])[:self.max_len]) + arr = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=0) + b_x = {'x': arr} + b_y = {} + return b_x, b_y + data.delete_collect_fn() # 删除之前的collect_fn + data.add_collect_fn(ConCollectFn(max_len=3)) + for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2): + print("batch_x:", batch_x) + print("batch_y:", batch_y) + # batch_x: {'x': tensor([[0, 1, 3], + # [2, 2, 4]])} + # batch_y: {'y': array([0, 1])} def testTensorLoaderIter(self): class FakeData: