Browse Source

1.将collect_fn修改为collate_fn使得与pytorch保持一致; 2.增加对BatchSampler的支持

tags/v0.5.5
yh_cc 4 years ago
parent
commit
ba645f4483
9 changed files with 275 additions and 132 deletions
  1. +2
    -2
      fastNLP/__init__.py
  2. +2
    -2
      fastNLP/core/__init__.py
  3. +42
    -23
      fastNLP/core/batch.py
  4. +24
    -24
      fastNLP/core/collate_fn.py
  5. +27
    -27
      fastNLP/core/dataset.py
  6. +4
    -4
      fastNLP/core/utils.py
  7. +6
    -6
      fastNLP/io/data_bundle.py
  8. +99
    -36
      test/core/test_batch.py
  9. +69
    -8
      test/core/test_trainer.py

+ 2
- 2
fastNLP/__init__.py View File

@@ -44,8 +44,8 @@ __all__ = [
"AutoPadder",
"EngChar2DPadder",

# "CollectFn",
"ConcatCollectFn",
# "CollateFn",
"ConcatCollateFn",

"MetricBase",
"AccuracyMetric",


+ 2
- 2
fastNLP/core/__init__.py View File

@@ -21,7 +21,7 @@ __all__ = [
"AutoPadder",
"EngChar2DPadder",

"ConcatCollectFn",
"ConcatCollateFn",
"Vocabulary",
@@ -99,4 +99,4 @@ from .tester import Tester
from .trainer import Trainer
from .utils import cache_results, seq_len_to_mask, get_seq_len
from .vocabulary import Vocabulary
from .collect_fn import ConcatCollectFn
from .collate_fn import ConcatCollateFn

+ 42
- 23
fastNLP/core/batch.py View File

@@ -18,7 +18,7 @@ import torch.utils.data
from collections import defaultdict

from .dataset import DataSet
from .sampler import SequentialSampler
from .sampler import SequentialSampler, Sampler
from ._logger import logger


@@ -89,8 +89,8 @@ class DataSetGetter:
sin_x = _pad(sin_x, dataset=self.dataset, as_numpy=self.as_numpy)
sin_y = _pad(sin_y, dataset=self.dataset, as_numpy=self.as_numpy)

if not self.dataset.collector.is_empty():
bx, by = self.dataset._collect_batch(ins_list)
if not self.dataset.collater.is_empty():
bx, by = self.dataset._collate_batch(ins_list)
sin_x.update(bx)
sin_y.update(by)

@@ -127,29 +127,35 @@ class BatchIter:
"""
def __init__(self, dataset, batch_size=1, sampler=None,
num_workers=0, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None, collate_fn=None):
if not isinstance(sampler, torch.utils.data.Sampler):
self.sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset)
else:
self.sampler = sampler
timeout=0, worker_init_fn=None, collate_fn=None,
batch_sampler=None):
if isinstance(sampler, Sampler): # 如果时fastNLP的sampler需要adapt一下
sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset)
self.sampler = sampler
self.batch_sampler = batch_sampler

# DataLoader的collect_fn输入是List[],里面的元素是dataset[index]返回的结果
# DataLoader的collate_fn输入是List[],里面的元素是dataset[index]返回的结果
if collate_fn is None:
# pytoch <= 1.1 中不能设置collate_fn=None
self.dataiter = torch.utils.data.DataLoader(
dataset=dataset, batch_size=batch_size, sampler=self.sampler,
num_workers=num_workers,
pin_memory=pin_memory, drop_last=drop_last,
timeout=timeout, worker_init_fn=worker_init_fn)
timeout=timeout, worker_init_fn=worker_init_fn,
batch_sampler=batch_sampler)
else:
self.dataiter = torch.utils.data.DataLoader(
dataset=dataset, batch_size=batch_size, sampler=self.sampler,
collate_fn=collate_fn, num_workers=num_workers,
pin_memory=pin_memory, drop_last=drop_last,
timeout=timeout, worker_init_fn=worker_init_fn)
timeout=timeout, worker_init_fn=worker_init_fn,
batch_sampler=batch_sampler)

# 以sampler的数量为准,因为DistributedSampler的时候每个进程上并不是所有的数据都用上了
self._num_batches = self.get_num_batches(len(self.dataiter.sampler), batch_size, drop_last)
if self.batch_sampler is None:
self._num_batches = self.get_num_batches(len(self.dataiter.sampler), batch_size, drop_last)
else:
self._num_batches = len(self.batch_sampler)
self.batch_size = batch_size
self.cur_batch_indices = None

@@ -222,7 +228,8 @@ class DataSetIter(BatchIter):
"""
def __init__(self, dataset, batch_size=1, sampler=None, as_numpy=False,
num_workers=0, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None, collate_fn=None):
timeout=0, worker_init_fn=None, collate_fn=None,
batch_sampler=None):
r"""
:param dataset: :class:`~fastNLP.DataSet` 对象, 数据集
@@ -239,15 +246,21 @@ class DataSetIter(BatchIter):
:param timeout: 生成一个batch的timeout值
:param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。
:param collate_fn: 用于将样本组合成batch的函数
:param batch_sampler: 当每次batch取出的数据数量不一致时,可以使用该sampler。batch_sampler每次iter应该输出一个list的index。
当batch_sampler不为None时,参数batch_size, sampler, drop_last会被忽略。
"""
assert isinstance(dataset, DataSet)
dataset = DataSetGetter(dataset, as_numpy)
collate_fn = dataset.collate_fn if collate_fn is None else collate_fn
if batch_sampler is not None:
batch_size = 1
sampler = None
drop_last = False
super().__init__(
dataset=dataset, batch_size=batch_size, sampler=sampler,
num_workers=num_workers, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
collate_fn=collate_fn
collate_fn=collate_fn, batch_sampler=batch_sampler
)

def __iter__(self):
@@ -384,12 +397,16 @@ class TorchLoaderIter(BatchIter):
os.remove(tmp_file_path)
"""
def __init__(self, dataset, batch_size=1, sampler=None,
def __init__(self, dataset, collate_fn, batch_size=1, sampler=None,
num_workers=0, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None, collate_fn=None):
timeout=0, worker_init_fn=None,
batch_sampler=None):
r"""

:param dataset: :class:`~fastNLP.DataSet` 对象, 数据集
:param dataset: 实现了__getitem__和__len__方法的数据容器。
:param callable collate_fn: 用于将样本组合成batch的函数。输入为[dataset[idx1], dataset[idx2], ...], 即dataset中
__getitem__返回值组成的list,返回值必须为两个dict,其中第一个dict会被认为是input,第二个dict中的内容被认为是target。
需要转换为tensor的数据,需要在collate_fn中转化,但不需要转移到对应device。
:param int batch_size: 取出的batch大小
:param sampler: 规定使用的 :class:`~fastNLP.Sampler` 方式. 若为 ``None`` , 使用 :class:`~fastNLP.SequentialSampler`.
Default: ``None``
@@ -398,19 +415,21 @@ class TorchLoaderIter(BatchIter):
:param bool drop_last: 如果最后一个batch没有batch_size这么多sample,就扔掉最后一个
:param timeout: 生成一个batch的timeout值
:param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。
:param collate_fn: 用于将样本组合成batch的函数。
:param batch_sampler: 当每次batch取出的数据数量不一致时,可以使用该sampler。batch_sampler每次iter应该输出一个list的index。
当batch_sampler不为None时,参数batch_size, sampler, drop_last会被忽略。
"""
assert len(dataset) > 0
ins = dataset[0]
if (len(ins) != 2 or not isinstance(ins[0], dict) or not isinstance(ins[1], dict)) and collate_fn is None:
raise RuntimeError("If the provided dataset does not return two dicts when call __getitem__(), the"
" `collate_fn` must be provided.")
assert collate_fn is not None, "You must pass collate_fn to pad the batch."
if batch_sampler is not None:
batch_size = 1
sampler = None
drop_last = False

super().__init__(
dataset=dataset, batch_size=batch_size, sampler=sampler,
num_workers=num_workers, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
collate_fn=collate_fn
collate_fn=collate_fn, batch_sampler=batch_sampler
)

def __iter__(self):


fastNLP/core/collect_fn.py → fastNLP/core/collate_fn.py View File

@@ -36,72 +36,72 @@ def batching(samples, max_len=0, padding_val=0):
return batch


class Collector:
class Collater:
r"""
辅助DataSet管理collect_fn的类
辅助DataSet管理collate_fn的类

"""
def __init__(self):
self.collect_fns = {}
self.collate_fns = {}

def add_fn(self, fn, name=None):
r"""
向collector新增一个collect_fn函数
向collater新增一个collate_fn函数

:param callable fn:
:param str,int name:
:return:
"""
if name in self.collect_fns:
logger.warn(f"collect_fn:{name} will be overwritten.")
if name in self.collate_fns:
logger.warn(f"collate_fn:{name} will be overwritten.")
if name is None:
name = len(self.collect_fns)
self.collect_fns[name] = fn
name = len(self.collate_fns)
self.collate_fns[name] = fn

def is_empty(self):
r"""
返回是否包含collect_fn
返回是否包含collate_fn

:return:
"""
return len(self.collect_fns)==0
return len(self.collate_fns) == 0

def delete_fn(self, name=None):
r"""
删除collect_fn
删除collate_fn

:param str,int name: 如果为None就删除最近加入的collect_fn
:param str,int name: 如果为None就删除最近加入的collate_fn
:return:
"""
if not self.is_empty():
if name in self.collect_fns:
self.collect_fns.pop(name)
if name in self.collate_fns:
self.collate_fns.pop(name)
elif name is None:
last_key = list(self.collect_fns.keys())[0]
self.collect_fns.pop(last_key)
last_key = list(self.collate_fns.keys())[0]
self.collate_fns.pop(last_key)

def collect_batch(self, ins_list):
def collate_batch(self, ins_list):
bx, by = {}, {}
for name, fn in self.collect_fns.items():
for name, fn in self.collate_fns.items():
try:
batch_x, batch_y = fn(ins_list)
except BaseException as e:
logger.error(f"Exception:`{e}` happens when call collect_fn:`{name}`.")
logger.error(f"Exception:`{e}` happens when call collate_fn:`{name}`.")
raise e
bx.update(batch_x)
by.update(batch_y)
return bx, by

def copy_from(self, col):
assert isinstance(col, Collector)
new_col = Collector()
new_col.collect_fns = deepcopy(col.collect_fns)
assert isinstance(col, Collater)
new_col = Collater()
new_col.collate_fns = deepcopy(col.collate_fns)
return new_col


class ConcatCollectFn:
class ConcatCollateFn:
r"""
field拼接collect_fn,将不同field按序拼接后,padding产生数据。
field拼接collate_fn,将不同field按序拼接后,padding产生数据。

:param List[str] inputs: 将哪些field的数据拼接起来, 目前仅支持1d的field
:param str output: 拼接后的field名称

+ 27
- 27
fastNLP/core/dataset.py View File

@@ -285,8 +285,8 @@ r"""
------------------------------------------------------------

DataSet支持在进行batch时,默认只能看到当前的field的值,但在某些训练中可能存在以下的情况: (1)需要两个field拼接成为一个field;
(2)需要在batch中进行负采样。这时候就需要能够同时利用多个field进行batch的操作,DataSet中的add_collect_fn()函数支持添加
自定义涉及多个field的collect_fn函数。例如下例中将两个field拼接成一个field的场景
(2)需要在batch中进行负采样。这时候就需要能够同时利用多个field进行batch的操作,DataSet中的add_collate_fn()函数支持添加
自定义涉及多个field的collate_fn函数。例如下例中将两个field拼接成一个field的场景

.. code-block::

@@ -302,9 +302,9 @@ r"""
})
data.set_target('y')

