@@ -44,8 +44,8 @@ __all__ = [ | |||
"AutoPadder", | |||
"EngChar2DPadder", | |||
# "CollectFn", | |||
"ConcatCollectFn", | |||
# "CollateFn", | |||
"ConcatCollateFn", | |||
"MetricBase", | |||
"AccuracyMetric", | |||
@@ -21,7 +21,7 @@ __all__ = [ | |||
"AutoPadder", | |||
"EngChar2DPadder", | |||
"ConcatCollectFn", | |||
"ConcatCollateFn", | |||
"Vocabulary", | |||
@@ -99,4 +99,4 @@ from .tester import Tester | |||
from .trainer import Trainer | |||
from .utils import cache_results, seq_len_to_mask, get_seq_len | |||
from .vocabulary import Vocabulary | |||
from .collect_fn import ConcatCollectFn | |||
from .collate_fn import ConcatCollateFn |
@@ -18,7 +18,7 @@ import torch.utils.data | |||
from collections import defaultdict | |||
from .dataset import DataSet | |||
from .sampler import SequentialSampler | |||
from .sampler import SequentialSampler, Sampler | |||
from ._logger import logger | |||
@@ -89,8 +89,8 @@ class DataSetGetter: | |||
sin_x = _pad(sin_x, dataset=self.dataset, as_numpy=self.as_numpy) | |||
sin_y = _pad(sin_y, dataset=self.dataset, as_numpy=self.as_numpy) | |||
if not self.dataset.collector.is_empty(): | |||
bx, by = self.dataset._collect_batch(ins_list) | |||
if not self.dataset.collater.is_empty(): | |||
bx, by = self.dataset._collate_batch(ins_list) | |||
sin_x.update(bx) | |||
sin_y.update(by) | |||
@@ -127,29 +127,35 @@ class BatchIter: | |||
""" | |||
def __init__(self, dataset, batch_size=1, sampler=None, | |||
num_workers=0, pin_memory=False, drop_last=False, | |||
timeout=0, worker_init_fn=None, collate_fn=None): | |||
if not isinstance(sampler, torch.utils.data.Sampler): | |||
self.sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset) | |||
else: | |||
self.sampler = sampler | |||
timeout=0, worker_init_fn=None, collate_fn=None, | |||
batch_sampler=None): | |||
if isinstance(sampler, Sampler): # 如果时fastNLP的sampler需要adapt一下 | |||
sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset) | |||
self.sampler = sampler | |||
self.batch_sampler = batch_sampler | |||
# DataLoader的collect_fn输入是List[],里面的元素是dataset[index]返回的结果 | |||
# DataLoader的collate_fn输入是List[],里面的元素是dataset[index]返回的结果 | |||
if collate_fn is None: | |||
# pytoch <= 1.1 中不能设置collate_fn=None | |||
self.dataiter = torch.utils.data.DataLoader( | |||
dataset=dataset, batch_size=batch_size, sampler=self.sampler, | |||
num_workers=num_workers, | |||
pin_memory=pin_memory, drop_last=drop_last, | |||
timeout=timeout, worker_init_fn=worker_init_fn) | |||
timeout=timeout, worker_init_fn=worker_init_fn, | |||
batch_sampler=batch_sampler) | |||
else: | |||
self.dataiter = torch.utils.data.DataLoader( | |||
dataset=dataset, batch_size=batch_size, sampler=self.sampler, | |||
collate_fn=collate_fn, num_workers=num_workers, | |||
pin_memory=pin_memory, drop_last=drop_last, | |||
timeout=timeout, worker_init_fn=worker_init_fn) | |||
timeout=timeout, worker_init_fn=worker_init_fn, | |||
batch_sampler=batch_sampler) | |||
# 以sampler的数量为准,因为DistributedSampler的时候每个进程上并不是所有的数据都用上了 | |||
self._num_batches = self.get_num_batches(len(self.dataiter.sampler), batch_size, drop_last) | |||
if self.batch_sampler is None: | |||
self._num_batches = self.get_num_batches(len(self.dataiter.sampler), batch_size, drop_last) | |||
else: | |||
self._num_batches = len(self.batch_sampler) | |||
self.batch_size = batch_size | |||
self.cur_batch_indices = None | |||
@@ -222,7 +228,8 @@ class DataSetIter(BatchIter): | |||
""" | |||
def __init__(self, dataset, batch_size=1, sampler=None, as_numpy=False, | |||
num_workers=0, pin_memory=False, drop_last=False, | |||
timeout=0, worker_init_fn=None, collate_fn=None): | |||
timeout=0, worker_init_fn=None, collate_fn=None, | |||
batch_sampler=None): | |||
r""" | |||
:param dataset: :class:`~fastNLP.DataSet` 对象, 数据集 | |||
@@ -239,15 +246,21 @@ class DataSetIter(BatchIter): | |||
:param timeout: 生成一个batch的timeout值 | |||
:param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。 | |||
:param collate_fn: 用于将样本组合成batch的函数 | |||
:param batch_sampler: 当每次batch取出的数据数量不一致时,可以使用该sampler。batch_sampler每次iter应该输出一个list的index。 | |||
当batch_sampler不为None时,参数batch_size, sampler, drop_last会被忽略。 | |||
""" | |||
assert isinstance(dataset, DataSet) | |||
dataset = DataSetGetter(dataset, as_numpy) | |||
collate_fn = dataset.collate_fn if collate_fn is None else collate_fn | |||
if batch_sampler is not None: | |||
batch_size = 1 | |||
sampler = None | |||
drop_last = False | |||
super().__init__( | |||
dataset=dataset, batch_size=batch_size, sampler=sampler, | |||
num_workers=num_workers, pin_memory=pin_memory, | |||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
collate_fn=collate_fn | |||
collate_fn=collate_fn, batch_sampler=batch_sampler | |||
) | |||
def __iter__(self): | |||
@@ -384,12 +397,16 @@ class TorchLoaderIter(BatchIter): | |||
os.remove(tmp_file_path) | |||
""" | |||
def __init__(self, dataset, batch_size=1, sampler=None, | |||
def __init__(self, dataset, collate_fn, batch_size=1, sampler=None, | |||
num_workers=0, pin_memory=False, drop_last=False, | |||
timeout=0, worker_init_fn=None, collate_fn=None): | |||
timeout=0, worker_init_fn=None, | |||
batch_sampler=None): | |||
r""" | |||
:param dataset: :class:`~fastNLP.DataSet` 对象, 数据集 | |||
:param dataset: 实现了__getitem__和__len__方法的数据容器。 | |||
:param callable collate_fn: 用于将样本组合成batch的函数。输入为[dataset[idx1], dataset[idx2], ...], 即dataset中 | |||
__getitem__返回值组成的list,返回值必须为两个dict,其中第一个dict会被认为是input,第二个dict中的内容被认为是target。 | |||
需要转换为tensor的数据,需要在collate_fn中转化,但不需要转移到对应device。 | |||
:param int batch_size: 取出的batch大小 | |||
:param sampler: 规定使用的 :class:`~fastNLP.Sampler` 方式. 若为 ``None`` , 使用 :class:`~fastNLP.SequentialSampler`. | |||
Default: ``None`` | |||
@@ -398,19 +415,21 @@ class TorchLoaderIter(BatchIter): | |||
:param bool drop_last: 如果最后一个batch没有batch_size这么多sample,就扔掉最后一个 | |||
:param timeout: 生成一个batch的timeout值 | |||
:param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。 | |||
:param collate_fn: 用于将样本组合成batch的函数。 | |||
:param batch_sampler: 当每次batch取出的数据数量不一致时,可以使用该sampler。batch_sampler每次iter应该输出一个list的index。 | |||
当batch_sampler不为None时,参数batch_size, sampler, drop_last会被忽略。 | |||
""" | |||
assert len(dataset) > 0 | |||
ins = dataset[0] | |||
if (len(ins) != 2 or not isinstance(ins[0], dict) or not isinstance(ins[1], dict)) and collate_fn is None: | |||
raise RuntimeError("If the provided dataset does not return two dicts when call __getitem__(), the" | |||
" `collate_fn` must be provided.") | |||
assert collate_fn is not None, "You must pass collate_fn to pad the batch." | |||
if batch_sampler is not None: | |||
batch_size = 1 | |||
sampler = None | |||
drop_last = False | |||
super().__init__( | |||
dataset=dataset, batch_size=batch_size, sampler=sampler, | |||
num_workers=num_workers, pin_memory=pin_memory, | |||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
collate_fn=collate_fn | |||
collate_fn=collate_fn, batch_sampler=batch_sampler | |||
) | |||
def __iter__(self): | |||
@@ -36,72 +36,72 @@ def batching(samples, max_len=0, padding_val=0): | |||
return batch | |||
class Collector: | |||
class Collater: | |||
r""" | |||
辅助DataSet管理collect_fn的类 | |||
辅助DataSet管理collate_fn的类 | |||
""" | |||
def __init__(self): | |||
self.collect_fns = {} | |||
self.collate_fns = {} | |||
def add_fn(self, fn, name=None): | |||
r""" | |||
向collector新增一个collect_fn函数 | |||
向collater新增一个collate_fn函数 | |||
:param callable fn: | |||
:param str,int name: | |||
:return: | |||
""" | |||
if name in self.collect_fns: | |||
logger.warn(f"collect_fn:{name} will be overwritten.") | |||
if name in self.collate_fns: | |||
logger.warn(f"collate_fn:{name} will be overwritten.") | |||
if name is None: | |||
name = len(self.collect_fns) | |||
self.collect_fns[name] = fn | |||
name = len(self.collate_fns) | |||
self.collate_fns[name] = fn | |||
def is_empty(self): | |||
r""" | |||
返回是否包含collect_fn | |||
返回是否包含collate_fn | |||
:return: | |||
""" | |||
return len(self.collect_fns)==0 | |||
return len(self.collate_fns) == 0 | |||
def delete_fn(self, name=None): | |||
r""" | |||
删除collect_fn | |||
删除collate_fn | |||
:param str,int name: 如果为None就删除最近加入的collect_fn | |||
:param str,int name: 如果为None就删除最近加入的collate_fn | |||
:return: | |||
""" | |||
if not self.is_empty(): | |||
if name in self.collect_fns: | |||
self.collect_fns.pop(name) | |||
if name in self.collate_fns: | |||
self.collate_fns.pop(name) | |||
elif name is None: | |||
last_key = list(self.collect_fns.keys())[0] | |||
self.collect_fns.pop(last_key) | |||
last_key = list(self.collate_fns.keys())[0] | |||
self.collate_fns.pop(last_key) | |||
def collect_batch(self, ins_list): | |||
def collate_batch(self, ins_list): | |||
bx, by = {}, {} | |||
for name, fn in self.collect_fns.items(): | |||
for name, fn in self.collate_fns.items(): | |||
try: | |||
batch_x, batch_y = fn(ins_list) | |||
except BaseException as e: | |||
logger.error(f"Exception:`{e}` happens when call collect_fn:`{name}`.") | |||
logger.error(f"Exception:`{e}` happens when call collate_fn:`{name}`.") | |||
raise e | |||
bx.update(batch_x) | |||
by.update(batch_y) | |||
return bx, by | |||
def copy_from(self, col): | |||
assert isinstance(col, Collector) | |||
new_col = Collector() | |||
new_col.collect_fns = deepcopy(col.collect_fns) | |||
assert isinstance(col, Collater) | |||
new_col = Collater() | |||
new_col.collate_fns = deepcopy(col.collate_fns) | |||
return new_col | |||
class ConcatCollectFn: | |||
class ConcatCollateFn: | |||
r""" | |||
field拼接collect_fn,将不同field按序拼接后,padding产生数据。 | |||
field拼接collate_fn,将不同field按序拼接后,padding产生数据。 | |||
:param List[str] inputs: 将哪些field的数据拼接起来, 目前仅支持1d的field | |||
:param str output: 拼接后的field名称 |
@@ -285,8 +285,8 @@ r""" | |||
------------------------------------------------------------ | |||
DataSet支持在进行batch时,默认只能看到当前的field的值,但在某些训练中可能存在以下的情况: (1)需要两个field拼接成为一个field; | |||
(2)需要在batch中进行负采样。这时候就需要能够同时利用多个field进行batch的操作,DataSet中的add_collect_fn()函数支持添加 | |||
自定义涉及多个field的collect_fn函数。例如下例中将两个field拼接成一个field的场景 | |||
(2)需要在batch中进行负采样。这时候就需要能够同时利用多个field进行batch的操作,DataSet中的add_collate_fn()函数支持添加 | |||
自定义涉及多个field的collate_fn函数。例如下例中将两个field拼接成一个field的场景 | |||
.. code-block:: | |||
@@ -302,9 +302,9 @@ r""" | |||
}) | |||
data.set_target('y') | |||
# 所有的collect_fn函数都接受list[(ind1, instance1), (ind2, instance2), ...]作为输入,其中ind1/ind2是该instance在dataset中 | |||
# 所有的collate_fn函数都接受list[(ind1, instance1), (ind2, instance2), ...]作为输入,其中ind1/ind2是该instance在dataset中 | |||
# 的index,instance1/instance2是这次batch取出来的数据,包含了所有的field. | |||
def concat_collect_fn(ins_list): | |||
def concat_collate_fn(ins_list): | |||
x1 = [ins['x1'] for ind,ins in ins_list] | |||
x2 = [ins['x2'] for ind,ins in ins_list] | |||
xs = [] | |||
@@ -318,7 +318,7 @@ r""" | |||
# 采用返回值。 | |||
return b_x, b_y | |||
data.add_collect_fn(concat_collect_fn) | |||
data.add_collate_fn(concat_collate_fn) | |||
for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2): | |||
print("batch_x:", batch_x) | |||
@@ -328,7 +328,7 @@ r""" | |||
# batch_y: {'y': array([0, 1])} | |||
# 如果取batch过程含有一些参数,可以通过类来实现 | |||
class ConCollectFn: | |||
class ConCollateFn: | |||
def __init__(self, max_len=3): | |||
self.max_len = max_len | |||
@@ -342,8 +342,8 @@ r""" | |||
b_x = {'x': arr} | |||
b_y = {} | |||
return b_x, b_y | |||
data.delete_collect_fn() # 删除之前的collect_fn | |||
data.add_collect_fn(ConCollectFn(max_len=3)) | |||
data.delete_collate_fn() # 删除之前的collate_fn | |||
data.add_collate_fn(ConCollateFn(max_len=3)) | |||
for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2): | |||
print("batch_x:", batch_x) | |||
print("batch_y:", batch_y) | |||
@@ -370,7 +370,7 @@ from .field import FieldArray | |||
from .field import SetInputOrTargetException | |||
from .instance import Instance | |||
from .utils import pretty_table_printer | |||
from .collect_fn import Collector | |||
from .collate_fn import Collater | |||
class ApplyResultException(Exception): | |||
@@ -406,7 +406,7 @@ class DataSet(object): | |||
else: | |||
raise ValueError("data only be dict or list type.") | |||
self.collector = Collector() | |||
self.collater = Collater() | |||
def __contains__(self, item): | |||
return item in self.field_arrays | |||
@@ -462,7 +462,7 @@ class DataSet(object): | |||
for field in self.field_arrays.values(): | |||
data_set.add_field(field_name=field.name, fields=field.content[idx], padder=field.padder, | |||
is_input=field.is_input, is_target=field.is_target, ignore_type=field.ignore_type) | |||
data_set.collector = self.collector.copy_from(self.collector) | |||
data_set.collater = self.collater.copy_from(self.collater) | |||
return data_set | |||
elif isinstance(idx, str): | |||
if idx not in self: | |||
@@ -476,7 +476,7 @@ class DataSet(object): | |||
dataset.append(instance) | |||
for field_name, field in self.field_arrays.items(): | |||
dataset.field_arrays[field_name].to(field) | |||
dataset.collector = self.collector.copy_from(self.collector) | |||
dataset.collater = self.collater.copy_from(self.collater) | |||
return dataset | |||
else: | |||
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | |||
@@ -1083,8 +1083,8 @@ class DataSet(object): | |||
train_set.field_arrays[field_name].to(self.field_arrays[field_name]) | |||
dev_set.field_arrays[field_name].to(self.field_arrays[field_name]) | |||
train_set.collector.copy_from(self.collector) | |||
dev_set.collector.copy_from(self.collector) | |||
train_set.collater.copy_from(self.collater) | |||
dev_set.collater.copy_from(self.collater) | |||
return train_set, dev_set | |||
def save(self, path): | |||
@@ -1109,30 +1109,30 @@ class DataSet(object): | |||
assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d)) | |||
return d | |||
def add_collect_fn(self, fn, name=None): | |||
def add_collate_fn(self, fn, name=None): | |||
r""" | |||
添加 CollectFn,collect_fn允许在生成的batch的过程中动态生成一些数据(在DataSetIter作为迭代器的情况下有效,默认情况下就是用的 | |||
这个)。支持依次添加多个collect_fn, 如果相同的key,后面的collect_fn的结果覆盖前面的collect_fn的结果。 | |||
添加 CollateFn,collate_fn允许在生成的batch的过程中动态生成一些数据(在DataSetIter作为迭代器的情况下有效,默认情况下就是用的 | |||
这个)。支持依次添加多个collate_fn, 如果相同的key,后面的collate_fn的结果覆盖前面的collate_fn的结果。 | |||
:param callable fn: 传入一个可调用的function, 该function可接受的参数为List[(ind1, instance1), (ind2, instance2)] | |||
(某个batch被选中的所有的indice以及instance),其中ind1/ind2是该instance在dataset中的index,instance1/instance2是 | |||
这次batch取出来的数据,包含了所有的field。返回值需要为两个dict,第一个dict的值将被认为是input,第二个dict的值被认为是 | |||
target,返回的值至多允许一个空dict。若返回的dict中包含了被设置为input或target的field的名称,将覆盖dataset中的field。 | |||
fastNLP不会将collect_fn的返回结果pad和转换为tensor,需要在collect_fn中完成pad和转换为tensor(不需要将tensor移动到 | |||
gpu中,如果是pytorch的tensor,fastNLP会自动将其移动到特定gpu)。不要修改传入collect_fn中的数据,否则可能导致未知问题。 | |||
:param str,int name: collect_fn的名称,如果不传入,默认使用自增长的数字作为key。相同的name会覆盖之前的collect_fn。 | |||
fastNLP不会将collate_fn的返回结果pad和转换为tensor,需要在collate_fn中完成pad和转换为tensor(不需要将tensor移动到 | |||
gpu中,fastNLP会自动将其移动到特定gpu)。不要修改传入collate_fn中的数据,否则可能导致未知问题。 | |||
:param str,int name: collate_fn的名称,如果不传入,默认使用自增长的数字作为key。相同的name会覆盖之前的collate_fn。 | |||
""" | |||
assert callable(fn), "You must pass in a callable object." | |||
self.collector.add_fn(fn, name=name) | |||
self.collater.add_fn(fn, name=name) | |||
def delete_collect_fn(self, name=None): | |||
def delete_collate_fn(self, name=None): | |||
r""" | |||
删除某个collect_fn | |||
删除某个collate_fn | |||
:param str,int name: 如果为None,则删除最近加入的collect_fn | |||
:param str,int name: 如果为None,则删除最近加入的collate_fn | |||
:return: | |||
""" | |||
self.collector.delete_fn(name) | |||
self.collater.delete_fn(name) | |||
def _collect_batch(self, ins_list): | |||
return self.collector.collect_batch(ins_list) | |||
def _collate_batch(self, ins_list): | |||
return self.collater.collate_batch(ins_list) |
@@ -718,8 +718,8 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||
_tmp += f' Or provide `{_miss}` in DataSet or the output of {prev_func_signature}. ' | |||
else: | |||
_tmp = f'Provide `{_miss}` in DataSet or the output of {prev_func_signature}.' | |||
if not dataset.collector.is_empty(): | |||
_tmp += f'Or you need to add `{_miss}` in the output of your collect_fn. ' | |||
if not dataset.collater.is_empty(): | |||
_tmp += f'Or you need to add `{_miss}` in the output of your collate_fn. ' | |||
suggestions.append(_tmp) | |||
if check_res.duplicated: | |||
@@ -779,8 +779,8 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): | |||
suggestions.append(f"You might need to set `{_miss_in_dataset}` as input. ") | |||
if _miss_out_dataset: | |||
_tmp = f"You need to provide `{_miss_out_dataset}` in DataSet and set it as input. " | |||
if not dataset.collector.is_empty(): | |||
_tmp += f'Or you need to add `{_miss_out_dataset}` in the output of your collect_fn. ' | |||
if not dataset.collator.is_empty(): | |||
_tmp += f'Or you need to add `{_miss_out_dataset}` in the output of your collate_fn. ' | |||
suggestions.append(_tmp) | |||
if check_res.unused: | |||
@@ -348,26 +348,26 @@ class DataBundle: | |||
dataset.apply(func, new_field_name=new_field_name, **kwargs) | |||
return self | |||
def add_collect_fn(self, fn, name=None): | |||
def add_collate_fn(self, fn, name=None): | |||
r""" | |||
向所有DataSet增加collect_fn, collect_fn详见 :class:`~fastNLP.DataSet` 中相关说明. | |||
向所有DataSet增加collate_fn, collate_fn详见 :class:`~fastNLP.DataSet` 中相关说明. | |||
:param callable fn: | |||
:param name: | |||
:return: | |||
""" | |||
for _, dataset in self.datasets.items(): | |||
dataset.add_collect_fn(fn=fn, name=name) | |||
dataset.add_collate_fn(fn=fn, name=name) | |||
def delete_collect_fn(self, name=None): | |||
def delete_collate_fn(self, name=None): | |||
r""" | |||
删除DataSet中的collect_fn | |||
删除DataSet中的collate_fn | |||
:param name: | |||
:return: | |||
""" | |||
for _, dataset in self.datasets.items(): | |||
dataset.delete_collect_fn(name=name) | |||
dataset.delete_collate_fn(name=name) | |||
def __repr__(self): | |||
_str = '' | |||
@@ -7,7 +7,7 @@ from fastNLP import DataSetIter, TorchLoaderIter | |||
from fastNLP import DataSet | |||
from fastNLP import Instance | |||
from fastNLP import SequentialSampler | |||
from fastNLP import ConcatCollectFn | |||
from fastNLP import ConcatCollateFn | |||
def generate_fake_dataset(num_samples=1000): | |||
@@ -177,76 +177,76 @@ class TestCase1(unittest.TestCase): | |||
for con,t in zip(cons, test): | |||
self.assertEqual(alphas[:con], t) | |||
def test_collect_fn(self): | |||
def test_collate_fn(self): | |||
batch_size = 32 | |||
num_samples = 1000 | |||
dataset = generate_fake_dataset(num_samples) | |||
dataset.set_input('1','2') | |||
dataset.set_target('0','3') | |||
fn = ConcatCollectFn(inputs=['1', '2'], output='12', pad_val=0, max_len=0, is_input=True, is_target=False) | |||
dataset.add_collect_fn(fn, name='demo') | |||
fn = ConcatCollateFn(inputs=['1', '2'], output='12', pad_val=0, max_len=0, is_input=True, is_target=False) | |||
dataset.add_collate_fn(fn, name='demo') | |||
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) | |||
for batch_x, batch_y in batch: | |||
for i in range(batch_size): | |||
# print(i) | |||
self.assertEqual(batch_x['12'][i].sum(), batch_x['1'][i].sum() + batch_x['2'][i].sum()) | |||
dataset.delete_collect_fn(name='demo') | |||
dataset.delete_collate_fn(name='demo') | |||
# 测试非input的情况 | |||
dataset.set_input('1', '2', flag=False) # | |||
fn = ConcatCollectFn(inputs=['1', '2'], output='12', pad_val=0, max_len=0, is_input=True, is_target=False) | |||
dataset.add_collect_fn(fn, name='demo') | |||
fn = ConcatCollateFn(inputs=['1', '2'], output='12', pad_val=0, max_len=0, is_input=True, is_target=False) | |||
dataset.add_collate_fn(fn, name='demo') | |||
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) | |||
for batch_x, batch_y in batch: | |||
for i in range(batch_size): | |||
self.assertTrue('12' in batch_x) | |||
dataset.delete_collect_fn(name='demo') | |||
dataset.delete_collate_fn(name='demo') | |||
dataset.set_input('1', '2', flag=True) # | |||
# 测试覆盖其它field的情况 | |||
fn = ConcatCollectFn(inputs=['1', '2'], output='3', pad_val=0, max_len=0, is_input=True, is_target=True) | |||
dataset.add_collect_fn(fn, name='demo') | |||
fn = ConcatCollateFn(inputs=['1', '2'], output='3', pad_val=0, max_len=0, is_input=True, is_target=True) | |||
dataset.add_collate_fn(fn, name='demo') | |||
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) | |||
for batch_x, batch_y in batch: | |||
for i in range(batch_size): | |||
# print(i) | |||
self.assertEqual(batch_y['3'][i].sum(), batch_x['1'][i].sum() + batch_x['2'][i].sum()) | |||
dataset.delete_collect_fn(name='demo') | |||
dataset.delete_collate_fn(name='demo') | |||
# 测试非input,target的情况 | |||
dataset.set_input('1', '2', flag=False) | |||
fn = ConcatCollectFn(inputs=['1', '2'], output='3', pad_val=0, max_len=0, is_input=True, is_target=True) | |||
dataset.add_collect_fn(fn, name='demo') | |||
fn = ConcatCollateFn(inputs=['1', '2'], output='3', pad_val=0, max_len=0, is_input=True, is_target=True) | |||
dataset.add_collate_fn(fn, name='demo') | |||
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) | |||
for batch_x, batch_y in batch: | |||
for i in range(batch_size): | |||
# print(i) | |||
self.assertTrue('3' in batch_x) | |||
self.assertTrue('3' in batch_y) | |||
dataset.delete_collect_fn(name='demo') | |||
dataset.delete_collate_fn(name='demo') | |||
# 测试加入非法fn的请 | |||
with self.assertRaises(AssertionError): | |||
dataset.add_collect_fn(1) | |||
dataset.add_collate_fn(1) | |||
# 测试collect_fn返回值只有一个的情况 | |||
def demo_collect_fn(ins_list): | |||
# 测试collate_fn返回值只有一个的情况 | |||
def demo_collate_fn(ins_list): | |||
return {'3':1} | |||
dataset.add_collect_fn(demo_collect_fn, name='demo') | |||
dataset.add_collate_fn(demo_collate_fn, name='demo') | |||
with self.assertRaises(BaseException): | |||
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) | |||
for batch_x, batch_y in batch: | |||
pass | |||
dataset.delete_collect_fn(name='demo') | |||
dataset.delete_collate_fn(name='demo') | |||
# 测试多个collect_fn | |||
dataset.add_collect_fn(demo_collect_fn, name='demo') | |||
dataset.add_collect_fn(demo_collect_fn, name='demo') | |||
# 测试多个collate_fn | |||
dataset.add_collate_fn(demo_collate_fn, name='demo') | |||
dataset.add_collate_fn(demo_collate_fn, name='demo') | |||
# 测试删除 | |||
dataset.delete_collect_fn() | |||
dataset.delete_collect_fn() | |||
self.assertTrue(dataset.collector.is_empty()) | |||
dataset.delete_collate_fn() | |||
dataset.delete_collate_fn() | |||
self.assertTrue(dataset.collater.is_empty()) | |||
def test_demo(self): | |||
import torch | |||
@@ -261,9 +261,9 @@ class TestCase1(unittest.TestCase): | |||
}) | |||
data.set_target('y') | |||
# 所有的collect_fn函数都接受list[(ind1, instance1), (ind2, instance2), ...]作为输入,其中ind1/ind2是该instance在dataset中 | |||
# 所有的collate_fn函数都接受list[(ind1, instance1), (ind2, instance2), ...]作为输入,其中ind1/ind2是该instance在dataset中 | |||
# 的index,instance1/instance2是这次batch取出来的数据,包含了所有的field. | |||
def concat_collect_fn(ins_list): | |||
def concat_collate_fn(ins_list): | |||
x1 = [ins['x1'] for ind,ins in ins_list] | |||
x2 = [ins['x2'] for ind,ins in ins_list] | |||
xs = [] | |||
@@ -277,7 +277,7 @@ class TestCase1(unittest.TestCase): | |||
# 采用返回值。 | |||
return b_x, b_y | |||
data.add_collect_fn(concat_collect_fn) | |||
data.add_collate_fn(concat_collate_fn) | |||
for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2): | |||
print("batch_x:", batch_x) | |||
@@ -287,7 +287,7 @@ class TestCase1(unittest.TestCase): | |||
# batch_y: {'y': array([0, 1])} | |||
# 如果取batch过程含有一些参数,可以通过类来实现 | |||
class ConCollectFn: | |||
class ConCollateFn: | |||
def __init__(self, max_len=3): | |||
self.max_len = max_len | |||
def __call__(self, ins_list): | |||
@@ -300,8 +300,8 @@ class TestCase1(unittest.TestCase): | |||
b_x = {'x': arr} | |||
b_y = {} | |||
return b_x, b_y | |||
data.delete_collect_fn() # 删除之前的collect_fn | |||
data.add_collect_fn(ConCollectFn(max_len=3)) | |||
data.delete_collate_fn() # 删除之前的collate_fn | |||
data.add_collate_fn(ConCollateFn(max_len=3)) | |||
for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2): | |||
print("batch_x:", batch_x) | |||
print("batch_y:", batch_y) | |||
@@ -326,14 +326,77 @@ class TestCase1(unittest.TestCase): | |||
return x, y | |||
data1 = FakeData() | |||
dataiter = TorchLoaderIter(data1, batch_size=2) | |||
def collact_fn(ins_list): | |||
xs = [ins[0]['x'] for ins in ins_list] | |||
ys = [ins[1]['y'] for ins in ins_list] | |||
return {'x':xs}, {'y':ys} | |||
dataiter = TorchLoaderIter(data1, collate_fn=collact_fn, batch_size=2) | |||
for x, y in dataiter: | |||
print(x, y) | |||
def func(): | |||
data2 = FakeData(return_dict=False) | |||
dataiter = TorchLoaderIter(data2, batch_size=2) | |||
self.assertRaises(Exception, func) | |||
def test_batch_sampler(self): | |||
# 测试DataSetIter与TorchLoaderIter的batch_sampler能否正常工作 | |||
# DataSetIter | |||
ds = generate_fake_dataset(5) | |||
ds.set_input('1') | |||
class BatchSampler: | |||
def __init__(self, dataset): | |||
self.num_samples = len(dataset) | |||
def __iter__(self): | |||
index = 0 | |||
indexes = list(range(self.num_samples)) | |||
np.random.shuffle(indexes) | |||
start_idx = 0 | |||
while index < self.num_samples: | |||
if start_idx == 0: | |||
end_index = self.num_samples//2 | |||
else: | |||
end_index = self.num_samples | |||
yield indexes[start_idx:end_index] | |||
index = end_index | |||
start_idx = end_index | |||
def __len__(self): | |||
return 2 | |||
batch_sampler = BatchSampler(ds) | |||
data_iter = DataSetIter(ds, batch_size=10, sampler=batch_sampler, as_numpy=False, | |||
num_workers=0, pin_memory=False, drop_last=False, | |||
timeout=0, worker_init_fn=None, collate_fn=None, | |||
batch_sampler=batch_sampler) | |||
num_samples = [len(ds)//2, len(ds)-len(ds)//2] | |||
for idx, (batch_x, batch_y) in enumerate(data_iter): | |||
self.assertEqual(num_samples[idx], len(batch_x['1'])) | |||
# TorchLoaderIter | |||
class FakeData: | |||
def __init__(self): | |||
self.x = [[1,2,3], [4,5,6], [1,2]] | |||
def __len__(self): | |||
return len(self.x) | |||
def __getitem__(self, i): | |||
x = self.x[i] | |||
y = 0 | |||
return x,y | |||
def collate_fn(ins_list): | |||
xs = [ins[0] for ins in ins_list] | |||
ys = [ins[1] for ins in ins_list] | |||
return {'x':xs}, {'y':ys} | |||
ds = FakeData() | |||
batch_sampler = BatchSampler(ds) | |||
data_iter = TorchLoaderIter(ds, batch_size=10, sampler=batch_sampler, | |||
num_workers=0, pin_memory=False, drop_last=False, | |||
timeout=0, worker_init_fn=None, collate_fn=collate_fn, | |||
batch_sampler=batch_sampler) | |||
num_samples = [len(ds)//2, len(ds)-len(ds)//2] | |||
for idx, (batch_x, batch_y) in enumerate(data_iter): | |||
self.assertEqual(num_samples[idx], len(batch_x['x'])) | |||
""" | |||
def test_multi_workers_batch(self): | |||
@@ -243,7 +243,7 @@ class TrainerTestGround(unittest.TestCase): | |||
def __len__(self): | |||
return self.num_samples | |||
def collect_fn(data_list): | |||
def collate_fn(data_list): | |||
# [(x1,y1), (x2,y2), ...], 这里的输入实际上是将UdfDataSet的__getitem__输入结合为list | |||
xs, ys = [], [] | |||
for l in data_list: | |||
@@ -254,7 +254,7 @@ class TrainerTestGround(unittest.TestCase): | |||
return {'x':x, 'y':y}, {'y':y} | |||
dataset = UdfDataSet(10) | |||
dataset = TorchLoaderIter(dataset, collate_fn=collect_fn) | |||
dataset = TorchLoaderIter(dataset, collate_fn=collate_fn) | |||
class Model(nn.Module): | |||
def __init__(self): | |||
super().__init__() | |||
@@ -268,6 +268,67 @@ class TrainerTestGround(unittest.TestCase): | |||
metrics=AccuracyMetric(target='y'), use_tqdm=False) | |||
trainer.train(load_best_model=False) | |||
def test_batch_sampler_dataiter(self): | |||
import random | |||
import torch | |||
class BatchSampler: | |||
def __init__(self, dataset): | |||
self.num_samples = len(dataset) | |||
def __iter__(self): | |||
index = 0 | |||
indexes = list(range(self.num_samples)) | |||
np.random.shuffle(indexes) | |||
start_idx = 0 | |||
while index < self.num_samples: | |||
if start_idx == 0: | |||
end_index = self.num_samples//2 | |||
else: | |||
end_index = self.num_samples | |||
yield indexes[start_idx:end_index] | |||
index = end_index | |||
start_idx = end_index | |||
def __len__(self): | |||
return 2 | |||
class UdfDataSet: | |||
def __init__(self, num_samples): | |||
self.num_samples = num_samples | |||
def __getitem__(self, idx): | |||
x = [random.random() for _ in range(3)] | |||
y = random.random() | |||
return x,y | |||
def __len__(self): | |||
return self.num_samples | |||
def collate_fn(data_list): | |||
# [(x1,y1), (x2,y2), ...], 这里的输入实际上是将UdfDataSet的__getitem__输入结合为list | |||
xs, ys = [], [] | |||
for l in data_list: | |||
x, y = l | |||
xs.append(x) | |||
ys.append(y) | |||
x,y = torch.FloatTensor(xs), torch.FloatTensor(ys) | |||
return {'x':x, 'y':y}, {'y':y} | |||
dataset = UdfDataSet(11) | |||
batch_sampler = BatchSampler(dataset) | |||
dataset = TorchLoaderIter(dataset, collate_fn=collate_fn, batch_sampler=batch_sampler) | |||
class Model(nn.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.fc = nn.Linear(3, 1) | |||
def forward(self, x, y): | |||
return {'loss':torch.pow(self.fc(x).squeeze(-1)-y, 2).sum()} | |||
def predict(self, x): | |||
return {'pred':self.fc(x).squeeze(-1)} | |||
model = Model() | |||
trainer = Trainer(train_data=dataset, model=model, loss=None, print_every=2, dev_data=dataset, | |||
metrics=AccuracyMetric(target='y'), use_tqdm=False) | |||
trainer.train(load_best_model=False) | |||
def test_onthefly_iter(self): | |||
import tempfile | |||
import random | |||
@@ -333,7 +394,7 @@ class TrainerTestGround(unittest.TestCase): | |||
return {'loss': torch.pow(self.fc(x).squeeze(-1) - y, 2).sum()} | |||
def predict(self, x): | |||
return {'pred': self.fc(x).squeeze(0)} | |||
return {'pred': self.fc(x).squeeze(-1)} | |||
model = Model() | |||
trainer = Trainer(train_data=dataset, model=model, loss=None, print_every=2, dev_data=dataset, | |||
@@ -356,7 +417,7 @@ class TrainerTestGround(unittest.TestCase): | |||
x.append(ins['x1']+ins['x2']) | |||
x = torch.FloatTensor(x) | |||
return {'x':x}, {} | |||
dataset.add_collect_fn(fn) | |||
dataset.add_collate_fn(fn) | |||
class Model(nn.Module): | |||
def __init__(self): | |||
@@ -377,7 +438,7 @@ class TrainerTestGround(unittest.TestCase): | |||
dev_data=dataset, metrics=AccuracyMetric(target='y'), use_tqdm=False) | |||
trainer.train() | |||
def test_collect_fn2(self): | |||
def test_collate_fn2(self): | |||
"""测试能否实现batch_x, batch_y""" | |||
dataset = prepare_fake_dataset2('x1', 'x2') | |||
dataset.set_input('x1', 'x2') | |||
@@ -389,7 +450,7 @@ class TrainerTestGround(unittest.TestCase): | |||
x.append(ins['x1']+ins['x2']) | |||
x = torch.FloatTensor(x) | |||
return {'x':x}, {'target':x[:, :4].argmax(dim=-1)} | |||
dataset.add_collect_fn(fn) | |||
dataset.add_collate_fn(fn) | |||
class Model(nn.Module): | |||
def __init__(self): | |||
@@ -410,7 +471,7 @@ class TrainerTestGround(unittest.TestCase): | |||
dev_data=dataset, metrics=AccuracyMetric(), use_tqdm=False) | |||
trainer.train() | |||
def test_collect_fn3(self): | |||
def test_collate_fn3(self): | |||
""" | |||
测试应该会覆盖 | |||
@@ -426,7 +487,7 @@ class TrainerTestGround(unittest.TestCase): | |||
x.append(ins['x1']+ins['x2']) | |||
x = torch.FloatTensor(x) | |||
return {'x1':torch.zeros_like(x)}, {'target':torch.zeros(x.size(0)).long(), 'y':x} | |||
dataset.add_collect_fn(fn) | |||
dataset.add_collate_fn(fn) | |||
class Model(nn.Module): | |||
def __init__(self): | |||