Browse Source

调整涉及到多个field取batch的实现方式

tags/v0.5.5
yh_cc 5 years ago
parent
commit
7f01971321
9 changed files with 354 additions and 205 deletions
  1. +1
    -2
      fastNLP/core/__init__.py
  2. +45
    -41
      fastNLP/core/batch.py
  3. +76
    -116
      fastNLP/core/collect_fn.py
  4. +90
    -22
      fastNLP/core/dataset.py
  5. +14
    -7
      fastNLP/core/field.py
  6. +1
    -1
      fastNLP/core/tester.py
  7. +1
    -1
      fastNLP/core/trainer.py
  8. +8
    -7
      fastNLP/core/utils.py
  9. +118
    -8
      test/core/test_batch.py

+ 1
- 2
fastNLP/core/__init__.py View File

@@ -21,7 +21,6 @@ __all__ = [
"AutoPadder",
"EngChar2DPadder",

"CollectFn",
"ConcatCollectFn",
"Vocabulary",
@@ -97,4 +96,4 @@ from .tester import Tester
from .trainer import Trainer
from .utils import cache_results, seq_len_to_mask, get_seq_len
from .vocabulary import Vocabulary
from .collect_fn import CollectFn, ConcatCollectFn
from .collect_fn import ConcatCollectFn

+ 45
- 41
fastNLP/core/batch.py View File

@@ -9,17 +9,16 @@ __all__ = [
]

import atexit
from numbers import Number
import abc

from numbers import Number
import numpy as np
import torch
import torch.utils.data
from collections import defaultdict

from ._logger import logger
from .dataset import DataSet
from .sampler import SequentialSampler
from .field import _get_ele_type_and_dim

_python_is_exit = False

@@ -33,6 +32,9 @@ atexit.register(_set_python_is_exit)


class DataSetGetter:
"""
传递给torch.utils.data.DataLoader获取数据,DataLoder会传入int的idx获取数据(调用这里的__getitem__()函数)。
"""
def __init__(self, dataset: DataSet, as_numpy=False):
self.dataset = dataset
self.as_numpy = as_numpy
@@ -56,7 +58,6 @@ class DataSetGetter:
:param batch: [[idx1, x_dict1, y_dict1], [idx2, x_dict2, y_dict2], [xx, xx, xx]]
:return:
"""
# TODO 支持在DataSet中定义collate_fn,因为有时候可能需要不同的field之间融合,比如BERT的场景
indices = []
sin_x, sin_y = defaultdict(list), defaultdict(list)
for idx, ins in ins_list:
@@ -67,24 +68,6 @@ class DataSetGetter:
if n in self.y_names:
sin_y[n].append(v)

def may_to_tensor(data):
dtype, dim = _get_ele_type_and_dim(data)
# print(dtype, type(dtype), str(dtype))
if not self.as_numpy:
try:
data, flag = _to_tensor(data, dtype)
except TypeError as e:
logger.error(f"Field {n} cannot be converted to torch.tensor.")
raise e
# if torch.is_tensor(data):
# str_dtype = str(dtype)
# if 'float' in str_dtype:
# data = data.float()
# elif 'int' in str_dtype:
# data = data.long()
# print(data.dtype)
return data

def pad(batch_dict):
result = {}
for n, vlist in batch_dict.items():
@@ -98,25 +81,13 @@ class DataSetGetter:
sin_x = pad(sin_x)
sin_y = pad(sin_y)

bx, by = self.dataset._collect_batch(ins_list)
def convert_tensor(batch_dict):
for n, v in batch_dict.items():
batch_dict[n] = may_to_tensor(v)

# collect_fn replaces single field
sin_x.update(bx)
sin_y.update(by)

convert_tensor(sin_x)
convert_tensor(sin_y)
if not self.dataset.collector.is_empty():
bx, by = self.dataset._collect_batch(ins_list)
sin_x.update(bx)
sin_y.update(by)

return (indices, sin_x, sin_y)

def set_idx_list(self, idx_list):
if len(idx_list) != len(self.idx_list):
raise ValueError
self.idx_list = idx_list

def __getattr__(self, item):
if hasattr(self.dataset, item):
return getattr(self.dataset, item)
@@ -125,6 +96,10 @@ class DataSetGetter:


class SamplerAdapter(torch.utils.data.Sampler):
"""
用于传入torch.utils.data.DataLoader中,DataLoader会调用__iter__()方法获取index(一次只取一个int)

"""
def __init__(self, sampler, dataset):
super().__init__(dataset)
self.sampler = sampler
@@ -138,6 +113,11 @@ class SamplerAdapter(torch.utils.data.Sampler):


class BatchIter:
"""
Trainer用于迭代数据的类。继承该类,并实现get_num_batches(), get_batch_indices(), dataset(), num_batches(),
__iter__()方法。

"""
def __init__(self, dataset, batch_size=1, sampler=None,
num_workers=0, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None, collate_fn=None):
@@ -145,6 +125,8 @@ class BatchIter:
self.sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset)
else:
self.sampler = sampler

# DataLoader的collect_fn输入是List[],里面的元素是dataset[index]返回的结果
if collate_fn is None:
# pytoch <= 1.1 中不能设置collate_fn=None
self.dataiter = torch.utils.data.DataLoader(
@@ -160,17 +142,25 @@ class BatchIter:
timeout=timeout, worker_init_fn=worker_init_fn)

# 以sampler的数量为准,因为DistributedSampler的时候每个进程上并不是所有的数据都用上了
self.num_batches = self.get_num_batches(len(self.dataiter.sampler), batch_size, drop_last)
self._num_batches = self.get_num_batches(len(self.dataiter.sampler), batch_size, drop_last)
self.batch_size = batch_size
self.cur_batch_indices = None

@property
def num_batches(self):
return self._num_batches

@num_batches.setter
def num_batches(self, value):
self._num_batches = value

def init_iter(self):
pass

@staticmethod
def get_num_batches(num_samples, batch_size, drop_last):
"""
计算batch的数量。
计算batch的数量。用于前端显示进度

:param int num_samples:
:param int batch_size:
@@ -184,7 +174,7 @@ class BatchIter:

def get_batch_indices(self):
"""
获取当前已经输出的batch的index。
获取最近输出的batch的index。用于溯源当前batch的数据

:return:
"""
@@ -195,8 +185,22 @@ class BatchIter:

@property
def dataset(self):
"""
获取正在参与iterate的dataset

:return:
"""
return self.dataiter.dataset

@abc.abstractmethod
def __iter__(self):
"""
用于实际数据循环的类,返回值需要为两个dict, 第一个dict中的内容会认为是input, 第二个dict中的内容会认为是target

:return:
"""
raise NotImplemented


class DataSetIter(BatchIter):
"""


+ 76
- 116
fastNLP/core/collect_fn.py View File

@@ -4,7 +4,8 @@ from builtins import sorted
import torch
import numpy as np
from .field import _get_ele_type_and_dim
from collections import defaultdict
from .utils import logger
from copy import deepcopy


def _check_type(batch_dict, fields):
@@ -36,127 +37,89 @@ def batching(samples, max_len=0, padding_val=0):


class Collector:
"""
辅助DataSet管理collect_fn的类

"""
def __init__(self):
self.fns = {}
self.input2fn = defaultdict(list)
self.output2fn = defaultdict(list)
self.fn2input = {}
self.fn2output = {}

def add_fn(self, fn, inputs, outputs, is_input, is_target):
for name in outputs:
if name in self.output2fn:
raise ValueError("Duplicated name: {} for CollectFn: {}".format(name, fn))

if fn.num_inputs() > 0 and len(inputs) != fn.num_inputs():
raise ValueError(
"Incorrect num of inputs, should be {} not {}".format(
fn.num_inputs(), len(inputs)
))

if fn.num_outputs() > 0 and len(outputs) != fn.num_outputs():
raise ValueError("Incorrect num of inputs, should be {} not {}".format(
fn.num_outputs(), len(outputs)))

self.fns[fn] = {'is_input': is_input, 'is_target': is_target}
for i, field in enumerate(inputs):
self.input2fn[field].append((fn, i))
for i, name in enumerate(outputs):
self.output2fn[name].append((fn, i))

def _rebuild_fn2io(self):
def transpose(name2fn):
fn2names = defaultdict(list)
for name, vlist in name2fn.items():
for fn, i in vlist:
fn2names[fn].append((name, i))
for fn, vlist in fn2names.items():
vlist = sorted(vlist, key=lambda x: x[1])
fn2names[fn] = [name for name, i in vlist]
return fn2names

self.fn2input = transpose(self.input2fn)
self.fn2output = transpose(self.output2fn)

def _clear_fn2io(self):
self.fn2input.clear()
self.fn2output.clear()
self.collect_fns = {}

def add_fn(self, fn, name=None):
"""
向collector新增一个collect_fn函数

:param callable fn:
:param str,int name:
:return:
"""
if name in self.collect_fns:
logger.warn(f"collect_fn:{name} will be overwritten.")
if name is None:
name = len(self.collect_fns)
self.collect_fns[name] = fn

def is_empty(self):
"""
返回是否包含collect_fn

:return:
"""
return len(self.collect_fns)==0

def delete_fn(self, name=None):
"""
删除collect_fn

:param str,int name: 如果为None就删除最近加入的collect_fn
:return:
"""
if not self.is_empty():
if name in self.collect_fns:
self.collect_fns.pop(name)
elif name is None:
last_key = list(self.collect_fns.keys())[0]
self.collect_fns.pop(last_key)

def collect_batch(self, ins_list):
if len(ins_list) == 0:
return {}, {}

if len(self.fn2output) == 0:
self._rebuild_fn2io()

bx = {}
by = {}
for fn, attr in self.fns.items():
inputs = self.fn2input.get(fn, None)
outputs = self.fn2output.get(fn, None)
res = fn.collect(ins_list, inputs, outputs)
if attr.get('is_input', False):
bx.update(res)
if attr.get('is_target', False):
by.update(res)
bx, by = {}, {}
for name, fn in self.collect_fns.items():
try:
batch_x, batch_y = fn(ins_list)
except BaseException as e:
logger.error(f"Exception:`{e}` happens when call collect_fn:`{name}`.")
raise e
bx.update(batch_x)
by.update(batch_y)
return bx, by

def rename_field(self, old_f, new_f):
if new_f in self.input2fn:
# name conflict
raise ValueError
if old_f not in self.input2fn:
# renamed field not affect collectors
return
self.input2fn[new_f] = self.input2fn[old_f]
self._clear_fn2io()

def drop_field(self, f):
if f in self.input2fn:
raise ValueError

def outputs(self):
return self.output2fn.keys()

def copy_from(self, col):
assert isinstance(col, Collector)
self.fns = col.fns.copy()
self.input2fn = col.input2fn.copy()
self.output2fn = col.output2fn.copy()
self._clear_fn2io()

class CollectFn:
def __init__(self):
self.fields = []

def collect(self, ins_list, inputs, outputs):
raise NotImplementedError
new_col = Collector()
new_col.collect_fns = deepcopy(col)
return new_col

def num_inputs(self):
return 0

def num_outputs(self):
return 0

@staticmethod
def get_batch_size(batch_dict):
if len(batch_dict) == 0:
return 0
return len(next(iter(batch_dict.values())))


class ConcatCollectFn(CollectFn):
class ConcatCollectFn:
"""
field拼接Fn,将不同field按序拼接后,padding产生数据。所有field必须有相同的dim。
field拼接collect_fn,将不同field按序拼接后,padding产生数据。

:param List[str] inputs: 将哪些field的数据拼接起来, 目前仅支持1d的field
:param str output: 拼接后的field名称
:param pad_val: padding的数值
:param max_len: 拼接后最大长度
:param is_input: 是否将生成的output设置为input
:param is_target: 是否将生成的output设置为target
"""

def __init__(self, pad_val=0, max_len=0):
def __init__(self, inputs, output, pad_val=0, max_len=0, is_input=True, is_target=False):
super().__init__()
assert isinstance(inputs, list)
self.inputs = inputs
self.output = output
self.pad_val = pad_val
self.max_len = max_len
self.is_input = is_input
self.is_target = is_target

@staticmethod
def _to_numpy(seq):
@@ -165,21 +128,18 @@ class ConcatCollectFn(CollectFn):
else:
return np.array(seq)

def collect(self, ins_list, inputs, outputs):
def __call__(self, ins_list):
samples = []
for i, ins in ins_list:
sample = []
for i in inputs:
sample.append(self._to_numpy(ins[i]))
for input_name in self.inputs:
sample.append(self._to_numpy(ins[input_name]))
samples.append(np.concatenate(sample, axis=0))
seq_len = [s.shape[0] for s in samples]
batch = batching(samples, max_len=self.max_len, padding_val=self.pad_val)
o1, o2 = outputs
return {o1: batch, o2: seq_len}
def num_inputs(self):
return 0
b_x, b_y = {}, {}
if self.is_input:
b_x[self.output] = batch
if self.is_target:
b_y[self.output] = batch

def num_outputs(self):
# (concat_words, seq_len)
return 2
return b_x, b_y

+ 90
- 22
fastNLP/core/dataset.py View File

@@ -281,6 +281,75 @@
# 也可以设置pad的value
dataset.set_pad_val('chars', -1)

3.3 根据DataSet中多个field合成新的field
--------------------------------------

DataSet支持在进行batch时,默认只能看到当前的field的值,但在某些训练中可能存在以下的情况: (1)需要两个field拼接成为一个field;
(2)需要在batch中进行负采样。这时候就需要能够同时利用多个field进行batch的操作,DataSet中的add_collect_fn()函数支持添加
自定义涉及多个field的collect_fn函数。例如下例中将两个field拼接成一个field的场景

.. code-block::

from fastNLP import DataSet, DataSetIter
import torch

data = DataSet({
'x1': [[0, 1],
[2]],
'x2': [[3],
[2, 4, 5]],
'y': [0, 1]
})
data.set_target('y')

# 所有的collect_fn函数都接受list[(ind1, instance1), (ind2, instance2), ...]作为输入,其中ind1/ind2是该instance在dataset中
# 的index,instance1/instance2是这次batch取出来的数据,包含了所有的field.
def concat_collect_fn(ins_list):
x1 = [ins['x1'] for ind,ins in ins_list]
x2 = [ins['x2'] for ind,ins in ins_list]
xs = []
for i in range(len(ins_list)):
xs.append(torch.LongTensor(x1[i] + x2[i]))
# 需要自行pad并转换为tensor,但不需要移动到gpu
arr = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=0)
b_x = {'x': arr}
b_y = {}
# 返回值一定是两个dict,第一个dict的值会认为是input,第二个dict的值会认为是target. 若名称与已有input或target重复,则
# 采用返回值。
return b_x, b_y

data.add_collect_fn(concat_collect_fn)

for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2):
print("batch_x:", batch_x)
print("batch_y:", batch_y)
# batch_x: {'x': tensor([[0, 1, 3, 0],
# [2, 2, 4, 5]])}
# batch_y: {'y': array([0, 1])}

# 如果取batch过程含有一些参数,可以通过类来实现
class ConCollectFn:
def __init__(self, max_len=3):
self.max_len = max_len

def __call__(self, ins_list): # 实现该类的__call__函数
x1 = [ins['x1'] for ind, ins in ins_list]
x2 = [ins['x2'] for ind, ins in ins_list]
xs = []
for i in range(len(ins_list)):
xs.append(torch.LongTensor(x1[i] + x2[i])[:self.max_len])
arr = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=0)
b_x = {'x': arr}
b_y = {}
return b_x, b_y
data.delete_collect_fn() # 删除之前的collect_fn
data.add_collect_fn(ConCollectFn(max_len=3))
for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2):
print("batch_x:", batch_x)
print("batch_y:", batch_y)
# batch_x: {'x': tensor([[0, 1, 3],
# [2, 2, 4]])}
# batch_y: {'y': array([0, 1])}

"""
__all__ = [
@@ -300,7 +369,6 @@ from .field import AutoPadder
from .field import FieldArray
from .field import SetInputOrTargetException
from .instance import Instance
from .utils import _get_func_signature
from .utils import pretty_table_printer
from .collect_fn import Collector

@@ -394,6 +462,7 @@ class DataSet(object):
for field in self.field_arrays.values():
data_set.add_field(field_name=field.name, fields=field.content[idx], padder=field.padder,
is_input=field.is_input, is_target=field.is_target, ignore_type=field.ignore_type)
data_set.collector = self.collector.copy_from(self.collector)
return data_set
elif isinstance(idx, str):
if idx not in self:
@@ -407,6 +476,7 @@ class DataSet(object):
dataset.append(instance)
for field_name, field in self.field_arrays.items():
dataset.field_arrays[field_name].to(field)
dataset.collector = self.collector.copy_from(self.collector)
return dataset
else:
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx)))
@@ -575,7 +645,6 @@ class DataSet(object):
:param str field_name: 需要删除的field的名称.
"""
self.field_arrays.pop(field_name)
self.collector.drop_field(field_name)
return self

def copy_field(self, field_name, new_field_name):
@@ -648,7 +717,6 @@ class DataSet(object):
if field_name in self.field_arrays:
self.field_arrays[new_field_name] = self.field_arrays.pop(field_name)
self.field_arrays[new_field_name].name = new_field_name
self.collector.rename_field(field_name, new_field_name)
else:
raise KeyError("DataSet has no field named {}.".format(field_name))
return self
@@ -1040,30 +1108,30 @@ class DataSet(object):
assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d))
return d

def add_collect_fn(self, fn, inputs, outputs, is_input, is_target):
def add_collect_fn(self, fn, name=None):
"""
添加 CollectFn,使用多个field产生batch中的数据
添加 CollectFn,collect_fn允许在生成的batch的过程中动态生成一些数据(在DataSetIter作为迭代器的情况下有效,默认情况下就是用的
这个)。支持依次添加多个collect_fn, 如果相同的key,后面的collect_fn的结果覆盖前面的collect_fn的结果。