# 所有的collect_fn函数都接受list[(ind1, instance1), (ind2, instance2), ...]作为输入,其中ind1/ind2是该instance在dataset中
# 所有的collate_fn函数都接受list[(ind1, instance1), (ind2, instance2), ...]作为输入,其中ind1/ind2是该instance在dataset中
# 的index,instance1/instance2是这次batch取出来的数据,包含了所有的field.
def concat_collect_fn(ins_list):
def concat_collate_fn(ins_list):
x1 = [ins['x1'] for ind,ins in ins_list]
x2 = [ins['x2'] for ind,ins in ins_list]
xs = []
@@ -318,7 +318,7 @@ r"""
# 采用返回值。
return b_x, b_y

data.add_collect_fn(concat_collect_fn)
data.add_collate_fn(concat_collate_fn)

for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2):
print("batch_x:", batch_x)
@@ -328,7 +328,7 @@ r"""
# batch_y: {'y': array([0, 1])}

# 如果取batch过程含有一些参数,可以通过类来实现
class ConCollectFn:
class ConCollateFn:
def __init__(self, max_len=3):
self.max_len = max_len

@@ -342,8 +342,8 @@ r"""
b_x = {'x': arr}
b_y = {}
return b_x, b_y
data.delete_collect_fn() # 删除之前的collect_fn
data.add_collect_fn(ConCollectFn(max_len=3))
data.delete_collate_fn() # 删除之前的collate_fn
data.add_collate_fn(ConCollateFn(max_len=3))
for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2):
print("batch_x:", batch_x)
print("batch_y:", batch_y)
@@ -370,7 +370,7 @@ from .field import FieldArray
from .field import SetInputOrTargetException
from .instance import Instance
from .utils import pretty_table_printer
from .collect_fn import Collector
from .collate_fn import Collater


class ApplyResultException(Exception):
@@ -406,7 +406,7 @@ class DataSet(object):

else:
raise ValueError("data only be dict or list type.")
self.collector = Collector()
self.collater = Collater()

def __contains__(self, item):
return item in self.field_arrays
@@ -462,7 +462,7 @@ class DataSet(object):
for field in self.field_arrays.values():
data_set.add_field(field_name=field.name, fields=field.content[idx], padder=field.padder,
is_input=field.is_input, is_target=field.is_target, ignore_type=field.ignore_type)
data_set.collector = self.collector.copy_from(self.collector)
data_set.collater = self.collater.copy_from(self.collater)
return data_set
elif isinstance(idx, str):
if idx not in self:
@@ -476,7 +476,7 @@ class DataSet(object):
dataset.append(instance)
for field_name, field in self.field_arrays.items():
dataset.field_arrays[field_name].to(field)
dataset.collector = self.collector.copy_from(self.collector)
dataset.collater = self.collater.copy_from(self.collater)
return dataset
else:
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx)))
@@ -1083,8 +1083,8 @@ class DataSet(object):
train_set.field_arrays[field_name].to(self.field_arrays[field_name])
dev_set.field_arrays[field_name].to(self.field_arrays[field_name])

train_set.collector.copy_from(self.collector)
dev_set.collector.copy_from(self.collector)
train_set.collater.copy_from(self.collater)
dev_set.collater.copy_from(self.collater)
return train_set, dev_set

def save(self, path):
@@ -1109,30 +1109,30 @@ class DataSet(object):
assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d))
return d

def add_collect_fn(self, fn, name=None):
def add_collate_fn(self, fn, name=None):
r"""
添加 CollectFn,collect_fn允许在生成的batch的过程中动态生成一些数据(在DataSetIter作为迭代器的情况下有效,默认情况下就是用的
这个)。支持依次添加多个collect_fn, 如果相同的key,后面的collect_fn的结果覆盖前面的collect_fn的结果。
添加 CollateFn,collate_fn允许在生成的batch的过程中动态生成一些数据(在DataSetIter作为迭代器的情况下有效,默认情况下就是用的
这个)。支持依次添加多个collate_fn, 如果相同的key,后面的collate_fn的结果覆盖前面的collate_fn的结果。

