Browse Source

1.将collect_fn修改为collate_fn使得与pytorch保持一致; 2.增加对BatchSampler的支持

tags/v0.5.5
yh_cc 4 years ago
parent
commit
ba645f4483
9 changed files with 275 additions and 132 deletions
  1. +2
    -2
      fastNLP/__init__.py
  2. +2
    -2
      fastNLP/core/__init__.py
  3. +42
    -23
      fastNLP/core/batch.py
  4. +24
    -24
      fastNLP/core/collate_fn.py
  5. +27
    -27
      fastNLP/core/dataset.py
  6. +4
    -4
      fastNLP/core/utils.py
  7. +6
    -6
      fastNLP/io/data_bundle.py
  8. +99
    -36
      test/core/test_batch.py
  9. +69
    -8
      test/core/test_trainer.py

+ 2
- 2
fastNLP/__init__.py View File

@@ -44,8 +44,8 @@ __all__ = [
"AutoPadder", "AutoPadder",
"EngChar2DPadder", "EngChar2DPadder",


# "CollectFn",
"ConcatCollectFn",
# "CollateFn",
"ConcatCollateFn",


"MetricBase", "MetricBase",
"AccuracyMetric", "AccuracyMetric",


+ 2
- 2
fastNLP/core/__init__.py View File

@@ -21,7 +21,7 @@ __all__ = [
"AutoPadder", "AutoPadder",
"EngChar2DPadder", "EngChar2DPadder",


"ConcatCollectFn",
"ConcatCollateFn",
"Vocabulary", "Vocabulary",
@@ -99,4 +99,4 @@ from .tester import Tester
from .trainer import Trainer from .trainer import Trainer
from .utils import cache_results, seq_len_to_mask, get_seq_len from .utils import cache_results, seq_len_to_mask, get_seq_len
from .vocabulary import Vocabulary from .vocabulary import Vocabulary
from .collect_fn import ConcatCollectFn
from .collate_fn import ConcatCollateFn

+ 42
- 23
fastNLP/core/batch.py View File

@@ -18,7 +18,7 @@ import torch.utils.data
from collections import defaultdict from collections import defaultdict


from .dataset import DataSet from .dataset import DataSet
from .sampler import SequentialSampler
from .sampler import SequentialSampler, Sampler
from ._logger import logger from ._logger import logger




@@ -89,8 +89,8 @@ class DataSetGetter:
sin_x = _pad(sin_x, dataset=self.dataset, as_numpy=self.as_numpy) sin_x = _pad(sin_x, dataset=self.dataset, as_numpy=self.as_numpy)
sin_y = _pad(sin_y, dataset=self.dataset, as_numpy=self.as_numpy) sin_y = _pad(sin_y, dataset=self.dataset, as_numpy=self.as_numpy)


if not self.dataset.collector.is_empty():
bx, by = self.dataset._collect_batch(ins_list)
if not self.dataset.collater.is_empty():
bx, by = self.dataset._collate_batch(ins_list)
sin_x.update(bx) sin_x.update(bx)
sin_y.update(by) sin_y.update(by)


@@ -127,29 +127,35 @@ class BatchIter:
""" """
def __init__(self, dataset, batch_size=1, sampler=None, def __init__(self, dataset, batch_size=1, sampler=None,
num_workers=0, pin_memory=False, drop_last=False, num_workers=0, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None, collate_fn=None):
if not isinstance(sampler, torch.utils.data.Sampler):
self.sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset)
else:
self.sampler = sampler
timeout=0, worker_init_fn=None, collate_fn=None,
batch_sampler=None):
if isinstance(sampler, Sampler): # 如果时fastNLP的sampler需要adapt一下
sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset)
self.sampler = sampler
self.batch_sampler = batch_sampler


# DataLoader的collect_fn输入是List[],里面的元素是dataset[index]返回的结果
# DataLoader的collate_fn输入是List[],里面的元素是dataset[index]返回的结果
if collate_fn is None: if collate_fn is None:
# pytoch <= 1.1 中不能设置collate_fn=None # pytoch <= 1.1 中不能设置collate_fn=None
self.dataiter = torch.utils.data.DataLoader( self.dataiter = torch.utils.data.DataLoader(
dataset=dataset, batch_size=batch_size, sampler=self.sampler, dataset=dataset, batch_size=batch_size, sampler=self.sampler,
num_workers=num_workers, num_workers=num_workers,
pin_memory=pin_memory, drop_last=drop_last, pin_memory=pin_memory, drop_last=drop_last,
timeout=timeout, worker_init_fn=worker_init_fn)
timeout=timeout, worker_init_fn=worker_init_fn,
batch_sampler=batch_sampler)
else: else:
self.dataiter = torch.utils.data.DataLoader( self.dataiter = torch.utils.data.DataLoader(
dataset=dataset, batch_size=batch_size, sampler=self.sampler, dataset=dataset, batch_size=batch_size, sampler=self.sampler,
collate_fn=collate_fn, num_workers=num_workers, collate_fn=collate_fn, num_workers=num_workers,
pin_memory=pin_memory, drop_last=drop_last, pin_memory=pin_memory, drop_last=drop_last,
timeout=timeout, worker_init_fn=worker_init_fn)
timeout=timeout, worker_init_fn=worker_init_fn,
batch_sampler=batch_sampler)


# 以sampler的数量为准,因为DistributedSampler的时候每个进程上并不是所有的数据都用上了 # 以sampler的数量为准,因为DistributedSampler的时候每个进程上并不是所有的数据都用上了
self._num_batches = self.get_num_batches(len(self.dataiter.sampler), batch_size, drop_last)
if self.batch_sampler is None:
self._num_batches = self.get_num_batches(len(self.dataiter.sampler), batch_size, drop_last)
else:
self._num_batches = len(self.batch_sampler)
self.batch_size = batch_size self.batch_size = batch_size
self.cur_batch_indices = None self.cur_batch_indices = None


