From ba645f448384d9ca69d3c4f5d8baf7dd333dc97f Mon Sep 17 00:00:00 2001 From: yh_cc Date: Sat, 4 Apr 2020 14:52:03 +0800 Subject: [PATCH] =?UTF-8?q?1.=E5=B0=86collect=5Ffn=E4=BF=AE=E6=94=B9?= =?UTF-8?q?=E4=B8=BAcollate=5Ffn=E4=BD=BF=E5=BE=97=E4=B8=8Epytorch?= =?UTF-8?q?=E4=BF=9D=E6=8C=81=E4=B8=80=E8=87=B4;=202.=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E5=AF=B9BatchSampler=E7=9A=84=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/__init__.py | 4 +- fastNLP/core/__init__.py | 4 +- fastNLP/core/batch.py | 65 ++++++--- fastNLP/core/{collect_fn.py => collate_fn.py} | 48 +++---- fastNLP/core/dataset.py | 54 +++---- fastNLP/core/utils.py | 8 +- fastNLP/io/data_bundle.py | 12 +- test/core/test_batch.py | 135 +++++++++++++----- test/core/test_trainer.py | 77 ++++++++-- 9 files changed, 275 insertions(+), 132 deletions(-) rename fastNLP/core/{collect_fn.py => collate_fn.py} (77%) diff --git a/fastNLP/__init__.py b/fastNLP/__init__.py index 53517da0..a9d7efe7 100644 --- a/fastNLP/__init__.py +++ b/fastNLP/__init__.py @@ -44,8 +44,8 @@ __all__ = [ "AutoPadder", "EngChar2DPadder", - # "CollectFn", - "ConcatCollectFn", + # "CollateFn", + "ConcatCollateFn", "MetricBase", "AccuracyMetric", diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index 89c6558d..9f61ae0c 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -21,7 +21,7 @@ __all__ = [ "AutoPadder", "EngChar2DPadder", - "ConcatCollectFn", + "ConcatCollateFn", "Vocabulary", @@ -99,4 +99,4 @@ from .tester import Tester from .trainer import Trainer from .utils import cache_results, seq_len_to_mask, get_seq_len from .vocabulary import Vocabulary -from .collect_fn import ConcatCollectFn +from .collate_fn import ConcatCollateFn diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index 1cebd4bd..7c1e64ee 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -18,7 +18,7 @@ import torch.utils.data from collections import defaultdict from .dataset import DataSet -from .sampler import SequentialSampler +from .sampler import SequentialSampler, Sampler from ._logger import logger @@ -89,8 +89,8 @@ class DataSetGetter: sin_x = _pad(sin_x, dataset=self.dataset, as_numpy=self.as_numpy) sin_y = _pad(sin_y, dataset=self.dataset, as_numpy=self.as_numpy) - if not self.dataset.collector.is_empty(): - bx, by = self.dataset._collect_batch(ins_list) + if not self.dataset.collater.is_empty(): + bx, by = self.dataset._collate_batch(ins_list) sin_x.update(bx) sin_y.update(by) @@ -127,29 +127,35 @@ class BatchIter: """ def __init__(self, dataset, batch_size=1, sampler=None, num_workers=0, pin_memory=False, drop_last=False, - timeout=0, worker_init_fn=None, collate_fn=None): - if not isinstance(sampler, torch.utils.data.Sampler): - self.sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset) - else: - self.sampler = sampler + timeout=0, worker_init_fn=None, collate_fn=None, + batch_sampler=None): + if isinstance(sampler, Sampler): # 如果时fastNLP的sampler需要adapt一下 + sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset) + self.sampler = sampler + self.batch_sampler = batch_sampler - # DataLoader的collect_fn输入是List[],里面的元素是dataset[index]返回的结果 + # DataLoader的collate_fn输入是List[],里面的元素是dataset[index]返回的结果 if collate_fn is None: # pytoch <= 1.1 中不能设置collate_fn=None self.dataiter = torch.utils.data.DataLoader( dataset=dataset, batch_size=batch_size, sampler=self.sampler, num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last, - timeout=timeout, worker_init_fn=worker_init_fn) + timeout=timeout, worker_init_fn=worker_init_fn, + batch_sampler=batch_sampler) else: self.dataiter = torch.utils.data.DataLoader( dataset=dataset, batch_size=batch_size, sampler=self.sampler, collate_fn=collate_fn, num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last, - timeout=timeout, worker_init_fn=worker_init_fn) + timeout=timeout, worker_init_fn=worker_init_fn, + batch_sampler=batch_sampler) # 以sampler的数量为准,因为DistributedSampler的时候每个进程上并不是所有的数据都用上了 - self._num_batches = self.get_num_batches(len(self.dataiter.sampler), batch_size, drop_last) + if self.batch_sampler is None: + self._num_batches = self.get_num_batches(len(self.dataiter.sampler), batch_size, drop_last) + else: + self._num_batches = len(self.batch_sampler) self.batch_size = batch_size self.cur_batch_indices = None @@ -222,7 +228,8 @@ class DataSetIter(BatchIter): """ def __init__(self, dataset, batch_size=1, sampler=None, as_numpy=False, num_workers=0, pin_memory=False, drop_last=False, - timeout=0, worker_init_fn=None, collate_fn=None): + timeout=0, worker_init_fn=None, collate_fn=None, + batch_sampler=None): r""" :param dataset: :class:`~fastNLP.DataSet` 对象, 数据集 @@ -239,15 +246,21 @@ class DataSetIter(BatchIter): :param timeout: 生成一个batch的timeout值 :param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。 :param collate_fn: 用于将样本组合成batch的函数 + :param batch_sampler: 当每次batch取出的数据数量不一致时,可以使用该sampler。batch_sampler每次iter应该输出一个list的index。 + 当batch_sampler不为None时,参数batch_size, sampler, drop_last会被忽略。 """ assert isinstance(dataset, DataSet) dataset = DataSetGetter(dataset, as_numpy) collate_fn = dataset.collate_fn if collate_fn is None else collate_fn + if batch_sampler is not None: + batch_size = 1 + sampler = None + drop_last = False super().__init__( dataset=dataset, batch_size=batch_size, sampler=sampler, num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, - collate_fn=collate_fn + collate_fn=collate_fn, batch_sampler=batch_sampler ) def __iter__(self): @@ -384,12 +397,16 @@ class TorchLoaderIter(BatchIter): os.remove(tmp_file_path) """ - def __init__(self, dataset, batch_size=1, sampler=None, + def __init__(self, dataset, collate_fn, batch_size=1, sampler=None, num_workers=0, pin_memory=False, drop_last=False, - timeout=0, worker_init_fn=None, collate_fn=None): + timeout=0, worker_init_fn=None, + batch_sampler=None): r""" - :param dataset: :class:`~fastNLP.DataSet` 对象, 数据集 + :param dataset: 实现了__getitem__和__len__方法的数据容器。 + :param callable collate_fn: 用于将样本组合成batch的函数。输入为[dataset[idx1], dataset[idx2], ...], 即dataset中 + __getitem__返回值组成的list,返回值必须为两个dict,其中第一个dict会被认为是input,第二个dict中的内容被认为是target。 + 需要转换为tensor的数据,需要在collate_fn中转化,但不需要转移到对应device。 :param int batch_size: 取出的batch大小 :param sampler: 规定使用的 :class:`~fastNLP.Sampler` 方式. 若为 ``None`` , 使用 :class:`~fastNLP.SequentialSampler`. Default: ``None`` @@ -398,19 +415,21 @@ class TorchLoaderIter(BatchIter): :param bool drop_last: 如果最后一个batch没有batch_size这么多sample,就扔掉最后一个 :param timeout: 生成一个batch的timeout值 :param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。 - :param collate_fn: 用于将样本组合成batch的函数。 + :param batch_sampler: 当每次batch取出的数据数量不一致时,可以使用该sampler。batch_sampler每次iter应该输出一个list的index。 + 当batch_sampler不为None时,参数batch_size, sampler, drop_last会被忽略。 """ assert len(dataset) > 0 - ins = dataset[0] - if (len(ins) != 2 or not isinstance(ins[0], dict) or not isinstance(ins[1], dict)) and collate_fn is None: - raise RuntimeError("If the provided dataset does not return two dicts when call __getitem__(), the" - " `collate_fn` must be provided.") + assert collate_fn is not None, "You must pass collate_fn to pad the batch." + if batch_sampler is not None: + batch_size = 1 + sampler = None + drop_last = False super().__init__( dataset=dataset, batch_size=batch_size, sampler=sampler, num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, - collate_fn=collate_fn + collate_fn=collate_fn, batch_sampler=batch_sampler ) def __iter__(self): diff --git a/fastNLP/core/collect_fn.py b/fastNLP/core/collate_fn.py similarity index 77% rename from fastNLP/core/collect_fn.py rename to fastNLP/core/collate_fn.py index 71068106..7d7f9726 100644 --- a/fastNLP/core/collect_fn.py +++ b/fastNLP/core/collate_fn.py @@ -36,72 +36,72 @@ def batching(samples, max_len=0, padding_val=0): return batch -class Collector: +class Collater: r""" - 辅助DataSet管理collect_fn的类 + 辅助DataSet管理collate_fn的类 """ def __init__(self): - self.collect_fns = {} + self.collate_fns = {} def add_fn(self, fn, name=None): r""" - 向collector新增一个collect_fn函数 + 向collater新增一个collate_fn函数 :param callable fn: :param str,int name: :return: """ - if name in self.collect_fns: - logger.warn(f"collect_fn:{name} will be overwritten.") + if name in self.collate_fns: + logger.warn(f"collate_fn:{name} will be overwritten.") if name is None: - name = len(self.collect_fns) - self.collect_fns[name] = fn + name = len(self.collate_fns) + self.collate_fns[name] = fn def is_empty(self): r""" - 返回是否包含collect_fn + 返回是否包含collate_fn :return: """ - return len(self.collect_fns)==0 + return len(self.collate_fns) == 0 def delete_fn(self, name=None): r""" - 删除collect_fn + 删除collate_fn - :param str,int name: 如果为None就删除最近加入的collect_fn + :param str,int name: 如果为None就删除最近加入的collate_fn :return: """ if not self.is_empty(): - if name in self.collect_fns: - self.collect_fns.pop(name) + if name in self.collate_fns: + self.collate_fns.pop(name) elif name is None: - last_key = list(self.collect_fns.keys())[0] - self.collect_fns.pop(last_key) + last_key = list(self.collate_fns.keys())[0] + self.collate_fns.pop(last_key) - def collect_batch(self, ins_list): + def collate_batch(self, ins_list): bx, by = {}, {} - for name, fn in self.collect_fns.items(): + for name, fn in self.collate_fns.items(): try: batch_x, batch_y = fn(ins_list) except BaseException as e: - logger.error(f"Exception:`{e}` happens when call collect_fn:`{name}`.") + logger.error(f"Exception:`{e}` happens when call collate_fn:`{name}`.") raise e bx.update(batch_x) by.update(batch_y) return bx, by def copy_from(self, col): - assert isinstance(col, Collector) - new_col = Collector() - new_col.collect_fns = deepcopy(col.collect_fns) + assert isinstance(col, Collater) + new_col = Collater() + new_col.collate_fns = deepcopy(col.collate_fns) return new_col -class ConcatCollectFn: +class ConcatCollateFn: r""" - field拼接collect_fn,将不同field按序拼接后,padding产生数据。 + field拼接collate_fn,将不同field按序拼接后,padding产生数据。 :param List[str] inputs: 将哪些field的数据拼接起来, 目前仅支持1d的field :param str output: 拼接后的field名称 diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 78120128..a6c6cde6 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -285,8 +285,8 @@ r""" ------------------------------------------------------------ DataSet支持在进行batch时,默认只能看到当前的field的值,但在某些训练中可能存在以下的情况: (1)需要两个field拼接成为一个field; - (2)需要在batch中进行负采样。这时候就需要能够同时利用多个field进行batch的操作,DataSet中的add_collect_fn()函数支持添加 - 自定义涉及多个field的collect_fn函数。例如下例中将两个field拼接成一个field的场景 + (2)需要在batch中进行负采样。这时候就需要能够同时利用多个field进行batch的操作,DataSet中的add_collate_fn()函数支持添加 + 自定义涉及多个field的collate_fn函数。例如下例中将两个field拼接成一个field的场景 .. code-block:: @@ -302,9 +302,9 @@ r""" }) data.set_target('y') - # 所有的collect_fn函数都接受list[(ind1, instance1), (ind2, instance2), ...]作为输入,其中ind1/ind2是该instance在dataset中 + # 所有的collate_fn函数都接受list[(ind1, instance1), (ind2, instance2), ...]作为输入,其中ind1/ind2是该instance在dataset中 # 的index,instance1/instance2是这次batch取出来的数据,包含了所有的field. - def concat_collect_fn(ins_list): + def concat_collate_fn(ins_list): x1 = [ins['x1'] for ind,ins in ins_list] x2 = [ins['x2'] for ind,ins in ins_list] xs = [] @@ -318,7 +318,7 @@ r""" # 采用返回值。 return b_x, b_y - data.add_collect_fn(concat_collect_fn) + data.add_collate_fn(concat_collate_fn) for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2): print("batch_x:", batch_x) @@ -328,7 +328,7 @@ r""" # batch_y: {'y': array([0, 1])} # 如果取batch过程含有一些参数,可以通过类来实现 - class ConCollectFn: + class ConCollateFn: def __init__(self, max_len=3): self.max_len = max_len @@ -342,8 +342,8 @@ r""" b_x = {'x': arr} b_y = {} return b_x, b_y - data.delete_collect_fn() # 删除之前的collect_fn - data.add_collect_fn(ConCollectFn(max_len=3)) + data.delete_collate_fn() # 删除之前的collate_fn + data.add_collate_fn(ConCollateFn(max_len=3)) for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2): print("batch_x:", batch_x) print("batch_y:", batch_y) @@ -370,7 +370,7 @@ from .field import FieldArray from .field import SetInputOrTargetException from .instance import Instance from .utils import pretty_table_printer -from .collect_fn import Collector +from .collate_fn import Collater class ApplyResultException(Exception): @@ -406,7 +406,7 @@ class DataSet(object): else: raise ValueError("data only be dict or list type.") - self.collector = Collector() + self.collater = Collater() def __contains__(self, item): return item in self.field_arrays @@ -462,7 +462,7 @@ class DataSet(object): for field in self.field_arrays.values(): data_set.add_field(field_name=field.name, fields=field.content[idx], padder=field.padder, is_input=field.is_input, is_target=field.is_target, ignore_type=field.ignore_type) - data_set.collector = self.collector.copy_from(self.collector) + data_set.collater = self.collater.copy_from(self.collater) return data_set elif isinstance(idx, str): if idx not in self: @@ -476,7 +476,7 @@ class DataSet(object): dataset.append(instance) for field_name, field in self.field_arrays.items(): dataset.field_arrays[field_name].to(field) - dataset.collector = self.collector.copy_from(self.collector) + dataset.collater = self.collater.copy_from(self.collater) return dataset else: raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) @@ -1083,8 +1083,8 @@ class DataSet(object): train_set.field_arrays[field_name].to(self.field_arrays[field_name]) dev_set.field_arrays[field_name].to(self.field_arrays[field_name]) - train_set.collector.copy_from(self.collector) - dev_set.collector.copy_from(self.collector) + train_set.collater.copy_from(self.collater) + dev_set.collater.copy_from(self.collater) return train_set, dev_set def save(self, path): @@ -1109,30 +1109,30 @@ class DataSet(object): assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d)) return d - def add_collect_fn(self, fn, name=None): + def add_collate_fn(self, fn, name=None): r""" - 添加 CollectFn,collect_fn允许在生成的batch的过程中动态生成一些数据(在DataSetIter作为迭代器的情况下有效,默认情况下就是用的 - 这个)。支持依次添加多个collect_fn, 如果相同的key,后面的collect_fn的结果覆盖前面的collect_fn的结果。 + 添加 CollateFn,collate_fn允许在生成的batch的过程中动态生成一些数据(在DataSetIter作为迭代器的情况下有效,默认情况下就是用的 + 这个)。支持依次添加多个collate_fn, 如果相同的key,后面的collate_fn的结果覆盖前面的collate_fn的结果。 :param callable fn: 传入一个可调用的function, 该function可接受的参数为List[(ind1, instance1), (ind2, instance2)] (某个batch被选中的所有的indice以及instance),其中ind1/ind2是该instance在dataset中的index,instance1/instance2是 这次batch取出来的数据,包含了所有的field。返回值需要为两个dict,第一个dict的值将被认为是input,第二个dict的值被认为是 target,返回的值至多允许一个空dict。若返回的dict中包含了被设置为input或target的field的名称,将覆盖dataset中的field。 - fastNLP不会将collect_fn的返回结果pad和转换为tensor,需要在collect_fn中完成pad和转换为tensor(不需要将tensor移动到 - gpu中,如果是pytorch的tensor,fastNLP会自动将其移动到特定gpu)。不要修改传入collect_fn中的数据,否则可能导致未知问题。 - :param str,int name: collect_fn的名称,如果不传入,默认使用自增长的数字作为key。相同的name会覆盖之前的collect_fn。 + fastNLP不会将collate_fn的返回结果pad和转换为tensor,需要在collate_fn中完成pad和转换为tensor(不需要将tensor移动到 + gpu中,fastNLP会自动将其移动到特定gpu)。不要修改传入collate_fn中的数据,否则可能导致未知问题。 + :param str,int name: collate_fn的名称,如果不传入,默认使用自增长的数字作为key。相同的name会覆盖之前的collate_fn。 """ assert callable(fn), "You must pass in a callable object." - self.collector.add_fn(fn, name=name) + self.collater.add_fn(fn, name=name) - def delete_collect_fn(self, name=None): + def delete_collate_fn(self, name=None): r""" - 删除某个collect_fn + 删除某个collate_fn - :param str,int name: 如果为None,则删除最近加入的collect_fn + :param str,int name: 如果为None,则删除最近加入的collate_fn :return: """ - self.collector.delete_fn(name) + self.collater.delete_fn(name) - def _collect_batch(self, ins_list): - return self.collector.collect_batch(ins_list) + def _collate_batch(self, ins_list): + return self.collater.collate_batch(ins_list) diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index eee02b1b..797ddcda 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -718,8 +718,8 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re _tmp += f' Or provide `{_miss}` in DataSet or the output of {prev_func_signature}. ' else: _tmp = f'Provide `{_miss}` in DataSet or the output of {prev_func_signature}.' - if not dataset.collector.is_empty(): - _tmp += f'Or you need to add `{_miss}` in the output of your collect_fn. ' + if not dataset.collater.is_empty(): + _tmp += f'Or you need to add `{_miss}` in the output of your collate_fn. ' suggestions.append(_tmp) if check_res.duplicated: @@ -779,8 +779,8 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): suggestions.append(f"You might need to set `{_miss_in_dataset}` as input. ") if _miss_out_dataset: _tmp = f"You need to provide `{_miss_out_dataset}` in DataSet and set it as input. " - if not dataset.collector.is_empty(): - _tmp += f'Or you need to add `{_miss_out_dataset}` in the output of your collect_fn. ' + if not dataset.collator.is_empty(): + _tmp += f'Or you need to add `{_miss_out_dataset}` in the output of your collate_fn. ' suggestions.append(_tmp) if check_res.unused: diff --git a/fastNLP/io/data_bundle.py b/fastNLP/io/data_bundle.py index a105e30b..bcb8a211 100644 --- a/fastNLP/io/data_bundle.py +++ b/fastNLP/io/data_bundle.py @@ -348,26 +348,26 @@ class DataBundle: dataset.apply(func, new_field_name=new_field_name, **kwargs) return self - def add_collect_fn(self, fn, name=None): + def add_collate_fn(self, fn, name=None): r""" - 向所有DataSet增加collect_fn, collect_fn详见 :class:`~fastNLP.DataSet` 中相关说明. + 向所有DataSet增加collate_fn, collate_fn详见 :class:`~fastNLP.DataSet` 中相关说明. :param callable fn: :param name: :return: """ for _, dataset in self.datasets.items(): - dataset.add_collect_fn(fn=fn, name=name) + dataset.add_collate_fn(fn=fn, name=name) - def delete_collect_fn(self, name=None): + def delete_collate_fn(self, name=None): r""" - 删除DataSet中的collect_fn + 删除DataSet中的collate_fn :param name: :return: """ for _, dataset in self.datasets.items(): - dataset.delete_collect_fn(name=name) + dataset.delete_collate_fn(name=name) def __repr__(self): _str = '' diff --git a/test/core/test_batch.py b/test/core/test_batch.py index 0efe1550..18cbf59d 100644 --- a/test/core/test_batch.py +++ b/test/core/test_batch.py @@ -7,7 +7,7 @@ from fastNLP import DataSetIter, TorchLoaderIter from fastNLP import DataSet from fastNLP import Instance from fastNLP import SequentialSampler -from fastNLP import ConcatCollectFn +from fastNLP import ConcatCollateFn def generate_fake_dataset(num_samples=1000): @@ -177,76 +177,76 @@ class TestCase1(unittest.TestCase): for con,t in zip(cons, test): self.assertEqual(alphas[:con], t) - def test_collect_fn(self): + def test_collate_fn(self): batch_size = 32 num_samples = 1000 dataset = generate_fake_dataset(num_samples) dataset.set_input('1','2') dataset.set_target('0','3') - fn = ConcatCollectFn(inputs=['1', '2'], output='12', pad_val=0, max_len=0, is_input=True, is_target=False) - dataset.add_collect_fn(fn, name='demo') + fn = ConcatCollateFn(inputs=['1', '2'], output='12', pad_val=0, max_len=0, is_input=True, is_target=False) + dataset.add_collate_fn(fn, name='demo') batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) for batch_x, batch_y in batch: for i in range(batch_size): # print(i) self.assertEqual(batch_x['12'][i].sum(), batch_x['1'][i].sum() + batch_x['2'][i].sum()) - dataset.delete_collect_fn(name='demo') + dataset.delete_collate_fn(name='demo') # 测试非input的情况 dataset.set_input('1', '2', flag=False) # - fn = ConcatCollectFn(inputs=['1', '2'], output='12', pad_val=0, max_len=0, is_input=True, is_target=False) - dataset.add_collect_fn(fn, name='demo') + fn = ConcatCollateFn(inputs=['1', '2'], output='12', pad_val=0, max_len=0, is_input=True, is_target=False) + dataset.add_collate_fn(fn, name='demo') batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) for batch_x, batch_y in batch: for i in range(batch_size): self.assertTrue('12' in batch_x) - dataset.delete_collect_fn(name='demo') + dataset.delete_collate_fn(name='demo') dataset.set_input('1', '2', flag=True) # # 测试覆盖其它field的情况 - fn = ConcatCollectFn(inputs=['1', '2'], output='3', pad_val=0, max_len=0, is_input=True, is_target=True) - dataset.add_collect_fn(fn, name='demo') + fn = ConcatCollateFn(inputs=['1', '2'], output='3', pad_val=0, max_len=0, is_input=True, is_target=True) + dataset.add_collate_fn(fn, name='demo') batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) for batch_x, batch_y in batch: for i in range(batch_size): # print(i) self.assertEqual(batch_y['3'][i].sum(), batch_x['1'][i].sum() + batch_x['2'][i].sum()) - dataset.delete_collect_fn(name='demo') + dataset.delete_collate_fn(name='demo') # 测试非input,target的情况 dataset.set_input('1', '2', flag=False) - fn = ConcatCollectFn(inputs=['1', '2'], output='3', pad_val=0, max_len=0, is_input=True, is_target=True) - dataset.add_collect_fn(fn, name='demo') + fn = ConcatCollateFn(inputs=['1', '2'], output='3', pad_val=0, max_len=0, is_input=True, is_target=True) + dataset.add_collate_fn(fn, name='demo') batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) for batch_x, batch_y in batch: for i in range(batch_size): # print(i) self.assertTrue('3' in batch_x) self.assertTrue('3' in batch_y) - dataset.delete_collect_fn(name='demo') + dataset.delete_collate_fn(name='demo') # 测试加入非法fn的请 with self.assertRaises(AssertionError): - dataset.add_collect_fn(1) + dataset.add_collate_fn(1) - # 测试collect_fn返回值只有一个的情况 - def demo_collect_fn(ins_list): + # 测试collate_fn返回值只有一个的情况 + def demo_collate_fn(ins_list): return {'3':1} - dataset.add_collect_fn(demo_collect_fn, name='demo') + dataset.add_collate_fn(demo_collate_fn, name='demo') with self.assertRaises(BaseException): batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) for batch_x, batch_y in batch: pass - dataset.delete_collect_fn(name='demo') + dataset.delete_collate_fn(name='demo') - # 测试多个collect_fn - dataset.add_collect_fn(demo_collect_fn, name='demo') - dataset.add_collect_fn(demo_collect_fn, name='demo') + # 测试多个collate_fn + dataset.add_collate_fn(demo_collate_fn, name='demo') + dataset.add_collate_fn(demo_collate_fn, name='demo') # 测试删除 - dataset.delete_collect_fn() - dataset.delete_collect_fn() - self.assertTrue(dataset.collector.is_empty()) + dataset.delete_collate_fn() + dataset.delete_collate_fn() + self.assertTrue(dataset.collater.is_empty()) def test_demo(self): import torch @@ -261,9 +261,9 @@ class TestCase1(unittest.TestCase): }) data.set_target('y') - # 所有的collect_fn函数都接受list[(ind1, instance1), (ind2, instance2), ...]作为输入,其中ind1/ind2是该instance在dataset中 + # 所有的collate_fn函数都接受list[(ind1, instance1), (ind2, instance2), ...]作为输入,其中ind1/ind2是该instance在dataset中 # 的index,instance1/instance2是这次batch取出来的数据,包含了所有的field. - def concat_collect_fn(ins_list): + def concat_collate_fn(ins_list): x1 = [ins['x1'] for ind,ins in ins_list] x2 = [ins['x2'] for ind,ins in ins_list] xs = [] @@ -277,7 +277,7 @@ class TestCase1(unittest.TestCase): # 采用返回值。 return b_x, b_y - data.add_collect_fn(concat_collect_fn) + data.add_collate_fn(concat_collate_fn) for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2): print("batch_x:", batch_x) @@ -287,7 +287,7 @@ class TestCase1(unittest.TestCase): # batch_y: {'y': array([0, 1])} # 如果取batch过程含有一些参数,可以通过类来实现 - class ConCollectFn: + class ConCollateFn: def __init__(self, max_len=3): self.max_len = max_len def __call__(self, ins_list): @@ -300,8 +300,8 @@ class TestCase1(unittest.TestCase): b_x = {'x': arr} b_y = {} return b_x, b_y - data.delete_collect_fn() # 删除之前的collect_fn - data.add_collect_fn(ConCollectFn(max_len=3)) + data.delete_collate_fn() # 删除之前的collate_fn + data.add_collate_fn(ConCollateFn(max_len=3)) for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2): print("batch_x:", batch_x) print("batch_y:", batch_y) @@ -326,14 +326,77 @@ class TestCase1(unittest.TestCase): return x, y data1 = FakeData() - dataiter = TorchLoaderIter(data1, batch_size=2) + def collact_fn(ins_list): + xs = [ins[0]['x'] for ins in ins_list] + ys = [ins[1]['y'] for ins in ins_list] + return {'x':xs}, {'y':ys} + dataiter = TorchLoaderIter(data1, collate_fn=collact_fn, batch_size=2) for x, y in dataiter: print(x, y) - def func(): - data2 = FakeData(return_dict=False) - dataiter = TorchLoaderIter(data2, batch_size=2) - self.assertRaises(Exception, func) + def test_batch_sampler(self): + # 测试DataSetIter与TorchLoaderIter的batch_sampler能否正常工作 + # DataSetIter + ds = generate_fake_dataset(5) + ds.set_input('1') + class BatchSampler: + def __init__(self, dataset): + self.num_samples = len(dataset) + + def __iter__(self): + index = 0 + indexes = list(range(self.num_samples)) + np.random.shuffle(indexes) + start_idx = 0 + while index < self.num_samples: + if start_idx == 0: + end_index = self.num_samples//2 + else: + end_index = self.num_samples + yield indexes[start_idx:end_index] + index = end_index + start_idx = end_index + + def __len__(self): + return 2 + + batch_sampler = BatchSampler(ds) + + data_iter = DataSetIter(ds, batch_size=10, sampler=batch_sampler, as_numpy=False, + num_workers=0, pin_memory=False, drop_last=False, + timeout=0, worker_init_fn=None, collate_fn=None, + batch_sampler=batch_sampler) + num_samples = [len(ds)//2, len(ds)-len(ds)//2] + for idx, (batch_x, batch_y) in enumerate(data_iter): + self.assertEqual(num_samples[idx], len(batch_x['1'])) + + # TorchLoaderIter + class FakeData: + def __init__(self): + self.x = [[1,2,3], [4,5,6], [1,2]] + + def __len__(self): + return len(self.x) + + def __getitem__(self, i): + x = self.x[i] + y = 0 + return x,y + + def collate_fn(ins_list): + xs = [ins[0] for ins in ins_list] + ys = [ins[1] for ins in ins_list] + return {'x':xs}, {'y':ys} + + ds = FakeData() + batch_sampler = BatchSampler(ds) + data_iter = TorchLoaderIter(ds, batch_size=10, sampler=batch_sampler, + num_workers=0, pin_memory=False, drop_last=False, + timeout=0, worker_init_fn=None, collate_fn=collate_fn, + batch_sampler=batch_sampler) + num_samples = [len(ds)//2, len(ds)-len(ds)//2] + for idx, (batch_x, batch_y) in enumerate(data_iter): + self.assertEqual(num_samples[idx], len(batch_x['x'])) """ def test_multi_workers_batch(self): diff --git a/test/core/test_trainer.py b/test/core/test_trainer.py index e8d78e19..138d0462 100644 --- a/test/core/test_trainer.py +++ b/test/core/test_trainer.py @@ -243,7 +243,7 @@ class TrainerTestGround(unittest.TestCase): def __len__(self): return self.num_samples - def collect_fn(data_list): + def collate_fn(data_list): # [(x1,y1), (x2,y2), ...], 这里的输入实际上是将UdfDataSet的__getitem__输入结合为list xs, ys = [], [] for l in data_list: @@ -254,7 +254,7 @@ class TrainerTestGround(unittest.TestCase): return {'x':x, 'y':y}, {'y':y} dataset = UdfDataSet(10) - dataset = TorchLoaderIter(dataset, collate_fn=collect_fn) + dataset = TorchLoaderIter(dataset, collate_fn=collate_fn) class Model(nn.Module): def __init__(self): super().__init__() @@ -268,6 +268,67 @@ class TrainerTestGround(unittest.TestCase): metrics=AccuracyMetric(target='y'), use_tqdm=False) trainer.train(load_best_model=False) + def test_batch_sampler_dataiter(self): + import random + import torch + class BatchSampler: + def __init__(self, dataset): + self.num_samples = len(dataset) + + def __iter__(self): + index = 0 + indexes = list(range(self.num_samples)) + np.random.shuffle(indexes) + start_idx = 0 + while index < self.num_samples: + if start_idx == 0: + end_index = self.num_samples//2 + else: + end_index = self.num_samples + yield indexes[start_idx:end_index] + index = end_index + start_idx = end_index + def __len__(self): + return 2 + + class UdfDataSet: + def __init__(self, num_samples): + self.num_samples = num_samples + + def __getitem__(self, idx): + x = [random.random() for _ in range(3)] + y = random.random() + return x,y + + def __len__(self): + return self.num_samples + + def collate_fn(data_list): + # [(x1,y1), (x2,y2), ...], 这里的输入实际上是将UdfDataSet的__getitem__输入结合为list + xs, ys = [], [] + for l in data_list: + x, y = l + xs.append(x) + ys.append(y) + x,y = torch.FloatTensor(xs), torch.FloatTensor(ys) + return {'x':x, 'y':y}, {'y':y} + + dataset = UdfDataSet(11) + batch_sampler = BatchSampler(dataset) + dataset = TorchLoaderIter(dataset, collate_fn=collate_fn, batch_sampler=batch_sampler) + class Model(nn.Module): + def __init__(self): + super().__init__() + self.fc = nn.Linear(3, 1) + def forward(self, x, y): + return {'loss':torch.pow(self.fc(x).squeeze(-1)-y, 2).sum()} + def predict(self, x): + return {'pred':self.fc(x).squeeze(-1)} + model = Model() + trainer = Trainer(train_data=dataset, model=model, loss=None, print_every=2, dev_data=dataset, + metrics=AccuracyMetric(target='y'), use_tqdm=False) + trainer.train(load_best_model=False) + def test_onthefly_iter(self): import tempfile import random @@ -333,7 +394,7 @@ class TrainerTestGround(unittest.TestCase): return {'loss': torch.pow(self.fc(x).squeeze(-1) - y, 2).sum()} def predict(self, x): - return {'pred': self.fc(x).squeeze(0)} + return {'pred': self.fc(x).squeeze(-1)} model = Model() trainer = Trainer(train_data=dataset, model=model, loss=None, print_every=2, dev_data=dataset, @@ -356,7 +417,7 @@ class TrainerTestGround(unittest.TestCase): x.append(ins['x1']+ins['x2']) x = torch.FloatTensor(x) return {'x':x}, {} - dataset.add_collect_fn(fn) + dataset.add_collate_fn(fn) class Model(nn.Module): def __init__(self): @@ -377,7 +438,7 @@ class TrainerTestGround(unittest.TestCase): dev_data=dataset, metrics=AccuracyMetric(target='y'), use_tqdm=False) trainer.train() - def test_collect_fn2(self): + def test_collate_fn2(self): """测试能否实现batch_x, batch_y""" dataset = prepare_fake_dataset2('x1', 'x2') dataset.set_input('x1', 'x2') @@ -389,7 +450,7 @@ class TrainerTestGround(unittest.TestCase): x.append(ins['x1']+ins['x2']) x = torch.FloatTensor(x) return {'x':x}, {'target':x[:, :4].argmax(dim=-1)} - dataset.add_collect_fn(fn) + dataset.add_collate_fn(fn) class Model(nn.Module): def __init__(self): @@ -410,7 +471,7 @@ class TrainerTestGround(unittest.TestCase): dev_data=dataset, metrics=AccuracyMetric(), use_tqdm=False) trainer.train() - def test_collect_fn3(self): + def test_collate_fn3(self): """ 测试应该会覆盖 @@ -426,7 +487,7 @@ class TrainerTestGround(unittest.TestCase): x.append(ins['x1']+ins['x2']) x = torch.FloatTensor(x) return {'x1':torch.zeros_like(x)}, {'target':torch.zeros(x.size(0)).long(), 'y':x} - dataset.add_collect_fn(fn) + dataset.add_collate_fn(fn) class Model(nn.Module): def __init__(self):