:param callable fn: 传入一个可调用的function, 该function可接受的参数为List[(ind1, instance1), (ind2, instance2)]
(某个batch被选中的所有的indice以及instance),其中ind1/ind2是该instance在dataset中的index,instance1/instance2是
这次batch取出来的数据,包含了所有的field。返回值需要为两个dict,第一个dict的值将被认为是input,第二个dict的值被认为是
target,返回的值至多允许一个空dict。若返回的dict中包含了被设置为input或target的field的名称,将覆盖dataset中的field。
fastNLP不会将collect_fn的返回结果pad和转换为tensor,需要在collect_fn中完成pad和转换为tensor(不需要将tensor移动到
gpu中,如果是pytorch的tensor,fastNLP会自动将其移动到特定gpu)。不要修改传入collect_fn中的数据,否则可能导致未知问题。
:param str,int name: collect_fn的名称,如果不传入,默认使用自增长的数字作为key。相同的name会覆盖之前的collect_fn。
fastNLP不会将collate_fn的返回结果pad和转换为tensor,需要在collate_fn中完成pad和转换为tensor(不需要将tensor移动到
gpu中,fastNLP会自动将其移动到特定gpu)。不要修改传入collate_fn中的数据,否则可能导致未知问题。
:param str,int name: collate_fn的名称,如果不传入,默认使用自增长的数字作为key。相同的name会覆盖之前的collate_fn。
"""
assert callable(fn), "You must pass in a callable object."
self.collector.add_fn(fn, name=name)
self.collater.add_fn(fn, name=name)

def delete_collect_fn(self, name=None):
def delete_collate_fn(self, name=None):
r"""
删除某个collect_fn
删除某个collate_fn