@@ -222,7 +228,8 @@ class DataSetIter(BatchIter):
""" """
def __init__(self, dataset, batch_size=1, sampler=None, as_numpy=False, def __init__(self, dataset, batch_size=1, sampler=None, as_numpy=False,
num_workers=0, pin_memory=False, drop_last=False, num_workers=0, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None, collate_fn=None):
timeout=0, worker_init_fn=None, collate_fn=None,
batch_sampler=None):
r""" r"""
:param dataset: :class:`~fastNLP.DataSet` 对象, 数据集 :param dataset: :class:`~fastNLP.DataSet` 对象, 数据集
@@ -239,15 +246,21 @@ class DataSetIter(BatchIter):
:param timeout: 生成一个batch的timeout值 :param timeout: 生成一个batch的timeout值
:param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。 :param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。
:param collate_fn: 用于将样本组合成batch的函数 :param collate_fn: 用于将样本组合成batch的函数
:param batch_sampler: 当每次batch取出的数据数量不一致时,可以使用该sampler。batch_sampler每次iter应该输出一个list的index。
当batch_sampler不为None时,参数batch_size, sampler, drop_last会被忽略。
""" """
assert isinstance(dataset, DataSet) assert isinstance(dataset, DataSet)
dataset = DataSetGetter(dataset, as_numpy) dataset = DataSetGetter(dataset, as_numpy)
collate_fn = dataset.collate_fn if collate_fn is None else collate_fn collate_fn = dataset.collate_fn if collate_fn is None else collate_fn
if batch_sampler is not None:
batch_size = 1
sampler = None
drop_last = False
super().__init__( super().__init__(
dataset=dataset, batch_size=batch_size, sampler=sampler, dataset=dataset, batch_size=batch_size, sampler=sampler,
num_workers=num_workers, pin_memory=pin_memory, num_workers=num_workers, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
collate_fn=collate_fn
collate_fn=collate_fn, batch_sampler=batch_sampler
) )


def __iter__(self): def __iter__(self):
@@ -384,12 +397,16 @@ class TorchLoaderIter(BatchIter):
os.remove(tmp_file_path) os.remove(tmp_file_path)
""" """
def __init__(self, dataset, batch_size=1, sampler=None,
def __init__(self, dataset, collate_fn, batch_size=1, sampler=None,
num_workers=0, pin_memory=False, drop_last=False, num_workers=0, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None, collate_fn=None):
timeout=0, worker_init_fn=None,
batch_sampler=None):
r""" r"""


:param dataset: :class:`~fastNLP.DataSet` 对象, 数据集
:param dataset: 实现了__getitem__和__len__方法的数据容器。
:param callable collate_fn: 用于将样本组合成batch的函数。输入为[dataset[idx1], dataset[idx2], ...], 即dataset中
__getitem__返回值组成的list,返回值必须为两个dict,其中第一个dict会被认为是input,第二个dict中的内容被认为是target。
需要转换为tensor的数据,需要在collate_fn中转化,但不需要转移到对应device。
:param int batch_size: 取出的batch大小 :param int batch_size: 取出的batch大小
:param sampler: 规定使用的 :class:`~fastNLP.Sampler` 方式. 若为 ``None`` , 使用 :class:`~fastNLP.SequentialSampler`. :param sampler: 规定使用的 :class:`~fastNLP.Sampler` 方式. 若为 ``None`` , 使用 :class:`~fastNLP.SequentialSampler`.
Default: ``None`` Default: ``None``
@@ -398,19 +415,21 @@ class TorchLoaderIter(BatchIter):
:param bool drop_last: 如果最后一个batch没有batch_size这么多sample,就扔掉最后一个 :param bool drop_last: 如果最后一个batch没有batch_size这么多sample,就扔掉最后一个
:param timeout: 生成一个batch的timeout值 :param timeout: 生成一个batch的timeout值
:param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。 :param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。
:param collate_fn: 用于将样本组合成batch的函数。
:param batch_sampler: 当每次batch取出的数据数量不一致时,可以使用该sampler。batch_sampler每次iter应该输出一个list的index。
当batch_sampler不为None时,参数batch_size, sampler, drop_last会被忽略。
""" """
assert len(dataset) > 0 assert len(dataset) > 0
ins = dataset[0]
if (len(ins) != 2 or not isinstance(ins[0], dict) or not isinstance(ins[1], dict)) and collate_fn is None:
raise RuntimeError("If the provided dataset does not return two dicts when call __getitem__(), the"
" `collate_fn` must be provided.")
assert collate_fn is not None, "You must pass collate_fn to pad the batch."
if batch_sampler is not None:
batch_size = 1
sampler = None
drop_last = False


super().__init__( super().__init__(
dataset=dataset, batch_size=batch_size, sampler=sampler, dataset=dataset, batch_size=batch_size, sampler=sampler,
num_workers=num_workers, pin_memory=pin_memory, num_workers=num_workers, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
collate_fn=collate_fn
collate_fn=collate_fn, batch_sampler=batch_sampler
) )


def __iter__(self): def __iter__(self):


fastNLP/core/collect_fn.py → fastNLP/core/collate_fn.py View File

@@ -36,72 +36,72 @@ def batching(samples, max_len=0, padding_val=0):
return batch return batch




class Collector:
class Collater:
r""" r"""
辅助DataSet管理collect_fn的类
辅助DataSet管理collate_fn的类


