@@ -44,8 +44,8 @@ __all__ = [ | |||||
"AutoPadder", | "AutoPadder", | ||||
"EngChar2DPadder", | "EngChar2DPadder", | ||||
# "CollectFn", | |||||
"ConcatCollectFn", | |||||
# "CollateFn", | |||||
"ConcatCollateFn", | |||||
"MetricBase", | "MetricBase", | ||||
"AccuracyMetric", | "AccuracyMetric", | ||||
@@ -21,7 +21,7 @@ __all__ = [ | |||||
"AutoPadder", | "AutoPadder", | ||||
"EngChar2DPadder", | "EngChar2DPadder", | ||||
"ConcatCollectFn", | |||||
"ConcatCollateFn", | |||||
"Vocabulary", | "Vocabulary", | ||||
@@ -99,4 +99,4 @@ from .tester import Tester | |||||
from .trainer import Trainer | from .trainer import Trainer | ||||
from .utils import cache_results, seq_len_to_mask, get_seq_len | from .utils import cache_results, seq_len_to_mask, get_seq_len | ||||
from .vocabulary import Vocabulary | from .vocabulary import Vocabulary | ||||
from .collect_fn import ConcatCollectFn | |||||
from .collate_fn import ConcatCollateFn |
@@ -18,7 +18,7 @@ import torch.utils.data | |||||
from collections import defaultdict | from collections import defaultdict | ||||
from .dataset import DataSet | from .dataset import DataSet | ||||
from .sampler import SequentialSampler | |||||
from .sampler import SequentialSampler, Sampler | |||||
from ._logger import logger | from ._logger import logger | ||||
@@ -89,8 +89,8 @@ class DataSetGetter: | |||||
sin_x = _pad(sin_x, dataset=self.dataset, as_numpy=self.as_numpy) | sin_x = _pad(sin_x, dataset=self.dataset, as_numpy=self.as_numpy) | ||||
sin_y = _pad(sin_y, dataset=self.dataset, as_numpy=self.as_numpy) | sin_y = _pad(sin_y, dataset=self.dataset, as_numpy=self.as_numpy) | ||||
if not self.dataset.collector.is_empty(): | |||||
bx, by = self.dataset._collect_batch(ins_list) | |||||
if not self.dataset.collater.is_empty(): | |||||
bx, by = self.dataset._collate_batch(ins_list) | |||||
sin_x.update(bx) | sin_x.update(bx) | ||||
sin_y.update(by) | sin_y.update(by) | ||||
@@ -127,29 +127,35 @@ class BatchIter: | |||||
""" | """ | ||||
def __init__(self, dataset, batch_size=1, sampler=None, | def __init__(self, dataset, batch_size=1, sampler=None, | ||||
num_workers=0, pin_memory=False, drop_last=False, | num_workers=0, pin_memory=False, drop_last=False, | ||||
timeout=0, worker_init_fn=None, collate_fn=None): | |||||
if not isinstance(sampler, torch.utils.data.Sampler): | |||||
self.sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset) | |||||
else: | |||||
self.sampler = sampler | |||||
timeout=0, worker_init_fn=None, collate_fn=None, | |||||
batch_sampler=None): | |||||
if isinstance(sampler, Sampler): # 如果时fastNLP的sampler需要adapt一下 | |||||
sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset) | |||||
self.sampler = sampler | |||||
self.batch_sampler = batch_sampler | |||||
# DataLoader的collect_fn输入是List[],里面的元素是dataset[index]返回的结果 | |||||
# DataLoader的collate_fn输入是List[],里面的元素是dataset[index]返回的结果 | |||||
if collate_fn is None: | if collate_fn is None: | ||||
# pytoch <= 1.1 中不能设置collate_fn=None | # pytoch <= 1.1 中不能设置collate_fn=None | ||||
self.dataiter = torch.utils.data.DataLoader( | self.dataiter = torch.utils.data.DataLoader( | ||||
dataset=dataset, batch_size=batch_size, sampler=self.sampler, | dataset=dataset, batch_size=batch_size, sampler=self.sampler, | ||||
num_workers=num_workers, | num_workers=num_workers, | ||||
pin_memory=pin_memory, drop_last=drop_last, | pin_memory=pin_memory, drop_last=drop_last, | ||||
timeout=timeout, worker_init_fn=worker_init_fn) | |||||
timeout=timeout, worker_init_fn=worker_init_fn, | |||||
batch_sampler=batch_sampler) | |||||
else: | else: | ||||
self.dataiter = torch.utils.data.DataLoader( | self.dataiter = torch.utils.data.DataLoader( | ||||
dataset=dataset, batch_size=batch_size, sampler=self.sampler, | dataset=dataset, batch_size=batch_size, sampler=self.sampler, | ||||
collate_fn=collate_fn, num_workers=num_workers, | collate_fn=collate_fn, num_workers=num_workers, | ||||
pin_memory=pin_memory, drop_last=drop_last, | pin_memory=pin_memory, drop_last=drop_last, | ||||
timeout=timeout, worker_init_fn=worker_init_fn) | |||||
timeout=timeout, worker_init_fn=worker_init_fn, | |||||
batch_sampler=batch_sampler) | |||||
# 以sampler的数量为准,因为DistributedSampler的时候每个进程上并不是所有的数据都用上了 | # 以sampler的数量为准,因为DistributedSampler的时候每个进程上并不是所有的数据都用上了 | ||||
self._num_batches = self.get_num_batches(len(self.dataiter.sampler), batch_size, drop_last) | |||||
if self.batch_sampler is None: | |||||
self._num_batches = self.get_num_batches(len(self.dataiter.sampler), batch_size, drop_last) | |||||
else: | |||||
self._num_batches = len(self.batch_sampler) | |||||
self.batch_size = batch_size | self.batch_size = batch_size | ||||
self.cur_batch_indices = None | self.cur_batch_indices = None | ||||
@@ -222,7 +228,8 @@ class DataSetIter(BatchIter): | |||||
""" | """ | ||||
def __init__(self, dataset, batch_size=1, sampler=None, as_numpy=False, | def __init__(self, dataset, batch_size=1, sampler=None, as_numpy=False, | ||||
num_workers=0, pin_memory=False, drop_last=False, | num_workers=0, pin_memory=False, drop_last=False, | ||||
timeout=0, worker_init_fn=None, collate_fn=None): | |||||
timeout=0, worker_init_fn=None, collate_fn=None, | |||||
batch_sampler=None): | |||||
r""" | r""" | ||||
:param dataset: :class:`~fastNLP.DataSet` 对象, 数据集 | :param dataset: :class:`~fastNLP.DataSet` 对象, 数据集 | ||||
@@ -239,15 +246,21 @@ class DataSetIter(BatchIter): | |||||
:param timeout: 生成一个batch的timeout值 | :param timeout: 生成一个batch的timeout值 | ||||
:param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。 | :param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。 | ||||
:param collate_fn: 用于将样本组合成batch的函数 | :param collate_fn: 用于将样本组合成batch的函数 | ||||
:param batch_sampler: 当每次batch取出的数据数量不一致时,可以使用该sampler。batch_sampler每次iter应该输出一个list的index。 | |||||
当batch_sampler不为None时,参数batch_size, sampler, drop_last会被忽略。 | |||||
""" | """ | ||||
assert isinstance(dataset, DataSet) | assert isinstance(dataset, DataSet) | ||||
dataset = DataSetGetter(dataset, as_numpy) | dataset = DataSetGetter(dataset, as_numpy) | ||||
collate_fn = dataset.collate_fn if collate_fn is None else collate_fn | collate_fn = dataset.collate_fn if collate_fn is None else collate_fn | ||||
if batch_sampler is not None: | |||||
batch_size = 1 | |||||
sampler = None | |||||
drop_last = False | |||||
super().__init__( | super().__init__( | ||||
dataset=dataset, batch_size=batch_size, sampler=sampler, | dataset=dataset, batch_size=batch_size, sampler=sampler, | ||||
num_workers=num_workers, pin_memory=pin_memory, | num_workers=num_workers, pin_memory=pin_memory, | ||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | ||||
collate_fn=collate_fn | |||||
collate_fn=collate_fn, batch_sampler=batch_sampler | |||||
) | ) | ||||
def __iter__(self): | def __iter__(self): | ||||
@@ -384,12 +397,16 @@ class TorchLoaderIter(BatchIter): | |||||
os.remove(tmp_file_path) | os.remove(tmp_file_path) | ||||
""" | """ | ||||
def __init__(self, dataset, batch_size=1, sampler=None, | |||||
def __init__(self, dataset, collate_fn, batch_size=1, sampler=None, | |||||
num_workers=0, pin_memory=False, drop_last=False, | num_workers=0, pin_memory=False, drop_last=False, | ||||
timeout=0, worker_init_fn=None, collate_fn=None): | |||||
timeout=0, worker_init_fn=None, | |||||
batch_sampler=None): | |||||
r""" | r""" | ||||
:param dataset: :class:`~fastNLP.DataSet` 对象, 数据集 | |||||
:param dataset: 实现了__getitem__和__len__方法的数据容器。 | |||||
:param callable collate_fn: 用于将样本组合成batch的函数。输入为[dataset[idx1], dataset[idx2], ...], 即dataset中 | |||||
__getitem__返回值组成的list,返回值必须为两个dict,其中第一个dict会被认为是input,第二个dict中的内容被认为是target。 | |||||
需要转换为tensor的数据,需要在collate_fn中转化,但不需要转移到对应device。 | |||||
:param int batch_size: 取出的batch大小 | :param int batch_size: 取出的batch大小 | ||||
:param sampler: 规定使用的 :class:`~fastNLP.Sampler` 方式. 若为 ``None`` , 使用 :class:`~fastNLP.SequentialSampler`. | :param sampler: 规定使用的 :class:`~fastNLP.Sampler` 方式. 若为 ``None`` , 使用 :class:`~fastNLP.SequentialSampler`. | ||||
Default: ``None`` | Default: ``None`` | ||||
@@ -398,19 +415,21 @@ class TorchLoaderIter(BatchIter): | |||||
:param bool drop_last: 如果最后一个batch没有batch_size这么多sample,就扔掉最后一个 | :param bool drop_last: 如果最后一个batch没有batch_size这么多sample,就扔掉最后一个 | ||||
:param timeout: 生成一个batch的timeout值 | :param timeout: 生成一个batch的timeout值 | ||||
:param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。 | :param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。 | ||||
:param collate_fn: 用于将样本组合成batch的函数。 | |||||
:param batch_sampler: 当每次batch取出的数据数量不一致时,可以使用该sampler。batch_sampler每次iter应该输出一个list的index。 | |||||
当batch_sampler不为None时,参数batch_size, sampler, drop_last会被忽略。 | |||||
""" | """ | ||||
assert len(dataset) > 0 | assert len(dataset) > 0 | ||||
ins = dataset[0] | |||||
if (len(ins) != 2 or not isinstance(ins[0], dict) or not isinstance(ins[1], dict)) and collate_fn is None: | |||||
raise RuntimeError("If the provided dataset does not return two dicts when call __getitem__(), the" | |||||
" `collate_fn` must be provided.") | |||||
assert collate_fn is not None, "You must pass collate_fn to pad the batch." | |||||
if batch_sampler is not None: | |||||
batch_size = 1 | |||||
sampler = None | |||||
drop_last = False | |||||
super().__init__( | super().__init__( | ||||
dataset=dataset, batch_size=batch_size, sampler=sampler, | dataset=dataset, batch_size=batch_size, sampler=sampler, | ||||
num_workers=num_workers, pin_memory=pin_memory, | num_workers=num_workers, pin_memory=pin_memory, | ||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | ||||
collate_fn=collate_fn | |||||
collate_fn=collate_fn, batch_sampler=batch_sampler | |||||
) | ) | ||||
def __iter__(self): | def __iter__(self): | ||||
@@ -36,72 +36,72 @@ def batching(samples, max_len=0, padding_val=0): | |||||
return batch | return batch | ||||
class Collector: | |||||
class Collater: | |||||
r""" | r""" | ||||
辅助DataSet管理collect_fn的类 | |||||
辅助DataSet管理collate_fn的类 | |||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
self.collect_fns = {} | |||||
self.collate_fns = {} | |||||
def add_fn(self, fn, name=None): | def add_fn(self, fn, name=None): | ||||
r""" | r""" | ||||
向collector新增一个collect_fn函数 | |||||
向collater新增一个collate_fn函数 | |||||
:param callable fn: | :param callable fn: | ||||
:param str,int name: | :param str,int name: | ||||
:return: | :return: | ||||
""" | """ | ||||
if name in self.collect_fns: | |||||
logger.warn(f"collect_fn:{name} will be overwritten.") | |||||
if name in self.collate_fns: | |||||
logger.warn(f"collate_fn:{name} will be overwritten.") | |||||
if name is None: | if name is None: | ||||
name = len(self.collect_fns) | |||||
self.collect_fns[name] = fn | |||||
name = len(self.collate_fns) | |||||
self.collate_fns[name] = fn | |||||
def is_empty(self): | def is_empty(self): | ||||
r""" | r""" | ||||
返回是否包含collect_fn | |||||
返回是否包含collate_fn | |||||
:return: | :return: | ||||
""" | """ | ||||
return len(self.collect_fns)==0 | |||||
return len(self.collate_fns) == 0 | |||||
def delete_fn(self, name=None): | def delete_fn(self, name=None): | ||||
r""" | r""" | ||||
删除collect_fn | |||||
删除collate_fn | |||||
:param str,int name: 如果为None就删除最近加入的collect_fn | |||||
:param str,int name: 如果为None就删除最近加入的collate_fn | |||||
:return: | :return: | ||||
""" | """ | ||||
if not self.is_empty(): | if not self.is_empty(): | ||||
if name in self.collect_fns: | |||||
self.collect_fns.pop(name) | |||||
if name in self.collate_fns: | |||||
self.collate_fns.pop(name) | |||||
elif name is None: | elif name is None: | ||||
last_key = list(self.collect_fns.keys())[0] | |||||
self.collect_fns.pop(last_key) | |||||
last_key = list(self.collate_fns.keys())[0] | |||||
self.collate_fns.pop(last_key) | |||||
def collect_batch(self, ins_list): | |||||
def collate_batch(self, ins_list): | |||||
bx, by = {}, {} | bx, by = {}, {} | ||||
for name, fn in self.collect_fns.items(): | |||||
for name, fn in self.collate_fns.items(): | |||||
try: | try: | ||||
batch_x, batch_y = fn(ins_list) | batch_x, batch_y = fn(ins_list) | ||||
except BaseException as e: | except BaseException as e: | ||||
logger.error(f"Exception:`{e}` happens when call collect_fn:`{name}`.") | |||||
logger.error(f"Exception:`{e}` happens when call collate_fn:`{name}`.") | |||||
raise e | raise e | ||||
bx.update(batch_x) | bx.update(batch_x) | ||||
by.update(batch_y) | by.update(batch_y) | ||||
return bx, by | return bx, by | ||||
def copy_from(self, col): | def copy_from(self, col): | ||||
assert isinstance(col, Collector) | |||||
new_col = Collector() | |||||
new_col.collect_fns = deepcopy(col.collect_fns) | |||||
assert isinstance(col, Collater) | |||||
new_col = Collater() | |||||
new_col.collate_fns = deepcopy(col.collate_fns) | |||||
return new_col | return new_col | ||||
class ConcatCollectFn: | |||||
class ConcatCollateFn: | |||||
r""" | r""" | ||||
field拼接collect_fn,将不同field按序拼接后,padding产生数据。 | |||||
field拼接collate_fn,将不同field按序拼接后,padding产生数据。 | |||||
:param List[str] inputs: 将哪些field的数据拼接起来, 目前仅支持1d的field | :param List[str] inputs: 将哪些field的数据拼接起来, 目前仅支持1d的field | ||||
:param str output: 拼接后的field名称 | :param str output: 拼接后的field名称 |
@@ -285,8 +285,8 @@ r""" | |||||
------------------------------------------------------------ | ------------------------------------------------------------ | ||||
DataSet支持在进行batch时,默认只能看到当前的field的值,但在某些训练中可能存在以下的情况: (1)需要两个field拼接成为一个field; | DataSet支持在进行batch时,默认只能看到当前的field的值,但在某些训练中可能存在以下的情况: (1)需要两个field拼接成为一个field; | ||||
(2)需要在batch中进行负采样。这时候就需要能够同时利用多个field进行batch的操作,DataSet中的add_collect_fn()函数支持添加 | |||||
自定义涉及多个field的collect_fn函数。例如下例中将两个field拼接成一个field的场景 | |||||
(2)需要在batch中进行负采样。这时候就需要能够同时利用多个field进行batch的操作,DataSet中的add_collate_fn()函数支持添加 | |||||
自定义涉及多个field的collate_fn函数。例如下例中将两个field拼接成一个field的场景 | |||||
.. code-block:: | .. code-block:: | ||||
@@ -302,9 +302,9 @@ r""" | |||||
}) | }) | ||||
data.set_target('y') | data.set_target('y') | ||||
# 所有的collect_fn函数都接受list[(ind1, instance1), (ind2, instance2), ...]作为输入,其中ind1/ind2是该instance在dataset中 | |||||
# 所有的collate_fn函数都接受list[(ind1, instance1), (ind2, instance2), ...]作为输入,其中ind1/ind2是该instance在dataset中 | |||||
# 的index,instance1/instance2是这次batch取出来的数据,包含了所有的field. | # 的index,instance1/instance2是这次batch取出来的数据,包含了所有的field. | ||||
def concat_collect_fn(ins_list): | |||||
def concat_collate_fn(ins_list): | |||||
x1 = [ins['x1'] for ind,ins in ins_list] | x1 = [ins['x1'] for ind,ins in ins_list] | ||||
x2 = [ins['x2'] for ind,ins in ins_list] | x2 = [ins['x2'] for ind,ins in ins_list] | ||||
xs = [] | xs = [] | ||||
@@ -318,7 +318,7 @@ r""" | |||||
# 采用返回值。 | # 采用返回值。 | ||||
return b_x, b_y | return b_x, b_y | ||||
data.add_collect_fn(concat_collect_fn) | |||||
data.add_collate_fn(concat_collate_fn) | |||||
for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2): | for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2): | ||||
print("batch_x:", batch_x) | print("batch_x:", batch_x) | ||||
@@ -328,7 +328,7 @@ r""" | |||||
# batch_y: {'y': array([0, 1])} | # batch_y: {'y': array([0, 1])} | ||||
# 如果取batch过程含有一些参数,可以通过类来实现 | # 如果取batch过程含有一些参数,可以通过类来实现 | ||||
class ConCollectFn: | |||||
class ConCollateFn: | |||||
def __init__(self, max_len=3): | def __init__(self, max_len=3): | ||||
self.max_len = max_len | self.max_len = max_len | ||||
@@ -342,8 +342,8 @@ r""" | |||||
b_x = {'x': arr} | b_x = {'x': arr} | ||||
b_y = {} | b_y = {} | ||||
return b_x, b_y | return b_x, b_y | ||||
data.delete_collect_fn() # 删除之前的collect_fn | |||||
data.add_collect_fn(ConCollectFn(max_len=3)) | |||||
data.delete_collate_fn() # 删除之前的collate_fn | |||||
data.add_collate_fn(ConCollateFn(max_len=3)) | |||||
for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2): | for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2): | ||||
print("batch_x:", batch_x) | print("batch_x:", batch_x) | ||||
print("batch_y:", batch_y) | print("batch_y:", batch_y) | ||||
@@ -370,7 +370,7 @@ from .field import FieldArray | |||||
from .field import SetInputOrTargetException | from .field import SetInputOrTargetException | ||||
from .instance import Instance | from .instance import Instance | ||||
from .utils import pretty_table_printer | from .utils import pretty_table_printer | ||||
from .collect_fn import Collector | |||||
from .collate_fn import Collater | |||||
class ApplyResultException(Exception): | class ApplyResultException(Exception): | ||||
@@ -406,7 +406,7 @@ class DataSet(object): | |||||
else: | else: | ||||
raise ValueError("data only be dict or list type.") | raise ValueError("data only be dict or list type.") | ||||
self.collector = Collector() | |||||
self.collater = Collater() | |||||
def __contains__(self, item): | def __contains__(self, item): | ||||
return item in self.field_arrays | return item in self.field_arrays | ||||
@@ -462,7 +462,7 @@ class DataSet(object): | |||||
for field in self.field_arrays.values(): | for field in self.field_arrays.values(): | ||||
data_set.add_field(field_name=field.name, fields=field.content[idx], padder=field.padder, | data_set.add_field(field_name=field.name, fields=field.content[idx], padder=field.padder, | ||||
is_input=field.is_input, is_target=field.is_target, ignore_type=field.ignore_type) | is_input=field.is_input, is_target=field.is_target, ignore_type=field.ignore_type) | ||||
data_set.collector = self.collector.copy_from(self.collector) | |||||
data_set.collater = self.collater.copy_from(self.collater) | |||||
return data_set | return data_set | ||||
elif isinstance(idx, str): | elif isinstance(idx, str): | ||||
if idx not in self: | if idx not in self: | ||||
@@ -476,7 +476,7 @@ class DataSet(object): | |||||
dataset.append(instance) | dataset.append(instance) | ||||
for field_name, field in self.field_arrays.items(): | for field_name, field in self.field_arrays.items(): | ||||
dataset.field_arrays[field_name].to(field) | dataset.field_arrays[field_name].to(field) | ||||
dataset.collector = self.collector.copy_from(self.collector) | |||||
dataset.collater = self.collater.copy_from(self.collater) | |||||
return dataset | return dataset | ||||
else: | else: | ||||
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | ||||
@@ -1083,8 +1083,8 @@ class DataSet(object): | |||||
train_set.field_arrays[field_name].to(self.field_arrays[field_name]) | train_set.field_arrays[field_name].to(self.field_arrays[field_name]) | ||||
dev_set.field_arrays[field_name].to(self.field_arrays[field_name]) | dev_set.field_arrays[field_name].to(self.field_arrays[field_name]) | ||||
train_set.collector.copy_from(self.collector) | |||||
dev_set.collector.copy_from(self.collector) | |||||
train_set.collater.copy_from(self.collater) | |||||
dev_set.collater.copy_from(self.collater) | |||||
return train_set, dev_set | return train_set, dev_set | ||||
def save(self, path): | def save(self, path): | ||||
@@ -1109,30 +1109,30 @@ class DataSet(object): | |||||
assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d)) | assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d)) | ||||
return d | return d | ||||
def add_collect_fn(self, fn, name=None): | |||||
def add_collate_fn(self, fn, name=None): | |||||
r""" | r""" | ||||
添加 CollectFn,collect_fn允许在生成的batch的过程中动态生成一些数据(在DataSetIter作为迭代器的情况下有效,默认情况下就是用的 | |||||
这个)。支持依次添加多个collect_fn, 如果相同的key,后面的collect_fn的结果覆盖前面的collect_fn的结果。 | |||||
添加 CollateFn,collate_fn允许在生成的batch的过程中动态生成一些数据(在DataSetIter作为迭代器的情况下有效,默认情况下就是用的 | |||||
这个)。支持依次添加多个collate_fn, 如果相同的key,后面的collate_fn的结果覆盖前面的collate_fn的结果。 | |||||
:param callable fn: 传入一个可调用的function, 该function可接受的参数为List[(ind1, instance1), (ind2, instance2)] | :param callable fn: 传入一个可调用的function, 该function可接受的参数为List[(ind1, instance1), (ind2, instance2)] | ||||
(某个batch被选中的所有的indice以及instance),其中ind1/ind2是该instance在dataset中的index,instance1/instance2是 | (某个batch被选中的所有的indice以及instance),其中ind1/ind2是该instance在dataset中的index,instance1/instance2是 | ||||
这次batch取出来的数据,包含了所有的field。返回值需要为两个dict,第一个dict的值将被认为是input,第二个dict的值被认为是 | 这次batch取出来的数据,包含了所有的field。返回值需要为两个dict,第一个dict的值将被认为是input,第二个dict的值被认为是 | ||||
target,返回的值至多允许一个空dict。若返回的dict中包含了被设置为input或target的field的名称,将覆盖dataset中的field。 | target,返回的值至多允许一个空dict。若返回的dict中包含了被设置为input或target的field的名称,将覆盖dataset中的field。 | ||||
fastNLP不会将collect_fn的返回结果pad和转换为tensor,需要在collect_fn中完成pad和转换为tensor(不需要将tensor移动到 | |||||
gpu中,如果是pytorch的tensor,fastNLP会自动将其移动到特定gpu)。不要修改传入collect_fn中的数据,否则可能导致未知问题。 | |||||
:param str,int name: collect_fn的名称,如果不传入,默认使用自增长的数字作为key。相同的name会覆盖之前的collect_fn。 | |||||
fastNLP不会将collate_fn的返回结果pad和转换为tensor,需要在collate_fn中完成pad和转换为tensor(不需要将tensor移动到 | |||||
gpu中,fastNLP会自动将其移动到特定gpu)。不要修改传入collate_fn中的数据,否则可能导致未知问题。 | |||||
:param str,int name: collate_fn的名称,如果不传入,默认使用自增长的数字作为key。相同的name会覆盖之前的collate_fn。 | |||||
""" | """ | ||||
assert callable(fn), "You must pass in a callable object." | assert callable(fn), "You must pass in a callable object." | ||||
self.collector.add_fn(fn, name=name) | |||||
self.collater.add_fn(fn, name=name) | |||||
def delete_collect_fn(self, name=None): | |||||
def delete_collate_fn(self, name=None): | |||||
r""" | r""" | ||||
删除某个collect_fn | |||||
删除某个collate_fn | |||||
:param str,int name: 如果为None,则删除最近加入的collect_fn | |||||
:param str,int name: 如果为None,则删除最近加入的collate_fn | |||||
:return: | :return: | ||||
""" | """ | ||||
self.collector.delete_fn(name) | |||||
self.collater.delete_fn(name) | |||||
def _collect_batch(self, ins_list): | |||||
return self.collector.collect_batch(ins_list) | |||||
def _collate_batch(self, ins_list): | |||||
return self.collater.collate_batch(ins_list) |
@@ -718,8 +718,8 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||||
_tmp += f' Or provide `{_miss}` in DataSet or the output of {prev_func_signature}. ' | _tmp += f' Or provide `{_miss}` in DataSet or the output of {prev_func_signature}. ' | ||||
else: | else: | ||||
_tmp = f'Provide `{_miss}` in DataSet or the output of {prev_func_signature}.' | _tmp = f'Provide `{_miss}` in DataSet or the output of {prev_func_signature}.' | ||||
if not dataset.collector.is_empty(): | |||||
_tmp += f'Or you need to add `{_miss}` in the output of your collect_fn. ' | |||||
if not dataset.collater.is_empty(): | |||||
_tmp += f'Or you need to add `{_miss}` in the output of your collate_fn. ' | |||||
suggestions.append(_tmp) | suggestions.append(_tmp) | ||||
if check_res.duplicated: | if check_res.duplicated: | ||||
@@ -779,8 +779,8 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): | |||||
suggestions.append(f"You might need to set `{_miss_in_dataset}` as input. ") | suggestions.append(f"You might need to set `{_miss_in_dataset}` as input. ") | ||||
if _miss_out_dataset: | if _miss_out_dataset: | ||||
_tmp = f"You need to provide `{_miss_out_dataset}` in DataSet and set it as input. " | _tmp = f"You need to provide `{_miss_out_dataset}` in DataSet and set it as input. " | ||||
if not dataset.collector.is_empty(): | |||||
_tmp += f'Or you need to add `{_miss_out_dataset}` in the output of your collect_fn. ' | |||||
if not dataset.collator.is_empty(): | |||||
_tmp += f'Or you need to add `{_miss_out_dataset}` in the output of your collate_fn. ' | |||||
suggestions.append(_tmp) | suggestions.append(_tmp) | ||||
if check_res.unused: | if check_res.unused: | ||||
@@ -348,26 +348,26 @@ class DataBundle: | |||||
dataset.apply(func, new_field_name=new_field_name, **kwargs) | dataset.apply(func, new_field_name=new_field_name, **kwargs) | ||||
return self | return self | ||||
def add_collect_fn(self, fn, name=None): | |||||
def add_collate_fn(self, fn, name=None): | |||||
r""" | r""" | ||||
向所有DataSet增加collect_fn, collect_fn详见 :class:`~fastNLP.DataSet` 中相关说明. | |||||
向所有DataSet增加collate_fn, collate_fn详见 :class:`~fastNLP.DataSet` 中相关说明. | |||||
:param callable fn: | :param callable fn: | ||||
:param name: | :param name: | ||||
:return: | :return: | ||||
""" | """ | ||||
for _, dataset in self.datasets.items(): | for _, dataset in self.datasets.items(): | ||||
dataset.add_collect_fn(fn=fn, name=name) | |||||
dataset.add_collate_fn(fn=fn, name=name) | |||||
def delete_collect_fn(self, name=None): | |||||
def delete_collate_fn(self, name=None): | |||||
r""" | r""" | ||||
删除DataSet中的collect_fn | |||||
删除DataSet中的collate_fn | |||||
:param name: | :param name: | ||||
:return: | :return: | ||||
""" | """ | ||||
for _, dataset in self.datasets.items(): | for _, dataset in self.datasets.items(): | ||||
dataset.delete_collect_fn(name=name) | |||||
dataset.delete_collate_fn(name=name) | |||||
def __repr__(self): | def __repr__(self): | ||||
_str = '' | _str = '' | ||||
@@ -7,7 +7,7 @@ from fastNLP import DataSetIter, TorchLoaderIter | |||||
from fastNLP import DataSet | from fastNLP import DataSet | ||||
from fastNLP import Instance | from fastNLP import Instance | ||||
from fastNLP import SequentialSampler | from fastNLP import SequentialSampler | ||||
from fastNLP import ConcatCollectFn | |||||
from fastNLP import ConcatCollateFn | |||||
def generate_fake_dataset(num_samples=1000): | def generate_fake_dataset(num_samples=1000): | ||||
@@ -177,76 +177,76 @@ class TestCase1(unittest.TestCase): | |||||
for con,t in zip(cons, test): | for con,t in zip(cons, test): | ||||
self.assertEqual(alphas[:con], t) | self.assertEqual(alphas[:con], t) | ||||
def test_collect_fn(self): | |||||
def test_collate_fn(self): | |||||
batch_size = 32 | batch_size = 32 | ||||
num_samples = 1000 | num_samples = 1000 | ||||
dataset = generate_fake_dataset(num_samples) | dataset = generate_fake_dataset(num_samples) | ||||
dataset.set_input('1','2') | dataset.set_input('1','2') | ||||
dataset.set_target('0','3') | dataset.set_target('0','3') | ||||
fn = ConcatCollectFn(inputs=['1', '2'], output='12', pad_val=0, max_len=0, is_input=True, is_target=False) | |||||
dataset.add_collect_fn(fn, name='demo') | |||||
fn = ConcatCollateFn(inputs=['1', '2'], output='12', pad_val=0, max_len=0, is_input=True, is_target=False) | |||||
dataset.add_collate_fn(fn, name='demo') | |||||
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) | batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) | ||||
for batch_x, batch_y in batch: | for batch_x, batch_y in batch: | ||||
for i in range(batch_size): | for i in range(batch_size): | ||||
# print(i) | # print(i) | ||||
self.assertEqual(batch_x['12'][i].sum(), batch_x['1'][i].sum() + batch_x['2'][i].sum()) | self.assertEqual(batch_x['12'][i].sum(), batch_x['1'][i].sum() + batch_x['2'][i].sum()) | ||||
dataset.delete_collect_fn(name='demo') | |||||
dataset.delete_collate_fn(name='demo') | |||||
# 测试非input的情况 | # 测试非input的情况 | ||||
dataset.set_input('1', '2', flag=False) # | dataset.set_input('1', '2', flag=False) # | ||||
fn = ConcatCollectFn(inputs=['1', '2'], output='12', pad_val=0, max_len=0, is_input=True, is_target=False) | |||||
dataset.add_collect_fn(fn, name='demo') | |||||
fn = ConcatCollateFn(inputs=['1', '2'], output='12', pad_val=0, max_len=0, is_input=True, is_target=False) | |||||
dataset.add_collate_fn(fn, name='demo') | |||||
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) | batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) | ||||
for batch_x, batch_y in batch: | for batch_x, batch_y in batch: | ||||
for i in range(batch_size): | for i in range(batch_size): | ||||
self.assertTrue('12' in batch_x) | self.assertTrue('12' in batch_x) | ||||
dataset.delete_collect_fn(name='demo') | |||||
dataset.delete_collate_fn(name='demo') | |||||
dataset.set_input('1', '2', flag=True) # | dataset.set_input('1', '2', flag=True) # | ||||
# 测试覆盖其它field的情况 | # 测试覆盖其它field的情况 | ||||
fn = ConcatCollectFn(inputs=['1', '2'], output='3', pad_val=0, max_len=0, is_input=True, is_target=True) | |||||
dataset.add_collect_fn(fn, name='demo') | |||||
fn = ConcatCollateFn(inputs=['1', '2'], output='3', pad_val=0, max_len=0, is_input=True, is_target=True) | |||||
dataset.add_collate_fn(fn, name='demo') | |||||
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) | batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) | ||||
for batch_x, batch_y in batch: | for batch_x, batch_y in batch: | ||||
for i in range(batch_size): | for i in range(batch_size): | ||||
# print(i) | # print(i) | ||||
self.assertEqual(batch_y['3'][i].sum(), batch_x['1'][i].sum() + batch_x['2'][i].sum()) | self.assertEqual(batch_y['3'][i].sum(), batch_x['1'][i].sum() + batch_x['2'][i].sum()) | ||||
dataset.delete_collect_fn(name='demo') | |||||
dataset.delete_collate_fn(name='demo') | |||||
# 测试非input,target的情况 | # 测试非input,target的情况 | ||||
dataset.set_input('1', '2', flag=False) | dataset.set_input('1', '2', flag=False) | ||||
fn = ConcatCollectFn(inputs=['1', '2'], output='3', pad_val=0, max_len=0, is_input=True, is_target=True) | |||||
dataset.add_collect_fn(fn, name='demo') | |||||
fn = ConcatCollateFn(inputs=['1', '2'], output='3', pad_val=0, max_len=0, is_input=True, is_target=True) | |||||
dataset.add_collate_fn(fn, name='demo') | |||||
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) | batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) | ||||
for batch_x, batch_y in batch: | for batch_x, batch_y in batch: | ||||
for i in range(batch_size): | for i in range(batch_size): | ||||
# print(i) | # print(i) | ||||
self.assertTrue('3' in batch_x) | self.assertTrue('3' in batch_x) | ||||
self.assertTrue('3' in batch_y) | self.assertTrue('3' in batch_y) | ||||
dataset.delete_collect_fn(name='demo') | |||||
dataset.delete_collate_fn(name='demo') | |||||
# 测试加入非法fn的请 | # 测试加入非法fn的请 | ||||
with self.assertRaises(AssertionError): | with self.assertRaises(AssertionError): | ||||
dataset.add_collect_fn(1) | |||||
dataset.add_collate_fn(1) | |||||
# 测试collect_fn返回值只有一个的情况 | |||||
def demo_collect_fn(ins_list): | |||||
# 测试collate_fn返回值只有一个的情况 | |||||
def demo_collate_fn(ins_list): | |||||
return {'3':1} | return {'3':1} | ||||
dataset.add_collect_fn(demo_collect_fn, name='demo') | |||||
dataset.add_collate_fn(demo_collate_fn, name='demo') | |||||
with self.assertRaises(BaseException): | with self.assertRaises(BaseException): | ||||
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) | batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) | ||||
for batch_x, batch_y in batch: | for batch_x, batch_y in batch: | ||||
pass | pass | ||||
dataset.delete_collect_fn(name='demo') | |||||
dataset.delete_collate_fn(name='demo') | |||||
# 测试多个collect_fn | |||||
dataset.add_collect_fn(demo_collect_fn, name='demo') | |||||
dataset.add_collect_fn(demo_collect_fn, name='demo') | |||||
# 测试多个collate_fn | |||||
dataset.add_collate_fn(demo_collate_fn, name='demo') | |||||
dataset.add_collate_fn(demo_collate_fn, name='demo') | |||||
# 测试删除 | # 测试删除 | ||||
dataset.delete_collect_fn() | |||||
dataset.delete_collect_fn() | |||||
self.assertTrue(dataset.collector.is_empty()) | |||||
dataset.delete_collate_fn() | |||||
dataset.delete_collate_fn() | |||||
self.assertTrue(dataset.collater.is_empty()) | |||||
def test_demo(self): | def test_demo(self): | ||||
import torch | import torch | ||||
@@ -261,9 +261,9 @@ class TestCase1(unittest.TestCase): | |||||
}) | }) | ||||
data.set_target('y') | data.set_target('y') | ||||
# 所有的collect_fn函数都接受list[(ind1, instance1), (ind2, instance2), ...]作为输入,其中ind1/ind2是该instance在dataset中 | |||||
# 所有的collate_fn函数都接受list[(ind1, instance1), (ind2, instance2), ...]作为输入,其中ind1/ind2是该instance在dataset中 | |||||
# 的index,instance1/instance2是这次batch取出来的数据,包含了所有的field. | # 的index,instance1/instance2是这次batch取出来的数据,包含了所有的field. | ||||
def concat_collect_fn(ins_list): | |||||
def concat_collate_fn(ins_list): | |||||
x1 = [ins['x1'] for ind,ins in ins_list] | x1 = [ins['x1'] for ind,ins in ins_list] | ||||
x2 = [ins['x2'] for ind,ins in ins_list] | x2 = [ins['x2'] for ind,ins in ins_list] | ||||
xs = [] | xs = [] | ||||
@@ -277,7 +277,7 @@ class TestCase1(unittest.TestCase): | |||||
# 采用返回值。 | # 采用返回值。 | ||||
return b_x, b_y | return b_x, b_y | ||||
data.add_collect_fn(concat_collect_fn) | |||||
data.add_collate_fn(concat_collate_fn) | |||||
for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2): | for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2): | ||||
print("batch_x:", batch_x) | print("batch_x:", batch_x) | ||||
@@ -287,7 +287,7 @@ class TestCase1(unittest.TestCase): | |||||
# batch_y: {'y': array([0, 1])} | # batch_y: {'y': array([0, 1])} | ||||
# 如果取batch过程含有一些参数,可以通过类来实现 | # 如果取batch过程含有一些参数,可以通过类来实现 | ||||
class ConCollectFn: | |||||
class ConCollateFn: | |||||
def __init__(self, max_len=3): | def __init__(self, max_len=3): | ||||
self.max_len = max_len | self.max_len = max_len | ||||
def __call__(self, ins_list): | def __call__(self, ins_list): | ||||
@@ -300,8 +300,8 @@ class TestCase1(unittest.TestCase): | |||||
b_x = {'x': arr} | b_x = {'x': arr} | ||||
b_y = {} | b_y = {} | ||||
return b_x, b_y | return b_x, b_y | ||||
data.delete_collect_fn() # 删除之前的collect_fn | |||||
data.add_collect_fn(ConCollectFn(max_len=3)) | |||||
data.delete_collate_fn() # 删除之前的collate_fn | |||||
data.add_collate_fn(ConCollateFn(max_len=3)) | |||||
for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2): | for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2): | ||||
print("batch_x:", batch_x) | print("batch_x:", batch_x) | ||||
print("batch_y:", batch_y) | print("batch_y:", batch_y) | ||||
@@ -326,14 +326,77 @@ class TestCase1(unittest.TestCase): | |||||
return x, y | return x, y | ||||
data1 = FakeData() | data1 = FakeData() | ||||
dataiter = TorchLoaderIter(data1, batch_size=2) | |||||
def collact_fn(ins_list): | |||||
xs = [ins[0]['x'] for ins in ins_list] | |||||
ys = [ins[1]['y'] for ins in ins_list] | |||||
return {'x':xs}, {'y':ys} | |||||
dataiter = TorchLoaderIter(data1, collate_fn=collact_fn, batch_size=2) | |||||
for x, y in dataiter: | for x, y in dataiter: | ||||
print(x, y) | print(x, y) | ||||
def func(): | |||||
data2 = FakeData(return_dict=False) | |||||
dataiter = TorchLoaderIter(data2, batch_size=2) | |||||
self.assertRaises(Exception, func) | |||||
def test_batch_sampler(self): | |||||
# 测试DataSetIter与TorchLoaderIter的batch_sampler能否正常工作 | |||||
# DataSetIter | |||||
ds = generate_fake_dataset(5) | |||||
ds.set_input('1') | |||||
class BatchSampler: | |||||
def __init__(self, dataset): | |||||
self.num_samples = len(dataset) | |||||
def __iter__(self): | |||||
index = 0 | |||||
indexes = list(range(self.num_samples)) | |||||
np.random.shuffle(indexes) | |||||
start_idx = 0 | |||||
while index < self.num_samples: | |||||
if start_idx == 0: | |||||
end_index = self.num_samples//2 | |||||
else: | |||||
end_index = self.num_samples | |||||
yield indexes[start_idx:end_index] | |||||
index = end_index | |||||
start_idx = end_index | |||||
def __len__(self): | |||||
return 2 | |||||
batch_sampler = BatchSampler(ds) | |||||
data_iter = DataSetIter(ds, batch_size=10, sampler=batch_sampler, as_numpy=False, | |||||
num_workers=0, pin_memory=False, drop_last=False, | |||||
timeout=0, worker_init_fn=None, collate_fn=None, | |||||
batch_sampler=batch_sampler) | |||||
num_samples = [len(ds)//2, len(ds)-len(ds)//2] | |||||
for idx, (batch_x, batch_y) in enumerate(data_iter): | |||||
self.assertEqual(num_samples[idx], len(batch_x['1'])) | |||||
# TorchLoaderIter | |||||
class FakeData: | |||||
def __init__(self): | |||||
self.x = [[1,2,3], [4,5,6], [1,2]] | |||||
def __len__(self): | |||||
return len(self.x) | |||||
def __getitem__(self, i): | |||||
x = self.x[i] | |||||
y = 0 | |||||
return x,y | |||||
def collate_fn(ins_list): | |||||
xs = [ins[0] for ins in ins_list] | |||||
ys = [ins[1] for ins in ins_list] | |||||
return {'x':xs}, {'y':ys} | |||||
ds = FakeData() | |||||
batch_sampler = BatchSampler(ds) | |||||
data_iter = TorchLoaderIter(ds, batch_size=10, sampler=batch_sampler, | |||||
num_workers=0, pin_memory=False, drop_last=False, | |||||
timeout=0, worker_init_fn=None, collate_fn=collate_fn, | |||||
batch_sampler=batch_sampler) | |||||
num_samples = [len(ds)//2, len(ds)-len(ds)//2] | |||||
for idx, (batch_x, batch_y) in enumerate(data_iter): | |||||
self.assertEqual(num_samples[idx], len(batch_x['x'])) | |||||
""" | """ | ||||
def test_multi_workers_batch(self): | def test_multi_workers_batch(self): | ||||
@@ -243,7 +243,7 @@ class TrainerTestGround(unittest.TestCase): | |||||
def __len__(self): | def __len__(self): | ||||
return self.num_samples | return self.num_samples | ||||
def collect_fn(data_list): | |||||
def collate_fn(data_list): | |||||
# [(x1,y1), (x2,y2), ...], 这里的输入实际上是将UdfDataSet的__getitem__输入结合为list | # [(x1,y1), (x2,y2), ...], 这里的输入实际上是将UdfDataSet的__getitem__输入结合为list | ||||
xs, ys = [], [] | xs, ys = [], [] | ||||
for l in data_list: | for l in data_list: | ||||
@@ -254,7 +254,7 @@ class TrainerTestGround(unittest.TestCase): | |||||
return {'x':x, 'y':y}, {'y':y} | return {'x':x, 'y':y}, {'y':y} | ||||
dataset = UdfDataSet(10) | dataset = UdfDataSet(10) | ||||
dataset = TorchLoaderIter(dataset, collate_fn=collect_fn) | |||||
dataset = TorchLoaderIter(dataset, collate_fn=collate_fn) | |||||
class Model(nn.Module): | class Model(nn.Module): | ||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
@@ -268,6 +268,67 @@ class TrainerTestGround(unittest.TestCase): | |||||
metrics=AccuracyMetric(target='y'), use_tqdm=False) | metrics=AccuracyMetric(target='y'), use_tqdm=False) | ||||
trainer.train(load_best_model=False) | trainer.train(load_best_model=False) | ||||
def test_batch_sampler_dataiter(self): | |||||
import random | |||||
import torch | |||||
class BatchSampler: | |||||
def __init__(self, dataset): | |||||
self.num_samples = len(dataset) | |||||
def __iter__(self): | |||||
index = 0 | |||||
indexes = list(range(self.num_samples)) | |||||
np.random.shuffle(indexes) | |||||
start_idx = 0 | |||||
while index < self.num_samples: | |||||
if start_idx == 0: | |||||
end_index = self.num_samples//2 | |||||
else: | |||||
end_index = self.num_samples | |||||
yield indexes[start_idx:end_index] | |||||
index = end_index | |||||
start_idx = end_index | |||||
def __len__(self): | |||||
return 2 | |||||
class UdfDataSet: | |||||
def __init__(self, num_samples): | |||||
self.num_samples = num_samples | |||||
def __getitem__(self, idx): | |||||
x = [random.random() for _ in range(3)] | |||||
y = random.random() | |||||
return x,y | |||||
def __len__(self): | |||||
return self.num_samples | |||||
def collate_fn(data_list): | |||||
# [(x1,y1), (x2,y2), ...], 这里的输入实际上是将UdfDataSet的__getitem__输入结合为list | |||||
xs, ys = [], [] | |||||
for l in data_list: | |||||
x, y = l | |||||
xs.append(x) | |||||
ys.append(y) | |||||
x,y = torch.FloatTensor(xs), torch.FloatTensor(ys) | |||||
return {'x':x, 'y':y}, {'y':y} | |||||
dataset = UdfDataSet(11) | |||||
batch_sampler = BatchSampler(dataset) | |||||
dataset = TorchLoaderIter(dataset, collate_fn=collate_fn, batch_sampler=batch_sampler) | |||||
class Model(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(3, 1) | |||||
def forward(self, x, y): | |||||
return {'loss':torch.pow(self.fc(x).squeeze(-1)-y, 2).sum()} | |||||
def predict(self, x): | |||||
return {'pred':self.fc(x).squeeze(-1)} | |||||
model = Model() | |||||
trainer = Trainer(train_data=dataset, model=model, loss=None, print_every=2, dev_data=dataset, | |||||
metrics=AccuracyMetric(target='y'), use_tqdm=False) | |||||
trainer.train(load_best_model=False) | |||||
def test_onthefly_iter(self): | def test_onthefly_iter(self): | ||||
import tempfile | import tempfile | ||||
import random | import random | ||||
@@ -333,7 +394,7 @@ class TrainerTestGround(unittest.TestCase): | |||||
return {'loss': torch.pow(self.fc(x).squeeze(-1) - y, 2).sum()} | return {'loss': torch.pow(self.fc(x).squeeze(-1) - y, 2).sum()} | ||||
def predict(self, x): | def predict(self, x): | ||||
return {'pred': self.fc(x).squeeze(0)} | |||||
return {'pred': self.fc(x).squeeze(-1)} | |||||
model = Model() | model = Model() | ||||
trainer = Trainer(train_data=dataset, model=model, loss=None, print_every=2, dev_data=dataset, | trainer = Trainer(train_data=dataset, model=model, loss=None, print_every=2, dev_data=dataset, | ||||
@@ -356,7 +417,7 @@ class TrainerTestGround(unittest.TestCase): | |||||
x.append(ins['x1']+ins['x2']) | x.append(ins['x1']+ins['x2']) | ||||
x = torch.FloatTensor(x) | x = torch.FloatTensor(x) | ||||
return {'x':x}, {} | return {'x':x}, {} | ||||
dataset.add_collect_fn(fn) | |||||
dataset.add_collate_fn(fn) | |||||
class Model(nn.Module): | class Model(nn.Module): | ||||
def __init__(self): | def __init__(self): | ||||
@@ -377,7 +438,7 @@ class TrainerTestGround(unittest.TestCase): | |||||
dev_data=dataset, metrics=AccuracyMetric(target='y'), use_tqdm=False) | dev_data=dataset, metrics=AccuracyMetric(target='y'), use_tqdm=False) | ||||
trainer.train() | trainer.train() | ||||
def test_collect_fn2(self): | |||||
def test_collate_fn2(self): | |||||
"""测试能否实现batch_x, batch_y""" | """测试能否实现batch_x, batch_y""" | ||||
dataset = prepare_fake_dataset2('x1', 'x2') | dataset = prepare_fake_dataset2('x1', 'x2') | ||||
dataset.set_input('x1', 'x2') | dataset.set_input('x1', 'x2') | ||||
@@ -389,7 +450,7 @@ class TrainerTestGround(unittest.TestCase): | |||||
x.append(ins['x1']+ins['x2']) | x.append(ins['x1']+ins['x2']) | ||||
x = torch.FloatTensor(x) | x = torch.FloatTensor(x) | ||||
return {'x':x}, {'target':x[:, :4].argmax(dim=-1)} | return {'x':x}, {'target':x[:, :4].argmax(dim=-1)} | ||||
dataset.add_collect_fn(fn) | |||||
dataset.add_collate_fn(fn) | |||||
class Model(nn.Module): | class Model(nn.Module): | ||||
def __init__(self): | def __init__(self): | ||||
@@ -410,7 +471,7 @@ class TrainerTestGround(unittest.TestCase): | |||||
dev_data=dataset, metrics=AccuracyMetric(), use_tqdm=False) | dev_data=dataset, metrics=AccuracyMetric(), use_tqdm=False) | ||||
trainer.train() | trainer.train() | ||||
def test_collect_fn3(self): | |||||
def test_collate_fn3(self): | |||||
""" | """ | ||||
测试应该会覆盖 | 测试应该会覆盖 | ||||
@@ -426,7 +487,7 @@ class TrainerTestGround(unittest.TestCase): | |||||
x.append(ins['x1']+ins['x2']) | x.append(ins['x1']+ins['x2']) | ||||
x = torch.FloatTensor(x) | x = torch.FloatTensor(x) | ||||
return {'x1':torch.zeros_like(x)}, {'target':torch.zeros(x.size(0)).long(), 'y':x} | return {'x1':torch.zeros_like(x)}, {'target':torch.zeros(x.size(0)).long(), 'y':x} | ||||
dataset.add_collect_fn(fn) | |||||
dataset.add_collate_fn(fn) | |||||
class Model(nn.Module): | class Model(nn.Module): | ||||
def __init__(self): | def __init__(self): | ||||