:param str,int name: 如果为None,则删除最近加入的collect_fn
:param str,int name: 如果为None,则删除最近加入的collate_fn
:return:
"""
self.collector.delete_fn(name)
self.collater.delete_fn(name)

def _collect_batch(self, ins_list):
return self.collector.collect_batch(ins_list)
def _collate_batch(self, ins_list):
return self.collater.collate_batch(ins_list)

+ 4
- 4
fastNLP/core/utils.py View File

@@ -718,8 +718,8 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
_tmp += f' Or provide `{_miss}` in DataSet or the output of {prev_func_signature}. '
else:
_tmp = f'Provide `{_miss}` in DataSet or the output of {prev_func_signature}.'
if not dataset.collector.is_empty():
_tmp += f'Or you need to add `{_miss}` in the output of your collect_fn. '
if not dataset.collater.is_empty():
_tmp += f'Or you need to add `{_miss}` in the output of your collate_fn. '
suggestions.append(_tmp)

if check_res.duplicated:
@@ -779,8 +779,8 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level):
suggestions.append(f"You might need to set `{_miss_in_dataset}` as input. ")
if _miss_out_dataset:
_tmp = f"You need to provide `{_miss_out_dataset}` in DataSet and set it as input. "
if not dataset.collector.is_empty():
_tmp += f'Or you need to add `{_miss_out_dataset}` in the output of your collect_fn. '
if not dataset.collator.is_empty():
_tmp += f'Or you need to add `{_miss_out_dataset}` in the output of your collate_fn. '
suggestions.append(_tmp)

if check_res.unused:


+ 6
- 6
fastNLP/io/data_bundle.py View File

@@ -348,26 +348,26 @@ class DataBundle:
dataset.apply(func, new_field_name=new_field_name, **kwargs)
return self

def add_collect_fn(self, fn, name=None):
def add_collate_fn(self, fn, name=None):
r"""
向所有DataSet增加collect_fn, collect_fn详见 :class:`~fastNLP.DataSet` 中相关说明.
向所有DataSet增加collate_fn, collate_fn详见 :class:`~fastNLP.DataSet` 中相关说明.