""" """
def __init__(self): def __init__(self):
self.collect_fns = {}
self.collate_fns = {}


def add_fn(self, fn, name=None): def add_fn(self, fn, name=None):
r""" r"""
向collector新增一个collect_fn函数
向collater新增一个collate_fn函数


:param callable fn: :param callable fn:
:param str,int name: :param str,int name:
:return: :return:
""" """
if name in self.collect_fns:
logger.warn(f"collect_fn:{name} will be overwritten.")
if name in self.collate_fns:
logger.warn(f"collate_fn:{name} will be overwritten.")
if name is None: if name is None:
name = len(self.collect_fns)
self.collect_fns[name] = fn
name = len(self.collate_fns)
self.collate_fns[name] = fn


def is_empty(self): def is_empty(self):
r""" r"""
返回是否包含collect_fn
返回是否包含collate_fn


:return: :return:
""" """
return len(self.collect_fns)==0
return len(self.collate_fns) == 0


def delete_fn(self, name=None): def delete_fn(self, name=None):
r""" r"""
删除collect_fn
删除collate_fn


:param str,int name: 如果为None就删除最近加入的collect_fn
:param str,int name: 如果为None就删除最近加入的collate_fn
:return: :return:
""" """
if not self.is_empty(): if not self.is_empty():
if name in self.collect_fns:
self.collect_fns.pop(name)
if name in self.collate_fns:
self.collate_fns.pop(name)
elif name is None: elif name is None:
last_key = list(self.collect_fns.keys())[0]
self.collect_fns.pop(last_key)
last_key = list(self.collate_fns.keys())[0]
self.collate_fns.pop(last_key)


def collect_batch(self, ins_list):
def collate_batch(self, ins_list):
bx, by = {}, {} bx, by = {}, {}
for name, fn in self.collect_fns.items():
for name, fn in self.collate_fns.items():
try: try:
batch_x, batch_y = fn(ins_list) batch_x, batch_y = fn(ins_list)
except BaseException as e: except BaseException as e:
logger.error(f"Exception:`{e}` happens when call collect_fn:`{name}`.")
logger.error(f"Exception:`{e}` happens when call collate_fn:`{name}`.")
raise e raise e
bx.update(batch_x) bx.update(batch_x)
by.update(batch_y) by.update(batch_y)
return bx, by return bx, by


def copy_from(self, col): def copy_from(self, col):
assert isinstance(col, Collector)
new_col = Collector()
new_col.collect_fns = deepcopy(col.collect_fns)
assert isinstance(col, Collater)
new_col = Collater()
new_col.collate_fns = deepcopy(col.collate_fns)
return new_col return new_col




class ConcatCollectFn:
class ConcatCollateFn:
r""" r"""
field拼接collect_fn,将不同field按序拼接后,padding产生数据。
field拼接collate_fn,将不同field按序拼接后,padding产生数据。


:param List[str] inputs: 将哪些field的数据拼接起来, 目前仅支持1d的field :param List[str] inputs: 将哪些field的数据拼接起来, 目前仅支持1d的field
:param str output: 拼接后的field名称 :param str output: 拼接后的field名称

+ 27
- 27
fastNLP/core/dataset.py View File

@@ -285,8 +285,8 @@ r"""
------------------------------------------------------------ ------------------------------------------------------------


DataSet支持在进行batch时,默认只能看到当前的field的值,但在某些训练中可能存在以下的情况: (1)需要两个field拼接成为一个field; DataSet支持在进行batch时,默认只能看到当前的field的值,但在某些训练中可能存在以下的情况: (1)需要两个field拼接成为一个field;
(2)需要在batch中进行负采样。这时候就需要能够同时利用多个field进行batch的操作,DataSet中的add_collect_fn()函数支持添加
自定义涉及多个field的collect_fn函数。例如下例中将两个field拼接成一个field的场景
(2)需要在batch中进行负采样。这时候就需要能够同时利用多个field进行batch的操作,DataSet中的add_collate_fn()函数支持添加
自定义涉及多个field的collate_fn函数。例如下例中将两个field拼接成一个field的场景


.. code-block:: .. code-block::


@@ -302,9 +302,9 @@ r"""
}) })
data.set_target('y') data.set_target('y')


# 所有的collect_fn函数都接受list[(ind1, instance1), (ind2, instance2), ...]作为输入,其中ind1/ind2是该instance在dataset中
# 所有的collate_fn函数都接受list[(ind1, instance1), (ind2, instance2), ...]作为输入,其中ind1/ind2是该instance在dataset中
# 的index,instance1/instance2是这次batch取出来的数据,包含了所有的field. # 的index,instance1/instance2是这次batch取出来的数据,包含了所有的field.
def concat_collect_fn(ins_list):
def concat_collate_fn(ins_list):
x1 = [ins['x1'] for ind,ins in ins_list] x1 = [ins['x1'] for ind,ins in ins_list]
x2 = [ins['x2'] for ind,ins in ins_list] x2 = [ins['x2'] for ind,ins in ins_list]
xs = [] xs = []
@@ -318,7 +318,7 @@ r"""
# 采用返回值。 # 采用返回值。
return b_x, b_y return b_x, b_y


data.add_collect_fn(concat_collect_fn)
data.add_collate_fn(concat_collate_fn)


for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2): for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2):
print("batch_x:", batch_x) print("batch_x:", batch_x)
@@ -328,7 +328,7 @@ r"""
# batch_y: {'y': array([0, 1])} # batch_y: {'y': array([0, 1])}