:param CollectFn fn: 定义产生数据的方式
:param list inputs: 生成的数据在batch中的名称
:param list outputs: 用于产生数据的 fields,有序
:param bool is_input: 是否出现在input中,为否则出现在target batch中
:param bool is_target:
:param callable fn: 传入一个可调用的function, 该function可接受的参数为List[(ind1, instance1), (ind2, instance2)]
(某个batch被选中的所有的indice以及instance),其中ind1/ind2是该instance在dataset中的index,instance1/instance2是
这次batch取出来的数据,包含了所有的field。返回值需要为两个dict,第一个dict的值将被认为是input,第二个dict的值被认为是
target,返回的值至多允许一个空dict。若返回的dict中包含了被设置为input或target的field的名称,将覆盖dataset中的field。
fastNLP不会将collect_fn的返回结果pad和转换为tensor,需要在collect_fn中完成pad和转换为tensor(不需要将tensor移动到
gpu中,如果是pytorch的tensor,fastNLP会自动将其移动到特定gpu)。不要修改传入collect_fn中的数据,否则可能导致未知问题。
:param str,int name: collect_fn的名称,如果不传入,默认使用自增长的数字作为key。相同的name会覆盖之前的collect_fn。
"""
def check_fields(fields):
for f in fields:
if f not in self.field_arrays:
raise ValueError(f)

def check_name(names):
for name in names:
if name in self.field_arrays:
logger.warning('name of collect_fn will cover the field name in dataset')
assert callable(fn), "You must pass in a callable object."
self.collector.add_fn(fn, name=name)

check_fields(inputs)
check_name(outputs)
def delete_collect_fn(self, name=None):
"""
删除某个collect_fn