:param callable fn:
:param name:
:return:
"""
for _, dataset in self.datasets.items():
dataset.add_collect_fn(fn=fn, name=name)
dataset.add_collate_fn(fn=fn, name=name)

def delete_collect_fn(self, name=None):
def delete_collate_fn(self, name=None):
r"""
删除DataSet中的collect_fn
删除DataSet中的collate_fn

:param name:
:return:
"""
for _, dataset in self.datasets.items():
dataset.delete_collect_fn(name=name)
dataset.delete_collate_fn(name=name)

def __repr__(self):
_str = ''


+ 99
- 36
test/core/test_batch.py View File

@@ -7,7 +7,7 @@ from fastNLP import DataSetIter, TorchLoaderIter
from fastNLP import DataSet
from fastNLP import Instance
from fastNLP import SequentialSampler
from fastNLP import ConcatCollectFn
from fastNLP import ConcatCollateFn


def generate_fake_dataset(num_samples=1000):
@@ -177,76 +177,76 @@ class TestCase1(unittest.TestCase):
for con,t in zip(cons, test):
self.assertEqual(alphas[:con], t)

def test_collect_fn(self):
def test_collate_fn(self):
batch_size = 32
num_samples = 1000
dataset = generate_fake_dataset(num_samples)
dataset.set_input('1','2')
dataset.set_target('0','3')

fn = ConcatCollectFn(inputs=['1', '2'], output='12', pad_val=0, max_len=0, is_input=True, is_target=False)
dataset.add_collect_fn(fn, name='demo')
fn = ConcatCollateFn(inputs=['1', '2'], output='12', pad_val=0, max_len=0, is_input=True, is_target=False)
dataset.add_collate_fn(fn, name='demo')
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True)
for batch_x, batch_y in batch:
for i in range(batch_size):
# print(i)
self.assertEqual(batch_x['12'][i].sum(), batch_x['1'][i].sum() + batch_x['2'][i].sum())
dataset.delete_collect_fn(name='demo')
dataset.delete_collate_fn(name='demo')

# 测试非input的情况
dataset.set_input('1', '2', flag=False) #
fn = ConcatCollectFn(inputs=['1', '2'], output='12', pad_val=0, max_len=0, is_input=True, is_target=False)
dataset.add_collect_fn(fn, name='demo')
fn = ConcatCollateFn(inputs=['1', '2'], output='12', pad_val=0, max_len=0, is_input=True, is_target=False)
dataset.add_collate_fn(fn, name='demo')
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True)
for batch_x, batch_y in batch:
for i in range(batch_size):
self.assertTrue('12' in batch_x)
dataset.delete_collect_fn(name='demo')
dataset.delete_collate_fn(name='demo')
dataset.set_input('1', '2', flag=True) #

# 测试覆盖其它field的情况
fn = ConcatCollectFn(inputs=['1', '2'], output='3', pad_val=0, max_len=0, is_input=True, is_target=True)
dataset.add_collect_fn(fn, name='demo')
fn = ConcatCollateFn(inputs=['1', '2'], output='3', pad_val=0, max_len=0, is_input=True, is_target=True)
dataset.add_collate_fn(fn, name='demo')
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True)
for batch_x, batch_y in batch:
for i in range(batch_size):
# print(i)
self.assertEqual(batch_y['3'][i].sum(), batch_x['1'][i].sum() + batch_x['2'][i].sum())
dataset.delete_collect_fn(name='demo')
dataset.delete_collate_fn(name='demo')

# 测试非input,target的情况
dataset.set_input('1', '2', flag=False)
fn = ConcatCollectFn(inputs=['1', '2'], output='3', pad_val=0, max_len=0, is_input=True, is_target=True)
dataset.add_collect_fn(fn, name='demo')
fn = ConcatCollateFn(inputs=['1', '2'], output='3', pad_val=0, max_len=0, is_input=True, is_target=True)
dataset.add_collate_fn(fn, name='demo')
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True)
for batch_x, batch_y in batch:
for i in range(batch_size):
# print(i)
self.assertTrue('3' in batch_x)
self.assertTrue('3' in batch_y)
dataset.delete_collect_fn(name='demo')
dataset.delete_collate_fn(name='demo')

# 测试加入非法fn的请
with self.assertRaises(AssertionError):
dataset.add_collect_fn(1)
dataset.add_collate_fn(1)

# 测试collect_fn返回值只有一个的情况
def demo_collect_fn(ins_list):
# 测试collate_fn返回值只有一个的情况
def demo_collate_fn(ins_list):
return {'3':1}
dataset.add_collect_fn(demo_collect_fn, name='demo')
dataset.add_collate_fn(demo_collate_fn, name='demo')
with self.assertRaises(BaseException):
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True)
for batch_x, batch_y in batch:
pass
dataset.delete_collect_fn(name='demo')
dataset.delete_collate_fn(name='demo')

# 测试多个collect_fn
dataset.add_collect_fn(demo_collect_fn, name='demo')
dataset.add_collect_fn(demo_collect_fn, name='demo')
# 测试多个collate_fn
dataset.add_collate_fn(demo_collate_fn, name='demo')
dataset.add_collate_fn(demo_collate_fn, name='demo')
# 测试删除
dataset.delete_collect_fn()
dataset.delete_collect_fn()
self.assertTrue(dataset.collector.is_empty())
dataset.delete_collate_fn()
dataset.delete_collate_fn()
self.assertTrue(dataset.collater.is_empty())

def test_demo(self):
import torch
@@ -261,9 +261,9 @@ class TestCase1(unittest.TestCase):
})
data.set_target('y')

# 所有的collect_fn函数都接受list[(ind1, instance1), (ind2, instance2), ...]作为输入,其中ind1/ind2是该instance在dataset中
# 所有的collate_fn函数都接受list[(ind1, instance1), (ind2, instance2), ...]作为输入,其中ind1/ind2是该instance在dataset中
# 的index,instance1/instance2是这次batch取出来的数据,包含了所有的field.
def concat_collect_fn(ins_list):
def concat_collate_fn(ins_list):
x1 = [ins['x1'] for ind,ins in ins_list]
x2 = [ins['x2'] for ind,ins in ins_list]
xs = []
@@ -277,7 +277,7 @@ class TestCase1(unittest.TestCase):
# 采用返回值。
return b_x, b_y

data.add_collect_fn(concat_collect_fn)
data.add_collate_fn(concat_collate_fn)

for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2):
print("batch_x:", batch_x)
@@ -287,7 +287,7 @@ class TestCase1(unittest.TestCase):
# batch_y: {'y': array([0, 1])}

# 如果取batch过程含有一些参数,可以通过类来实现
class ConCollectFn:
class ConCollateFn:
def __init__(self, max_len=3):
self.max_len = max_len
def __call__(self, ins_list):
@@ -300,8 +300,8 @@ class TestCase1(unittest.TestCase):
b_x = {'x': arr}
b_y = {}
return b_x, b_y
data.delete_collect_fn() # 删除之前的collect_fn
data.add_collect_fn(ConCollectFn(max_len=3))
data.delete_collate_fn() # 删除之前的collate_fn
data.add_collate_fn(ConCollateFn(max_len=3))
for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2):
print("batch_x:", batch_x)
print("batch_y:", batch_y)
@@ -326,14 +326,77 @@ class TestCase1(unittest.TestCase):
return x, y

data1 = FakeData()
dataiter = TorchLoaderIter(data1, batch_size=2)
def collact_fn(ins_list):
xs = [ins[0]['x'] for ins in ins_list]
ys = [ins[1]['y'] for ins in ins_list]
return {'x':xs}, {'y':ys}
dataiter = TorchLoaderIter(data1, collate_fn=collact_fn, batch_size=2)
for x, y in dataiter:
print(x, y)

def func():
data2 = FakeData(return_dict=False)
dataiter = TorchLoaderIter(data2, batch_size=2)
self.assertRaises(Exception, func)
def test_batch_sampler(self):
# 测试DataSetIter与TorchLoaderIter的batch_sampler能否正常工作
# DataSetIter
ds = generate_fake_dataset(5)
ds.set_input('1')
class BatchSampler:
def __init__(self, dataset):
self.num_samples = len(dataset)

def __iter__(self):
index = 0
indexes = list(range(self.num_samples))
np.random.shuffle(indexes)
start_idx = 0
while index < self.num_samples:
if start_idx == 0:
end_index = self.num_samples//2
else:
end_index = self.num_samples
yield indexes[start_idx:end_index]
index = end_index
start_idx = end_index

def __len__(self):
return 2

batch_sampler = BatchSampler(ds)

data_iter = DataSetIter(ds, batch_size=10, sampler=batch_sampler, as_numpy=False,
num_workers=0, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None, collate_fn=None,
batch_sampler=batch_sampler)
num_samples = [len(ds)//2, len(ds)-len(ds)//2]
for idx, (batch_x, batch_y) in enumerate(data_iter):
self.assertEqual(num_samples[idx], len(batch_x['1']))

# TorchLoaderIter
class FakeData:
def __init__(self):
self.x = [[1,2,3], [4,5,6], [1,2]]

def __len__(self):
return len(self.x)

def __getitem__(self, i):
x = self.x[i]
y = 0
return x,y

def collate_fn(ins_list):
xs = [ins[0] for ins in ins_list]
ys = [ins[1] for ins in ins_list]
return {'x':xs}, {'y':ys}

ds = FakeData()
batch_sampler = BatchSampler(ds)
data_iter = TorchLoaderIter(ds, batch_size=10, sampler=batch_sampler,
num_workers=0, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None, collate_fn=collate_fn,
batch_sampler=batch_sampler)
num_samples = [len(ds)//2, len(ds)-len(ds)//2]
for idx, (batch_x, batch_y) in enumerate(data_iter):
self.assertEqual(num_samples[idx], len(batch_x['x']))

"""
def test_multi_workers_batch(self):