# 如果取batch过程含有一些参数,可以通过类来实现 # 如果取batch过程含有一些参数,可以通过类来实现
class ConCollectFn:
class ConCollateFn:
def __init__(self, max_len=3): def __init__(self, max_len=3):
self.max_len = max_len self.max_len = max_len


@@ -342,8 +342,8 @@ r"""
b_x = {'x': arr} b_x = {'x': arr}
b_y = {} b_y = {}
return b_x, b_y return b_x, b_y
data.delete_collect_fn() # 删除之前的collect_fn
data.add_collect_fn(ConCollectFn(max_len=3))
data.delete_collate_fn() # 删除之前的collate_fn
data.add_collate_fn(ConCollateFn(max_len=3))
for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2): for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2):
print("batch_x:", batch_x) print("batch_x:", batch_x)
print("batch_y:", batch_y) print("batch_y:", batch_y)
@@ -370,7 +370,7 @@ from .field import FieldArray
from .field import SetInputOrTargetException from .field import SetInputOrTargetException
from .instance import Instance from .instance import Instance
from .utils import pretty_table_printer from .utils import pretty_table_printer
from .collect_fn import Collector
from .collate_fn import Collater




class ApplyResultException(Exception): class ApplyResultException(Exception):
@@ -406,7 +406,7 @@ class DataSet(object):


else: else:
raise ValueError("data only be dict or list type.") raise ValueError("data only be dict or list type.")
self.collector = Collector()
self.collater = Collater()


def __contains__(self, item): def __contains__(self, item):
return item in self.field_arrays return item in self.field_arrays
@@ -462,7 +462,7 @@ class DataSet(object):
for field in self.field_arrays.values(): for field in self.field_arrays.values():
data_set.add_field(field_name=field.name, fields=field.content[idx], padder=field.padder, data_set.add_field(field_name=field.name, fields=field.content[idx], padder=field.padder,
is_input=field.is_input, is_target=field.is_target, ignore_type=field.ignore_type) is_input=field.is_input, is_target=field.is_target, ignore_type=field.ignore_type)
data_set.collector = self.collector.copy_from(self.collector)
data_set.collater = self.collater.copy_from(self.collater)
return data_set return data_set
elif isinstance(idx, str): elif isinstance(idx, str):
if idx not in self: if idx not in self:
@@ -476,7 +476,7 @@ class DataSet(object):
dataset.append(instance) dataset.append(instance)
for field_name, field in self.field_arrays.items(): for field_name, field in self.field_arrays.items():
dataset.field_arrays[field_name].to(field) dataset.field_arrays[field_name].to(field)
dataset.collector = self.collector.copy_from(self.collector)
dataset.collater = self.collater.copy_from(self.collater)
return dataset return dataset
else: else:
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx)))
@@ -1083,8 +1083,8 @@ class DataSet(object):
train_set.field_arrays[field_name].to(self.field_arrays[field_name]) train_set.field_arrays[field_name].to(self.field_arrays[field_name])
dev_set.field_arrays[field_name].to(self.field_arrays[field_name]) dev_set.field_arrays[field_name].to(self.field_arrays[field_name])


train_set.collector.copy_from(self.collector)
dev_set.collector.copy_from(self.collector)
train_set.collater.copy_from(self.collater)
dev_set.collater.copy_from(self.collater)
return train_set, dev_set return train_set, dev_set


def save(self, path): def save(self, path):
@@ -1109,30 +1109,30 @@ class DataSet(object):
assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d)) assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d))
return d return d


def add_collect_fn(self, fn, name=None):
def add_collate_fn(self, fn, name=None):
r""" r"""
添加 CollectFn,collect_fn允许在生成的batch的过程中动态生成一些数据(在DataSetIter作为迭代器的情况下有效,默认情况下就是用的
这个)。支持依次添加多个collect_fn, 如果相同的key,后面的collect_fn的结果覆盖前面的collect_fn的结果。
添加 CollateFn,collate_fn允许在生成的batch的过程中动态生成一些数据(在DataSetIter作为迭代器的情况下有效,默认情况下就是用的
这个)。支持依次添加多个collate_fn, 如果相同的key,后面的collate_fn的结果覆盖前面的collate_fn的结果。


:param callable fn: 传入一个可调用的function, 该function可接受的参数为List[(ind1, instance1), (ind2, instance2)] :param callable fn: 传入一个可调用的function, 该function可接受的参数为List[(ind1, instance1), (ind2, instance2)]
(某个batch被选中的所有的indice以及instance),其中ind1/ind2是该instance在dataset中的index,instance1/instance2是 (某个batch被选中的所有的indice以及instance),其中ind1/ind2是该instance在dataset中的index,instance1/instance2是
这次batch取出来的数据,包含了所有的field。返回值需要为两个dict,第一个dict的值将被认为是input,第二个dict的值被认为是 这次batch取出来的数据,包含了所有的field。返回值需要为两个dict,第一个dict的值将被认为是input,第二个dict的值被认为是
target,返回的值至多允许一个空dict。若返回的dict中包含了被设置为input或target的field的名称,将覆盖dataset中的field。 target,返回的值至多允许一个空dict。若返回的dict中包含了被设置为input或target的field的名称,将覆盖dataset中的field。
fastNLP不会将collect_fn的返回结果pad和转换为tensor,需要在collect_fn中完成pad和转换为tensor(不需要将tensor移动到
gpu中,如果是pytorch的tensor,fastNLP会自动将其移动到特定gpu)。不要修改传入collect_fn中的数据,否则可能导致未知问题。
:param str,int name: collect_fn的名称,如果不传入,默认使用自增长的数字作为key。相同的name会覆盖之前的collect_fn。
fastNLP不会将collate_fn的返回结果pad和转换为tensor,需要在collate_fn中完成pad和转换为tensor(不需要将tensor移动到
gpu中,fastNLP会自动将其移动到特定gpu)。不要修改传入collate_fn中的数据,否则可能导致未知问题。
:param str,int name: collate_fn的名称,如果不传入,默认使用自增长的数字作为key。相同的name会覆盖之前的collate_fn。
""" """
assert callable(fn), "You must pass in a callable object." assert callable(fn), "You must pass in a callable object."
self.collector.add_fn(fn, name=name)
self.collater.add_fn(fn, name=name)


def delete_collect_fn(self, name=None):
def delete_collate_fn(self, name=None):
r""" r"""
删除某个collect_fn
删除某个collate_fn


:param str,int name: 如果为None,则删除最近加入的collect_fn
:param str,int name: 如果为None,则删除最近加入的collate_fn
:return: :return:
""" """
self.collector.delete_fn(name)
self.collater.delete_fn(name)


def _collect_batch(self, ins_list):
return self.collector.collect_batch(ins_list)
def _collate_batch(self, ins_list):
return self.collater.collate_batch(ins_list)

+ 4
- 4
fastNLP/core/utils.py View File

@@ -718,8 +718,8 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
_tmp += f' Or provide `{_miss}` in DataSet or the output of {prev_func_signature}. ' _tmp += f' Or provide `{_miss}` in DataSet or the output of {prev_func_signature}. '
else: else:
_tmp = f'Provide `{_miss}` in DataSet or the output of {prev_func_signature}.' _tmp = f'Provide `{_miss}` in DataSet or the output of {prev_func_signature}.'
if not dataset.collector.is_empty():
_tmp += f'Or you need to add `{_miss}` in the output of your collect_fn. '
if not dataset.collater.is_empty():
_tmp += f'Or you need to add `{_miss}` in the output of your collate_fn. '
suggestions.append(_tmp) suggestions.append(_tmp)


if check_res.duplicated: if check_res.duplicated:
@@ -779,8 +779,8 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level):
suggestions.append(f"You might need to set `{_miss_in_dataset}` as input. ") suggestions.append(f"You might need to set `{_miss_in_dataset}` as input. ")
if _miss_out_dataset: if _miss_out_dataset:
_tmp = f"You need to provide `{_miss_out_dataset}` in DataSet and set it as input. " _tmp = f"You need to provide `{_miss_out_dataset}` in DataSet and set it as input. "
if not dataset.collector.is_empty():
_tmp += f'Or you need to add `{_miss_out_dataset}` in the output of your collect_fn. '
if not dataset.collator.is_empty():
_tmp += f'Or you need to add `{_miss_out_dataset}` in the output of your collate_fn. '
suggestions.append(_tmp) suggestions.append(_tmp)


if check_res.unused: if check_res.unused:


+ 6
- 6
fastNLP/io/data_bundle.py View File

@@ -348,26 +348,26 @@ class DataBundle:
dataset.apply(func, new_field_name=new_field_name, **kwargs) dataset.apply(func, new_field_name=new_field_name, **kwargs)
return self return self


def add_collect_fn(self, fn, name=None):
def add_collate_fn(self, fn, name=None):
r""" r"""
向所有DataSet增加collect_fn, collect_fn详见 :class:`~fastNLP.DataSet` 中相关说明.
向所有DataSet增加collate_fn, collate_fn详见 :class:`~fastNLP.DataSet` 中相关说明.


:param callable fn: :param callable fn:
:param name: :param name:
:return: :return:
""" """
for _, dataset in self.datasets.items(): for _, dataset in self.datasets.items():
dataset.add_collect_fn(fn=fn, name=name)
dataset.add_collate_fn(fn=fn, name=name)


def delete_collect_fn(self, name=None):
def delete_collate_fn(self, name=None):
r""" r"""
删除DataSet中的collect_fn
删除DataSet中的collate_fn


:param name: :param name:
:return: :return:
""" """
for _, dataset in self.datasets.items(): for _, dataset in self.datasets.items():
dataset.delete_collect_fn(name=name)
dataset.delete_collate_fn(name=name)


def __repr__(self): def __repr__(self):
_str = '' _str = ''


+ 99
- 36
test/core/test_batch.py View File

@@ -7,7 +7,7 @@ from fastNLP import DataSetIter, TorchLoaderIter
from fastNLP import DataSet from fastNLP import DataSet
from fastNLP import Instance from fastNLP import Instance
from fastNLP import SequentialSampler from fastNLP import SequentialSampler
from fastNLP import ConcatCollectFn
from fastNLP import ConcatCollateFn




def generate_fake_dataset(num_samples=1000): def generate_fake_dataset(num_samples=1000):
@@ -177,76 +177,76 @@ class TestCase1(unittest.TestCase):
for con,t in zip(cons, test): for con,t in zip(cons, test):
self.assertEqual(alphas[:con], t) self.assertEqual(alphas[:con], t)


def test_collect_fn(self):
def test_collate_fn(self):
batch_size = 32 batch_size = 32
num_samples = 1000 num_samples = 1000
dataset = generate_fake_dataset(num_samples) dataset = generate_fake_dataset(num_samples)
dataset.set_input('1','2') dataset.set_input('1','2')
dataset.set_target('0','3') dataset.set_target('0','3')


fn = ConcatCollectFn(inputs=['1', '2'], output='12', pad_val=0, max_len=0, is_input=True, is_target=False)
dataset.add_collect_fn(fn, name='demo')
fn = ConcatCollateFn(inputs=['1', '2'], output='12', pad_val=0, max_len=0, is_input=True, is_target=False)
dataset.add_collate_fn(fn, name='demo')
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True)
for batch_x, batch_y in batch: for batch_x, batch_y in batch:
for i in range(batch_size): for i in range(batch_size):
# print(i) # print(i)
self.assertEqual(batch_x['12'][i].sum(), batch_x['1'][i].sum() + batch_x['2'][i].sum()) self.assertEqual(batch_x['12'][i].sum(), batch_x['1'][i].sum() + batch_x['2'][i].sum())
dataset.delete_collect_fn(name='demo')
dataset.delete_collate_fn(name='demo')


# 测试非input的情况 # 测试非input的情况
dataset.set_input('1', '2', flag=False) # dataset.set_input('1', '2', flag=False) #
fn = ConcatCollectFn(inputs=['1', '2'], output='12', pad_val=0, max_len=0, is_input=True, is_target=False)
dataset.add_collect_fn(fn, name='demo')
fn = ConcatCollateFn(inputs=['1', '2'], output='12', pad_val=0, max_len=0, is_input=True, is_target=False)
dataset.add_collate_fn(fn, name='demo')
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True)
for batch_x, batch_y in batch: for batch_x, batch_y in batch:
for i in range(batch_size): for i in range(batch_size):
self.assertTrue('12' in batch_x) self.assertTrue('12' in batch_x)
dataset.delete_collect_fn(name='demo')
dataset.delete_collate_fn(name='demo')
dataset.set_input('1', '2', flag=True) # dataset.set_input('1', '2', flag=True) #


# 测试覆盖其它field的情况 # 测试覆盖其它field的情况
fn = ConcatCollectFn(inputs=['1', '2'], output='3', pad_val=0, max_len=0, is_input=True, is_target=True)
dataset.add_collect_fn(fn, name='demo')
fn = ConcatCollateFn(inputs=['1', '2'], output='3', pad_val=0, max_len=0, is_input=True, is_target=True)
dataset.add_collate_fn(fn, name='demo')
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True)
for batch_x, batch_y in batch: for batch_x, batch_y in batch:
for i in range(batch_size): for i in range(batch_size):
# print(i) # print(i)
self.assertEqual(batch_y['3'][i].sum(), batch_x['1'][i].sum() + batch_x['2'][i].sum()) self.assertEqual(batch_y['3'][i].sum(), batch_x['1'][i].sum() + batch_x['2'][i].sum())
dataset.delete_collect_fn(name='demo')
dataset.delete_collate_fn(name='demo')


# 测试非input,target的情况 # 测试非input,target的情况
dataset.set_input('1', '2', flag=False) dataset.set_input('1', '2', flag=False)
fn = ConcatCollectFn(inputs=['1', '2'], output='3', pad_val=0, max_len=0, is_input=True, is_target=True)
dataset.add_collect_fn(fn, name='demo')
fn = ConcatCollateFn(inputs=['1', '2'], output='3', pad_val=0, max_len=0, is_input=True, is_target=True)
dataset.add_collate_fn(fn, name='demo')
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True)
for batch_x, batch_y in batch: for batch_x, batch_y in batch:
for i in range(batch_size): for i in range(batch_size):
# print(i) # print(i)
self.assertTrue('3' in batch_x) self.assertTrue('3' in batch_x)
self.assertTrue('3' in batch_y) self.assertTrue('3' in batch_y)
dataset.delete_collect_fn(name='demo')
dataset.delete_collate_fn(name='demo')


# 测试加入非法fn的请 # 测试加入非法fn的请
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
dataset.add_collect_fn(1)
dataset.add_collate_fn(1)


# 测试collect_fn返回值只有一个的情况
def demo_collect_fn(ins_list):
# 测试collate_fn返回值只有一个的情况
def demo_collate_fn(ins_list):
return {'3':1} return {'3':1}
dataset.add_collect_fn(demo_collect_fn, name='demo')
dataset.add_collate_fn(demo_collate_fn, name='demo')
with self.assertRaises(BaseException): with self.assertRaises(BaseException):
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True)
for batch_x, batch_y in batch: for batch_x, batch_y in batch:
pass pass
dataset.delete_collect_fn(name='demo')
dataset.delete_collate_fn(name='demo')


# 测试多个collect_fn
dataset.add_collect_fn(demo_collect_fn, name='demo')
dataset.add_collect_fn(demo_collect_fn, name='demo')
# 测试多个collate_fn
dataset.add_collate_fn(demo_collate_fn, name='demo')
dataset.add_collate_fn(demo_collate_fn, name='demo')
# 测试删除 # 测试删除
dataset.delete_collect_fn()
dataset.delete_collect_fn()
self.assertTrue(dataset.collector.is_empty())
dataset.delete_collate_fn()
dataset.delete_collate_fn()
self.assertTrue(dataset.collater.is_empty())


def test_demo(self): def test_demo(self):
import torch import torch
@@ -261,9 +261,9 @@ class TestCase1(unittest.TestCase):
}) })
data.set_target('y') data.set_target('y')


# 所有的collect_fn函数都接受list[(ind1, instance1), (ind2, instance2), ...]作为输入,其中ind1/ind2是该instance在dataset中
# 所有的collate_fn函数都接受list[(ind1, instance1), (ind2, instance2), ...]作为输入,其中ind1/ind2是该instance在dataset中
# 的index,instance1/instance2是这次batch取出来的数据,包含了所有的field. # 的index,instance1/instance2是这次batch取出来的数据,包含了所有的field.
def concat_collect_fn(ins_list):
def concat_collate_fn(ins_list):
x1 = [ins['x1'] for ind,ins in ins_list] x1 = [ins['x1'] for ind,ins in ins_list]
x2 = [ins['x2'] for ind,ins in ins_list] x2 = [ins['x2'] for ind,ins in ins_list]
xs = [] xs = []
@@ -277,7 +277,7 @@ class TestCase1(unittest.TestCase):
# 采用返回值。 # 采用返回值。
return b_x, b_y return b_x, b_y


data.add_collect_fn(concat_collect_fn)
data.add_collate_fn(concat_collate_fn)


for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2): for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2):
print("batch_x:", batch_x) print("batch_x:", batch_x)
@@ -287,7 +287,7 @@ class TestCase1(unittest.TestCase):
# batch_y: {'y': array([0, 1])} # batch_y: {'y': array([0, 1])}


# 如果取batch过程含有一些参数,可以通过类来实现 # 如果取batch过程含有一些参数,可以通过类来实现
class ConCollectFn:
class ConCollateFn:
def __init__(self, max_len=3): def __init__(self, max_len=3):
self.max_len = max_len self.max_len = max_len
def __call__(self, ins_list): def __call__(self, ins_list):
@@ -300,8 +300,8 @@ class TestCase1(unittest.TestCase):
b_x = {'x': arr} b_x = {'x': arr}
b_y = {} b_y = {}
return b_x, b_y return b_x, b_y
data.delete_collect_fn() # 删除之前的collect_fn
data.add_collect_fn(ConCollectFn(max_len=3))
data.delete_collate_fn() # 删除之前的collate_fn
data.add_collate_fn(ConCollateFn(max_len=3))
for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2): for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2):
print("batch_x:", batch_x) print("batch_x:", batch_x)
print("batch_y:", batch_y) print("batch_y:", batch_y)
@@ -326,14 +326,77 @@ class TestCase1(unittest.TestCase):
return x, y return x, y


data1 = FakeData() data1 = FakeData()
dataiter = TorchLoaderIter(data1, batch_size=2)
def collact_fn(ins_list):
xs = [ins[0]['x'] for ins in ins_list]
ys = [ins[1]['y'] for ins in ins_list]
return {'x':xs}, {'y':ys}
dataiter = TorchLoaderIter(data1, collate_fn=collact_fn, batch_size=2)
for x, y in dataiter: for x, y in dataiter:
print(x, y) print(x, y)


def func():
data2 = FakeData(return_dict=False)
dataiter = TorchLoaderIter(data2, batch_size=2)
self.assertRaises(Exception, func)
def test_batch_sampler(self):
# 测试DataSetIter与TorchLoaderIter的batch_sampler能否正常工作
# DataSetIter
ds = generate_fake_dataset(5)
ds.set_input('1')
class BatchSampler:
def __init__(self, dataset):
self.num_samples = len(dataset)

def __iter__(self):
index = 0
indexes = list(range(self.num_samples))
np.random.shuffle(indexes)
start_idx = 0
while index < self.num_samples:
if start_idx == 0:
end_index = self.num_samples//2
else:
end_index = self.num_samples
yield indexes[start_idx:end_index]
index = end_index
start_idx = end_index

def __len__(self):
return 2

batch_sampler = BatchSampler(ds)

data_iter = DataSetIter(ds, batch_size=10, sampler=batch_sampler, as_numpy=False,
num_workers=0, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None, collate_fn=None,
batch_sampler=batch_sampler)
num_samples = [len(ds)//2, len(ds)-len(ds)//2]
for idx, (batch_x, batch_y) in enumerate(data_iter):
self.assertEqual(num_samples[idx], len(batch_x['1']))

# TorchLoaderIter
class FakeData:
def __init__(self):
self.x = [[1,2,3], [4,5,6], [1,2]]

def __len__(self):
return len(self.x)

def __getitem__(self, i):
x = self.x[i]
y = 0
return x,y

def collate_fn(ins_list):
xs = [ins[0] for ins in ins_list]
ys = [ins[1] for ins in ins_list]
return {'x':xs}, {'y':ys}

ds = FakeData()
batch_sampler = BatchSampler(ds)
data_iter = TorchLoaderIter(ds, batch_size=10, sampler=batch_sampler,
num_workers=0, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None, collate_fn=collate_fn,
batch_sampler=batch_sampler)
num_samples = [len(ds)//2, len(ds)-len(ds)//2]
for idx, (batch_x, batch_y) in enumerate(data_iter):
self.assertEqual(num_samples[idx], len(batch_x['x']))


""" """
def test_multi_workers_batch(self): def test_multi_workers_batch(self):


+ 69
- 8
test/core/test_trainer.py View File

@@ -243,7 +243,7 @@ class TrainerTestGround(unittest.TestCase):
def __len__(self): def __len__(self):
return self.num_samples return self.num_samples


def collect_fn(data_list):
def collate_fn(data_list):
# [(x1,y1), (x2,y2), ...], 这里的输入实际上是将UdfDataSet的__getitem__输入结合为list # [(x1,y1), (x2,y2), ...], 这里的输入实际上是将UdfDataSet的__getitem__输入结合为list
xs, ys = [], [] xs, ys = [], []
for l in data_list: for l in data_list:
@@ -254,7 +254,7 @@ class TrainerTestGround(unittest.TestCase):
return {'x':x, 'y':y}, {'y':y} return {'x':x, 'y':y}, {'y':y}


dataset = UdfDataSet(10) dataset = UdfDataSet(10)
dataset = TorchLoaderIter(dataset, collate_fn=collect_fn)
dataset = TorchLoaderIter(dataset, collate_fn=collate_fn)
class Model(nn.Module): class Model(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@@ -268,6 +268,67 @@ class TrainerTestGround(unittest.TestCase):
metrics=AccuracyMetric(target='y'), use_tqdm=False) metrics=AccuracyMetric(target='y'), use_tqdm=False)
trainer.train(load_best_model=False) trainer.train(load_best_model=False)


def test_batch_sampler_dataiter(self):
import random
import torch
class BatchSampler:
def __init__(self, dataset):
self.num_samples = len(dataset)

def __iter__(self):
index = 0
indexes = list(range(self.num_samples))
np.random.shuffle(indexes)
start_idx = 0
while index < self.num_samples:
if start_idx == 0:
end_index = self.num_samples//2
else:
end_index = self.num_samples
yield indexes[start_idx:end_index]
index = end_index
start_idx = end_index
def __len__(self):
return 2

class UdfDataSet:
def __init__(self, num_samples):
self.num_samples = num_samples

def __getitem__(self, idx):
x = [random.random() for _ in range(3)]
y = random.random()
return x,y

def __len__(self):
return self.num_samples

def collate_fn(data_list):
# [(x1,y1), (x2,y2), ...], 这里的输入实际上是将UdfDataSet的__getitem__输入结合为list
xs, ys = [], []
for l in data_list:
x, y = l
xs.append(x)
ys.append(y)
x,y = torch.FloatTensor(xs), torch.FloatTensor(ys)
return {'x':x, 'y':y}, {'y':y}

dataset = UdfDataSet(11)
batch_sampler = BatchSampler(dataset)
dataset = TorchLoaderIter(dataset, collate_fn=collate_fn, batch_sampler=batch_sampler)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(3, 1)
def forward(self, x, y):
return {'loss':torch.pow(self.fc(x).squeeze(-1)-y, 2).sum()}
def predict(self, x):
return {'pred':self.fc(x).squeeze(-1)}
model = Model()
trainer = Trainer(train_data=dataset, model=model, loss=None, print_every=2, dev_data=dataset,
metrics=AccuracyMetric(target='y'), use_tqdm=False)
trainer.train(load_best_model=False)

def test_onthefly_iter(self): def test_onthefly_iter(self):
import tempfile import tempfile
import random import random
@@ -333,7 +394,7 @@ class TrainerTestGround(unittest.TestCase):
return {'loss': torch.pow(self.fc(x).squeeze(-1) - y, 2).sum()} return {'loss': torch.pow(self.fc(x).squeeze(-1) - y, 2).sum()}


def predict(self, x): def predict(self, x):
return {'pred': self.fc(x).squeeze(0)}
return {'pred': self.fc(x).squeeze(-1)}


model = Model() model = Model()
trainer = Trainer(train_data=dataset, model=model, loss=None, print_every=2, dev_data=dataset, trainer = Trainer(train_data=dataset, model=model, loss=None, print_every=2, dev_data=dataset,
@@ -356,7 +417,7 @@ class TrainerTestGround(unittest.TestCase):
x.append(ins['x1']+ins['x2']) x.append(ins['x1']+ins['x2'])
x = torch.FloatTensor(x) x = torch.FloatTensor(x)
return {'x':x}, {} return {'x':x}, {}
dataset.add_collect_fn(fn)
dataset.add_collate_fn(fn)


class Model(nn.Module): class Model(nn.Module):
def __init__(self): def __init__(self):
@@ -377,7 +438,7 @@ class TrainerTestGround(unittest.TestCase):
dev_data=dataset, metrics=AccuracyMetric(target='y'), use_tqdm=False) dev_data=dataset, metrics=AccuracyMetric(target='y'), use_tqdm=False)
trainer.train() trainer.train()


def test_collect_fn2(self):
def test_collate_fn2(self):
"""测试能否实现batch_x, batch_y""" """测试能否实现batch_x, batch_y"""
dataset = prepare_fake_dataset2('x1', 'x2') dataset = prepare_fake_dataset2('x1', 'x2')
dataset.set_input('x1', 'x2') dataset.set_input('x1', 'x2')
@@ -389,7 +450,7 @@ class TrainerTestGround(unittest.TestCase):
x.append(ins['x1']+ins['x2']) x.append(ins['x1']+ins['x2'])
x = torch.FloatTensor(x) x = torch.FloatTensor(x)
return {'x':x}, {'target':x[:, :4].argmax(dim=-1)} return {'x':x}, {'target':x[:, :4].argmax(dim=-1)}
dataset.add_collect_fn(fn)
dataset.add_collate_fn(fn)


class Model(nn.Module): class Model(nn.Module):
def __init__(self): def __init__(self):
@@ -410,7 +471,7 @@ class TrainerTestGround(unittest.TestCase):
dev_data=dataset, metrics=AccuracyMetric(), use_tqdm=False) dev_data=dataset, metrics=AccuracyMetric(), use_tqdm=False)
trainer.train() trainer.train()


def test_collect_fn3(self):
def test_collate_fn3(self):
""" """
测试应该会覆盖 测试应该会覆盖


@@ -426,7 +487,7 @@ class TrainerTestGround(unittest.TestCase):
x.append(ins['x1']+ins['x2']) x.append(ins['x1']+ins['x2'])
x = torch.FloatTensor(x) x = torch.FloatTensor(x)
return {'x1':torch.zeros_like(x)}, {'target':torch.zeros(x.size(0)).long(), 'y':x} return {'x1':torch.zeros_like(x)}, {'target':torch.zeros(x.size(0)).long(), 'y':x}
dataset.add_collect_fn(fn)
dataset.add_collate_fn(fn)


class Model(nn.Module): class Model(nn.Module):
def __init__(self): def __init__(self):


Loading…
Cancel
Save