self.collector.add_fn(fn, inputs, outputs, is_input, is_target)
:param str,int name: 如果为None,则删除最近加入的collect_fn
:return:
"""
self.collector.delete_fn(name)

def _collect_batch(self, ins_list):
return self.collector.collect_batch(ins_list)

+ 14
- 7
fastNLP/core/field.py View File

@@ -191,24 +191,31 @@ class FieldArray:
def get(self, indices, pad=True):
"""
根据给定的indices返回内容
根据给定的indices返回内容

:param int,List[int] indices: 获取indices对应的内容。
:param bool pad: 是否对返回的结果进行padding。仅对indices为List[int]时有效
:return: 根据给定的indices返回的内容,可能是单个值或List
:param bool pad: 是否对返回的结果进行padding。仅对: (1) indices为List[int]; (2)padder不为None; (3)field设置了input
或target,有效
:return: 根据给定的indices返回的内容,可能是单个值或ndarray
"""
if isinstance(indices, int):
return self.content[indices]
if self.is_input is False and self.is_target is False:
raise RuntimeError("Please specify either is_input or is_target to True for {}".format(self.name))

contents = [self.content[i] for i in indices]
if self.padder is None or pad is False:
return np.array(contents)
else:
elif self.is_input or self.is_target:
return self.pad(contents)
else:
return np.array(contents)
def pad(self, contents):
"""
传入list的contents,将contents使用padder进行padding,contents必须为从本FieldArray中取出的。