+ 69
- 8
test/core/test_trainer.py View File

@@ -243,7 +243,7 @@ class TrainerTestGround(unittest.TestCase):
def __len__(self):
return self.num_samples

def collect_fn(data_list):
def collate_fn(data_list):
# [(x1,y1), (x2,y2), ...], 这里的输入实际上是将UdfDataSet的__getitem__输入结合为list
xs, ys = [], []
for l in data_list:
@@ -254,7 +254,7 @@ class TrainerTestGround(unittest.TestCase):
return {'x':x, 'y':y}, {'y':y}

dataset = UdfDataSet(10)
dataset = TorchLoaderIter(dataset, collate_fn=collect_fn)
dataset = TorchLoaderIter(dataset, collate_fn=collate_fn)
class Model(nn.Module):
def __init__(self):
super().__init__()
@@ -268,6 +268,67 @@ class TrainerTestGround(unittest.TestCase):
metrics=AccuracyMetric(target='y'), use_tqdm=False)
trainer.train(load_best_model=False)

def test_batch_sampler_dataiter(self):
import random
import torch
class BatchSampler:
def __init__(self, dataset):
self.num_samples = len(dataset)

def __iter__(self):
index = 0
indexes = list(range(self.num_samples))
np.random.shuffle(indexes)
start_idx = 0
while index < self.num_samples:
if start_idx == 0:
end_index = self.num_samples//2
else:
end_index = self.num_samples
yield indexes[start_idx:end_index]
index = end_index
start_idx = end_index
def __len__(self):
return 2

class UdfDataSet:
def __init__(self, num_samples):
self.num_samples = num_samples

def __getitem__(self, idx):
x = [random.random() for _ in range(3)]
y = random.random()
return x,y

def __len__(self):
return self.num_samples

def collate_fn(data_list):
# [(x1,y1), (x2,y2), ...], 这里的输入实际上是将UdfDataSet的__getitem__输入结合为list
xs, ys = [], []
for l in data_list:
x, y = l
xs.append(x)
ys.append(y)
x,y = torch.FloatTensor(xs), torch.FloatTensor(ys)
return {'x':x, 'y':y}, {'y':y}

dataset = UdfDataSet(11)
batch_sampler = BatchSampler(dataset)
dataset = TorchLoaderIter(dataset, collate_fn=collate_fn, batch_sampler=batch_sampler)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(3, 1)
def forward(self, x, y):
return {'loss':torch.pow(self.fc(x).squeeze(-1)-y, 2).sum()}
def predict(self, x):
return {'pred':self.fc(x).squeeze(-1)}
model = Model()
trainer = Trainer(train_data=dataset, model=model, loss=None, print_every=2, dev_data=dataset,
metrics=AccuracyMetric(target='y'), use_tqdm=False)
trainer.train(load_best_model=False)

def test_onthefly_iter(self):
import tempfile
import random
@@ -333,7 +394,7 @@ class TrainerTestGround(unittest.TestCase):
return {'loss': torch.pow(self.fc(x).squeeze(-1) - y, 2).sum()}