:param list contents:
:return:
"""
return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim)
def set_padder(self, padder):


+ 1
- 1
fastNLP/core/tester.py View File

@@ -71,7 +71,7 @@ class Tester(object):
def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1, use_tqdm=True):
"""
:param ~fastNLP.DataSet data: 需要测试的数据集
:param ~fastNLP.DataSet,~fastNLP.BatchIter data: 需要测试的数据集
:param torch.nn.Module model: 使用的模型
:param ~fastNLP.core.metrics.MetricBase,List[~fastNLP.core.metrics.MetricBase] metrics: 测试时使用的metrics
:param int batch_size: evaluation时使用的batch_size有多大。


+ 1
- 1
fastNLP/core/trainer.py View File

@@ -375,7 +375,7 @@ class Trainer(object):
callbacks=None, check_code_level=0, **kwargs):
"""
:param train_data: 训练集, :class:`~fastNLP.DataSet` 类型
:param train_data: 训练集, :class:`~fastNLP.DataSet` 类型或 :class:`~fastNLP.BatchIter`的子类
:param nn.modules model: 待训练的模型
:param optimizer: `torch.optim.Optimizer` 优化器。如果为None,则Trainer使用默认的Adam(model.parameters(), lr=4e-3)这个优化器
:param int batch_size: 训练和验证的时候的batch大小。


+ 8
- 7
fastNLP/core/utils.py View File

@@ -624,9 +624,11 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
if check_res.unused:
_tmp = f"Check key assignment for `{input_func_map.get(_miss,_miss)}` when initialize {module_name}."
if _tmp:
_tmp += f' Or provide `{_miss}` in DataSet or output of {prev_func_signature}.'
_tmp += f' Or provide `{_miss}` in DataSet or the output of {prev_func_signature}. '
else:
_tmp = f'Provide `{_miss}` in DataSet or output of {prev_func_signature}.'
_tmp = f'Provide `{_miss}` in DataSet or the output of {prev_func_signature}.'
if not dataset.collector.is_empty():
_tmp += f'Or you need to add `{_miss}` in the output of your collect_fn. '
suggestions.append(_tmp)

if check_res.duplicated:
@@ -683,12 +685,11 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level):
else:
_miss_out_dataset.append(_miss)
if _miss_in_dataset:
suggestions.append(f"You might need to set {_miss_in_dataset} as input. ")
suggestions.append(f"You might need to set `{_miss_in_dataset}` as input. ")
if _miss_out_dataset:
_tmp = f"You need to provide {_miss_out_dataset} in DataSet and set it as input. "
# if check_res.unused:
# _tmp += f"Or you might find it in `unused field:`, you can use DataSet.rename_field() to " \
# f"rename the field in `unused field:`."
_tmp = f"You need to provide `{_miss_out_dataset}` in DataSet and set it as input. "
if not dataset.collector.is_empty():
_tmp += f'Or you need to add `{_miss_out_dataset}` in the output of your collect_fn. '
suggestions.append(_tmp)

if check_res.unused:


+ 118
- 8
test/core/test_batch.py View File

@@ -158,20 +158,130 @@ class TestCase1(unittest.TestCase):
dataset.set_input('1','2')
dataset.set_target('0','3')

fn = ConcatCollectFn()
dataset.add_collect_fn(fn, inputs=['1', '2'],
outputs=['12', 'seq_len'],
is_input=True, is_target=False)

fn = ConcatCollectFn(inputs=['1', '2'], output='12', pad_val=0, max_len=0, is_input=True, is_target=False)
dataset.add_collect_fn(fn, name='demo')
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True)
for batch_x, batch_y in batch:
for i in range(batch_size):
# print(i)
self.assertEqual(batch_x['12'][i].sum(), batch_x['1'][i].sum() + batch_x['2'][i].sum())
self.assertEqual(
batch_x['seq_len'][i],
(batch_x['1'][i]!=0).sum() + (batch_x['2'][i]!=0).sum())
dataset.delete_collect_fn(name='demo')

# 测试非input的情况
dataset.set_input('1', '2', flag=False) #
fn = ConcatCollectFn(inputs=['1', '2'], output='12', pad_val=0, max_len=0, is_input=True, is_target=False)
dataset.add_collect_fn(fn, name='demo')
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True)
for batch_x, batch_y in batch:
for i in range(batch_size):
self.assertTrue('12' in batch_x)
dataset.delete_collect_fn(name='demo')
dataset.set_input('1', '2', flag=True) #

# 测试覆盖其它field的情况
fn = ConcatCollectFn(inputs=['1', '2'], output='3', pad_val=0, max_len=0, is_input=True, is_target=True)
dataset.add_collect_fn(fn, name='demo')
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True)
for batch_x, batch_y in batch:
for i in range(batch_size):
# print(i)
self.assertEqual(batch_y['3'][i].sum(), batch_x['1'][i].sum() + batch_x['2'][i].sum())
dataset.delete_collect_fn(name='demo')

# 测试非input,target的情况
dataset.set_input('1', '2', flag=False)
fn = ConcatCollectFn(inputs=['1', '2'], output='3', pad_val=0, max_len=0, is_input=True, is_target=True)
dataset.add_collect_fn(fn, name='demo')
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True)
for batch_x, batch_y in batch:
for i in range(batch_size):
# print(i)
self.assertTrue('3' in batch_x)
self.assertTrue('3' in batch_y)
dataset.delete_collect_fn(name='demo')

# 测试加入非法fn的请
with self.assertRaises(AssertionError):
dataset.add_collect_fn(1)

# 测试collect_fn返回值只有一个的情况
def demo_collect_fn(ins_list):
return {'3':1}
dataset.add_collect_fn(demo_collect_fn, name='demo')
with self.assertRaises(BaseException):
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True)
for batch_x, batch_y in batch:
pass
dataset.delete_collect_fn(name='demo')

# 测试多个collect_fn
dataset.add_collect_fn(demo_collect_fn, name='demo')
dataset.add_collect_fn(demo_collect_fn, name='demo')
# 测试删除
dataset.delete_collect_fn()
dataset.delete_collect_fn()
self.assertTrue(dataset.collector.is_empty())

def test_demo(self):
import torch

data = DataSet({
'x1': [[0, 1],
[2]],
'x2': [[3],
[2, 4, 5]
],
'y': [0, 1]
})
data.set_target('y')

# 所有的collect_fn函数都接受list[(ind1, instance1), (ind2, instance2), ...]作为输入,其中ind1/ind2是该instance在dataset中
# 的index,instance1/instance2是这次batch取出来的数据,包含了所有的field.
def concat_collect_fn(ins_list):
x1 = [ins['x1'] for ind,ins in ins_list]
x2 = [ins['x2'] for ind,ins in ins_list]
xs = []
for i in range(len(ins_list)):
xs.append(torch.LongTensor(x1[i] + x2[i]))
# 需要自行pad并转换为tensor,但不需要移动到gpu
arr = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=0)
b_x = {'x': arr}
b_y = {}
# 返回值一定是两个dict,第一个dict的值会认为是input,第二个dict的值会认为是target. 若名称与已有input或target重复,则
# 采用返回值。
return b_x, b_y

data.add_collect_fn(concat_collect_fn)

for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2):
print("batch_x:", batch_x)
print("batch_y:", batch_y)
# batch_x: {'x': tensor([[0, 1, 3, 0],
# [2, 2, 4, 5]])}
# batch_y: {'y': array([0, 1])}

# 如果取batch过程含有一些参数,可以通过类来实现
class ConCollectFn:
def __init__(self, max_len=3):
self.max_len = max_len
def __call__(self, ins_list):
x1 = [ins['x1'] for ind, ins in ins_list]
x2 = [ins['x2'] for ind, ins in ins_list]
xs = []
for i in range(len(ins_list)):
xs.append(torch.LongTensor(x1[i] + x2[i])[:self.max_len])
arr = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=0)
b_x = {'x': arr}
b_y = {}
return b_x, b_y
data.delete_collect_fn() # 删除之前的collect_fn
data.add_collect_fn(ConCollectFn(max_len=3))
for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2):
print("batch_x:", batch_x)
print("batch_y:", batch_y)
# batch_x: {'x': tensor([[0, 1, 3],
# [2, 2, 4]])}
# batch_y: {'y': array([0, 1])}

def testTensorLoaderIter(self):
class FakeData:


Loading…
Cancel
Save