def predict(self, x):
return {'pred': self.fc(x).squeeze(0)}
return {'pred': self.fc(x).squeeze(-1)}

model = Model()
trainer = Trainer(train_data=dataset, model=model, loss=None, print_every=2, dev_data=dataset,
@@ -356,7 +417,7 @@ class TrainerTestGround(unittest.TestCase):
x.append(ins['x1']+ins['x2'])
x = torch.FloatTensor(x)
return {'x':x}, {}
dataset.add_collect_fn(fn)
dataset.add_collate_fn(fn)

class Model(nn.Module):
def __init__(self):
@@ -377,7 +438,7 @@ class TrainerTestGround(unittest.TestCase):
dev_data=dataset, metrics=AccuracyMetric(target='y'), use_tqdm=False)
trainer.train()

def test_collect_fn2(self):
def test_collate_fn2(self):
"""测试能否实现batch_x, batch_y"""
dataset = prepare_fake_dataset2('x1', 'x2')
dataset.set_input('x1', 'x2')
@@ -389,7 +450,7 @@ class TrainerTestGround(unittest.TestCase):
x.append(ins['x1']+ins['x2'])
x = torch.FloatTensor(x)
return {'x':x}, {'target':x[:, :4].argmax(dim=-1)}
dataset.add_collect_fn(fn)
dataset.add_collate_fn(fn)

class Model(nn.Module):
def __init__(self):
@@ -410,7 +471,7 @@ class TrainerTestGround(unittest.TestCase):
dev_data=dataset, metrics=AccuracyMetric(), use_tqdm=False)
trainer.train()

def test_collect_fn3(self):
def test_collate_fn3(self):
"""
测试应该会覆盖

@@ -426,7 +487,7 @@ class TrainerTestGround(unittest.TestCase):
x.append(ins['x1']+ins['x2'])
x = torch.FloatTensor(x)
return {'x1':torch.zeros_like(x)}, {'target':torch.zeros(x.size(0)).long(), 'y':x}
dataset.add_collect_fn(fn)
dataset.add_collate_fn(fn)

class Model(nn.Module):
def __init__(self):


Loading…
Cancel
Save