Browse Source

调整涉及到多个field取batch的实现方式

tags/v0.5.5
yh_cc 5 years ago
parent
commit
7f01971321
9 changed files with 354 additions and 205 deletions
  1. +1
    -2
      fastNLP/core/__init__.py
  2. +45
    -41
      fastNLP/core/batch.py
  3. +76
    -116
      fastNLP/core/collect_fn.py
  4. +90
    -22
      fastNLP/core/dataset.py
  5. +14
    -7
      fastNLP/core/field.py
  6. +1
    -1
      fastNLP/core/tester.py
  7. +1
    -1
      fastNLP/core/trainer.py
  8. +8
    -7
      fastNLP/core/utils.py
  9. +118
    -8
      test/core/test_batch.py

+ 1
- 2
fastNLP/core/__init__.py View File

@@ -21,7 +21,6 @@ __all__ = [
"AutoPadder", "AutoPadder",
"EngChar2DPadder", "EngChar2DPadder",


"CollectFn",
"ConcatCollectFn", "ConcatCollectFn",
"Vocabulary", "Vocabulary",
@@ -97,4 +96,4 @@ from .tester import Tester
from .trainer import Trainer from .trainer import Trainer
from .utils import cache_results, seq_len_to_mask, get_seq_len from .utils import cache_results, seq_len_to_mask, get_seq_len
from .vocabulary import Vocabulary from .vocabulary import Vocabulary
from .collect_fn import CollectFn, ConcatCollectFn
from .collect_fn import ConcatCollectFn

+ 45
- 41
fastNLP/core/batch.py View File

@@ -9,17 +9,16 @@ __all__ = [
] ]


import atexit import atexit
from numbers import Number
import abc


from numbers import Number
import numpy as np import numpy as np
import torch import torch
import torch.utils.data import torch.utils.data
from collections import defaultdict from collections import defaultdict


from ._logger import logger
from .dataset import DataSet from .dataset import DataSet
from .sampler import SequentialSampler from .sampler import SequentialSampler
from .field import _get_ele_type_and_dim


_python_is_exit = False _python_is_exit = False


@@ -33,6 +32,9 @@ atexit.register(_set_python_is_exit)




class DataSetGetter: class DataSetGetter:
"""
传递给torch.utils.data.DataLoader获取数据,DataLoder会传入int的idx获取数据(调用这里的__getitem__()函数)。
"""
def __init__(self, dataset: DataSet, as_numpy=False): def __init__(self, dataset: DataSet, as_numpy=False):
self.dataset = dataset self.dataset = dataset
self.as_numpy = as_numpy self.as_numpy = as_numpy
@@ -56,7 +58,6 @@ class DataSetGetter:
:param batch: [[idx1, x_dict1, y_dict1], [idx2, x_dict2, y_dict2], [xx, xx, xx]] :param batch: [[idx1, x_dict1, y_dict1], [idx2, x_dict2, y_dict2], [xx, xx, xx]]
:return: :return:
""" """
# TODO 支持在DataSet中定义collate_fn,因为有时候可能需要不同的field之间融合,比如BERT的场景
indices = [] indices = []
sin_x, sin_y = defaultdict(list), defaultdict(list) sin_x, sin_y = defaultdict(list), defaultdict(list)
for idx, ins in ins_list: for idx, ins in ins_list:
@@ -67,24 +68,6 @@ class DataSetGetter:
if n in self.y_names: if n in self.y_names:
sin_y[n].append(v) sin_y[n].append(v)


def may_to_tensor(data):
dtype, dim = _get_ele_type_and_dim(data)
# print(dtype, type(dtype), str(dtype))
if not self.as_numpy:
try:
data, flag = _to_tensor(data, dtype)
except TypeError as e:
logger.error(f"Field {n} cannot be converted to torch.tensor.")
raise e
# if torch.is_tensor(data):
# str_dtype = str(dtype)
# if 'float' in str_dtype:
# data = data.float()
# elif 'int' in str_dtype:
# data = data.long()
# print(data.dtype)
return data

def pad(batch_dict): def pad(batch_dict):
result = {} result = {}
for n, vlist in batch_dict.items(): for n, vlist in batch_dict.items():
@@ -98,25 +81,13 @@ class DataSetGetter:
sin_x = pad(sin_x) sin_x = pad(sin_x)
sin_y = pad(sin_y) sin_y = pad(sin_y)


bx, by = self.dataset._collect_batch(ins_list)
def convert_tensor(batch_dict):
for n, v in batch_dict.items():
batch_dict[n] = may_to_tensor(v)

# collect_fn replaces single field
sin_x.update(bx)
sin_y.update(by)

convert_tensor(sin_x)
convert_tensor(sin_y)
if not self.dataset.collector.is_empty():
bx, by = self.dataset._collect_batch(ins_list)
sin_x.update(bx)
sin_y.update(by)


return (indices, sin_x, sin_y) return (indices, sin_x, sin_y)


def set_idx_list(self, idx_list):
if len(idx_list) != len(self.idx_list):
raise ValueError
self.idx_list = idx_list

def __getattr__(self, item): def __getattr__(self, item):
if hasattr(self.dataset, item): if hasattr(self.dataset, item):
return getattr(self.dataset, item) return getattr(self.dataset, item)
@@ -125,6 +96,10 @@ class DataSetGetter:




class SamplerAdapter(torch.utils.data.Sampler): class SamplerAdapter(torch.utils.data.Sampler):
"""
用于传入torch.utils.data.DataLoader中,DataLoader会调用__iter__()方法获取index(一次只取一个int)

"""
def __init__(self, sampler, dataset): def __init__(self, sampler, dataset):
super().__init__(dataset) super().__init__(dataset)
self.sampler = sampler self.sampler = sampler
@@ -138,6 +113,11 @@ class SamplerAdapter(torch.utils.data.Sampler):




class BatchIter: class BatchIter:
"""
Trainer用于迭代数据的类。继承该类,并实现get_num_batches(), get_batch_indices(), dataset(), num_batches(),
__iter__()方法。

"""
def __init__(self, dataset, batch_size=1, sampler=None, def __init__(self, dataset, batch_size=1, sampler=None,
num_workers=0, pin_memory=False, drop_last=False, num_workers=0, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None, collate_fn=None): timeout=0, worker_init_fn=None, collate_fn=None):
@@ -145,6 +125,8 @@ class BatchIter:
self.sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset) self.sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset)
else: else:
self.sampler = sampler self.sampler = sampler

# DataLoader的collect_fn输入是List[],里面的元素是dataset[index]返回的结果
if collate_fn is None: if collate_fn is None:
# pytoch <= 1.1 中不能设置collate_fn=None # pytoch <= 1.1 中不能设置collate_fn=None
self.dataiter = torch.utils.data.DataLoader( self.dataiter = torch.utils.data.DataLoader(
@@ -160,17 +142,25 @@ class BatchIter:
timeout=timeout, worker_init_fn=worker_init_fn) timeout=timeout, worker_init_fn=worker_init_fn)


# 以sampler的数量为准,因为DistributedSampler的时候每个进程上并不是所有的数据都用上了 # 以sampler的数量为准,因为DistributedSampler的时候每个进程上并不是所有的数据都用上了
self.num_batches = self.get_num_batches(len(self.dataiter.sampler), batch_size, drop_last)
self._num_batches = self.get_num_batches(len(self.dataiter.sampler), batch_size, drop_last)
self.batch_size = batch_size self.batch_size = batch_size
self.cur_batch_indices = None self.cur_batch_indices = None


@property
def num_batches(self):
return self._num_batches

@num_batches.setter
def num_batches(self, value):
self._num_batches = value

def init_iter(self): def init_iter(self):
pass pass


@staticmethod @staticmethod
def get_num_batches(num_samples, batch_size, drop_last): def get_num_batches(num_samples, batch_size, drop_last):
""" """
计算batch的数量。
计算batch的数量。用于前端显示进度


:param int num_samples: :param int num_samples:
:param int batch_size: :param int batch_size:
@@ -184,7 +174,7 @@ class BatchIter:


def get_batch_indices(self): def get_batch_indices(self):
""" """
获取当前已经输出的batch的index。
获取最近输出的batch的index。用于溯源当前batch的数据


:return: :return:
""" """
@@ -195,8 +185,22 @@ class BatchIter:


@property @property
def dataset(self): def dataset(self):
"""
获取正在参与iterate的dataset

:return:
"""
return self.dataiter.dataset return self.dataiter.dataset


@abc.abstractmethod
def __iter__(self):
"""
用于实际数据循环的类,返回值需要为两个dict, 第一个dict中的内容会认为是input, 第二个dict中的内容会认为是target

:return:
"""
raise NotImplemented



class DataSetIter(BatchIter): class DataSetIter(BatchIter):
""" """


+ 76
- 116
fastNLP/core/collect_fn.py View File

@@ -4,7 +4,8 @@ from builtins import sorted
import torch import torch
import numpy as np import numpy as np
from .field import _get_ele_type_and_dim from .field import _get_ele_type_and_dim
from collections import defaultdict
from .utils import logger
from copy import deepcopy




def _check_type(batch_dict, fields): def _check_type(batch_dict, fields):
@@ -36,127 +37,89 @@ def batching(samples, max_len=0, padding_val=0):




class Collector: class Collector:
"""
辅助DataSet管理collect_fn的类

"""
def __init__(self): def __init__(self):
self.fns = {}
self.input2fn = defaultdict(list)
self.output2fn = defaultdict(list)
self.fn2input = {}
self.fn2output = {}

def add_fn(self, fn, inputs, outputs, is_input, is_target):
for name in outputs:
if name in self.output2fn:
raise ValueError("Duplicated name: {} for CollectFn: {}".format(name, fn))

if fn.num_inputs() > 0 and len(inputs) != fn.num_inputs():
raise ValueError(
"Incorrect num of inputs, should be {} not {}".format(
fn.num_inputs(), len(inputs)
))

if fn.num_outputs() > 0 and len(outputs) != fn.num_outputs():
raise ValueError("Incorrect num of inputs, should be {} not {}".format(
fn.num_outputs(), len(outputs)))

self.fns[fn] = {'is_input': is_input, 'is_target': is_target}
for i, field in enumerate(inputs):
self.input2fn[field].append((fn, i))
for i, name in enumerate(outputs):
self.output2fn[name].append((fn, i))

def _rebuild_fn2io(self):
def transpose(name2fn):
fn2names = defaultdict(list)
for name, vlist in name2fn.items():
for fn, i in vlist:
fn2names[fn].append((name, i))
for fn, vlist in fn2names.items():
vlist = sorted(vlist, key=lambda x: x[1])
fn2names[fn] = [name for name, i in vlist]
return fn2names

self.fn2input = transpose(self.input2fn)
self.fn2output = transpose(self.output2fn)

def _clear_fn2io(self):
self.fn2input.clear()
self.fn2output.clear()
self.collect_fns = {}

def add_fn(self, fn, name=None):
"""
向collector新增一个collect_fn函数

:param callable fn:
:param str,int name:
:return:
"""
if name in self.collect_fns:
logger.warn(f"collect_fn:{name} will be overwritten.")
if name is None:
name = len(self.collect_fns)
self.collect_fns[name] = fn

def is_empty(self):
"""
返回是否包含collect_fn

:return:
"""
return len(self.collect_fns)==0

def delete_fn(self, name=None):
"""
删除collect_fn

:param str,int name: 如果为None就删除最近加入的collect_fn
:return:
"""
if not self.is_empty():
if name in self.collect_fns:
self.collect_fns.pop(name)
elif name is None:
last_key = list(self.collect_fns.keys())[0]
self.collect_fns.pop(last_key)


def collect_batch(self, ins_list): def collect_batch(self, ins_list):
if len(ins_list) == 0:
return {}, {}

if len(self.fn2output) == 0:
self._rebuild_fn2io()

bx = {}
by = {}
for fn, attr in self.fns.items():
inputs = self.fn2input.get(fn, None)
outputs = self.fn2output.get(fn, None)
res = fn.collect(ins_list, inputs, outputs)
if attr.get('is_input', False):
bx.update(res)
if attr.get('is_target', False):
by.update(res)
bx, by = {}, {}
for name, fn in self.collect_fns.items():
try:
batch_x, batch_y = fn(ins_list)
except BaseException as e:
logger.error(f"Exception:`{e}` happens when call collect_fn:`{name}`.")
raise e
bx.update(batch_x)
by.update(batch_y)
return bx, by return bx, by


def rename_field(self, old_f, new_f):
if new_f in self.input2fn:
# name conflict
raise ValueError
if old_f not in self.input2fn:
# renamed field not affect collectors
return
self.input2fn[new_f] = self.input2fn[old_f]
self._clear_fn2io()

def drop_field(self, f):
if f in self.input2fn:
raise ValueError

def outputs(self):
return self.output2fn.keys()

def copy_from(self, col): def copy_from(self, col):
assert isinstance(col, Collector) assert isinstance(col, Collector)
self.fns = col.fns.copy()
self.input2fn = col.input2fn.copy()
self.output2fn = col.output2fn.copy()
self._clear_fn2io()

class CollectFn:
def __init__(self):
self.fields = []

def collect(self, ins_list, inputs, outputs):
raise NotImplementedError
new_col = Collector()
new_col.collect_fns = deepcopy(col)
return new_col


def num_inputs(self):
return 0


def num_outputs(self):
return 0

@staticmethod
def get_batch_size(batch_dict):
if len(batch_dict) == 0:
return 0
return len(next(iter(batch_dict.values())))


class ConcatCollectFn(CollectFn):
class ConcatCollectFn:
""" """
field拼接Fn,将不同field按序拼接后,padding产生数据。所有field必须有相同的dim。
field拼接collect_fn,将不同field按序拼接后,padding产生数据。


:param List[str] inputs: 将哪些field的数据拼接起来, 目前仅支持1d的field
:param str output: 拼接后的field名称
:param pad_val: padding的数值 :param pad_val: padding的数值
:param max_len: 拼接后最大长度 :param max_len: 拼接后最大长度
:param is_input: 是否将生成的output设置为input
:param is_target: 是否将生成的output设置为target
""" """


def __init__(self, pad_val=0, max_len=0):
def __init__(self, inputs, output, pad_val=0, max_len=0, is_input=True, is_target=False):
super().__init__() super().__init__()
assert isinstance(inputs, list)
self.inputs = inputs
self.output = output
self.pad_val = pad_val self.pad_val = pad_val
self.max_len = max_len self.max_len = max_len
self.is_input = is_input
self.is_target = is_target


@staticmethod @staticmethod
def _to_numpy(seq): def _to_numpy(seq):
@@ -165,21 +128,18 @@ class ConcatCollectFn(CollectFn):
else: else:
return np.array(seq) return np.array(seq)


def collect(self, ins_list, inputs, outputs):
def __call__(self, ins_list):
samples = [] samples = []
for i, ins in ins_list: for i, ins in ins_list:
sample = [] sample = []
for i in inputs:
sample.append(self._to_numpy(ins[i]))
for input_name in self.inputs:
sample.append(self._to_numpy(ins[input_name]))
samples.append(np.concatenate(sample, axis=0)) samples.append(np.concatenate(sample, axis=0))
seq_len = [s.shape[0] for s in samples]
batch = batching(samples, max_len=self.max_len, padding_val=self.pad_val) batch = batching(samples, max_len=self.max_len, padding_val=self.pad_val)
o1, o2 = outputs
return {o1: batch, o2: seq_len}
def num_inputs(self):
return 0
b_x, b_y = {}, {}
if self.is_input:
b_x[self.output] = batch
if self.is_target:
b_y[self.output] = batch


def num_outputs(self):
# (concat_words, seq_len)
return 2
return b_x, b_y

+ 90
- 22
fastNLP/core/dataset.py View File

@@ -281,6 +281,75 @@
# 也可以设置pad的value # 也可以设置pad的value
dataset.set_pad_val('chars', -1) dataset.set_pad_val('chars', -1)


3.3 根据DataSet中多个field合成新的field
--------------------------------------

DataSet支持在进行batch时,默认只能看到当前的field的值,但在某些训练中可能存在以下的情况: (1)需要两个field拼接成为一个field;
(2)需要在batch中进行负采样。这时候就需要能够同时利用多个field进行batch的操作,DataSet中的add_collect_fn()函数支持添加
自定义涉及多个field的collect_fn函数。例如下例中将两个field拼接成一个field的场景

.. code-block::

from fastNLP import DataSet, DataSetIter
import torch

data = DataSet({
'x1': [[0, 1],
[2]],
'x2': [[3],
[2, 4, 5]],
'y': [0, 1]
})
data.set_target('y')

# 所有的collect_fn函数都接受list[(ind1, instance1), (ind2, instance2), ...]作为输入,其中ind1/ind2是该instance在dataset中
# 的index,instance1/instance2是这次batch取出来的数据,包含了所有的field.
def concat_collect_fn(ins_list):
x1 = [ins['x1'] for ind,ins in ins_list]
x2 = [ins['x2'] for ind,ins in ins_list]
xs = []
for i in range(len(ins_list)):
xs.append(torch.LongTensor(x1[i] + x2[i]))
# 需要自行pad并转换为tensor,但不需要移动到gpu
arr = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=0)
b_x = {'x': arr}
b_y = {}
# 返回值一定是两个dict,第一个dict的值会认为是input,第二个dict的值会认为是target. 若名称与已有input或target重复,则
# 采用返回值。
return b_x, b_y

data.add_collect_fn(concat_collect_fn)

for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2):
print("batch_x:", batch_x)
print("batch_y:", batch_y)
# batch_x: {'x': tensor([[0, 1, 3, 0],
# [2, 2, 4, 5]])}
# batch_y: {'y': array([0, 1])}

# 如果取batch过程含有一些参数,可以通过类来实现
class ConCollectFn:
def __init__(self, max_len=3):
self.max_len = max_len

def __call__(self, ins_list): # 实现该类的__call__函数
x1 = [ins['x1'] for ind, ins in ins_list]
x2 = [ins['x2'] for ind, ins in ins_list]
xs = []
for i in range(len(ins_list)):
xs.append(torch.LongTensor(x1[i] + x2[i])[:self.max_len])
arr = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=0)
b_x = {'x': arr}
b_y = {}
return b_x, b_y
data.delete_collect_fn() # 删除之前的collect_fn
data.add_collect_fn(ConCollectFn(max_len=3))
for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2):
print("batch_x:", batch_x)
print("batch_y:", batch_y)
# batch_x: {'x': tensor([[0, 1, 3],
# [2, 2, 4]])}
# batch_y: {'y': array([0, 1])}


""" """
__all__ = [ __all__ = [
@@ -300,7 +369,6 @@ from .field import AutoPadder
from .field import FieldArray from .field import FieldArray
from .field import SetInputOrTargetException from .field import SetInputOrTargetException
from .instance import Instance from .instance import Instance
from .utils import _get_func_signature
from .utils import pretty_table_printer from .utils import pretty_table_printer
from .collect_fn import Collector from .collect_fn import Collector


@@ -394,6 +462,7 @@ class DataSet(object):
for field in self.field_arrays.values(): for field in self.field_arrays.values():
data_set.add_field(field_name=field.name, fields=field.content[idx], padder=field.padder, data_set.add_field(field_name=field.name, fields=field.content[idx], padder=field.padder,
is_input=field.is_input, is_target=field.is_target, ignore_type=field.ignore_type) is_input=field.is_input, is_target=field.is_target, ignore_type=field.ignore_type)
data_set.collector = self.collector.copy_from(self.collector)
return data_set return data_set
elif isinstance(idx, str): elif isinstance(idx, str):
if idx not in self: if idx not in self:
@@ -407,6 +476,7 @@ class DataSet(object):
dataset.append(instance) dataset.append(instance)
for field_name, field in self.field_arrays.items(): for field_name, field in self.field_arrays.items():
dataset.field_arrays[field_name].to(field) dataset.field_arrays[field_name].to(field)
dataset.collector = self.collector.copy_from(self.collector)
return dataset return dataset
else: else:
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx)))
@@ -575,7 +645,6 @@ class DataSet(object):
:param str field_name: 需要删除的field的名称. :param str field_name: 需要删除的field的名称.
""" """
self.field_arrays.pop(field_name) self.field_arrays.pop(field_name)
self.collector.drop_field(field_name)
return self return self


def copy_field(self, field_name, new_field_name): def copy_field(self, field_name, new_field_name):
@@ -648,7 +717,6 @@ class DataSet(object):
if field_name in self.field_arrays: if field_name in self.field_arrays:
self.field_arrays[new_field_name] = self.field_arrays.pop(field_name) self.field_arrays[new_field_name] = self.field_arrays.pop(field_name)
self.field_arrays[new_field_name].name = new_field_name self.field_arrays[new_field_name].name = new_field_name
self.collector.rename_field(field_name, new_field_name)
else: else:
raise KeyError("DataSet has no field named {}.".format(field_name)) raise KeyError("DataSet has no field named {}.".format(field_name))
return self return self
@@ -1040,30 +1108,30 @@ class DataSet(object):
assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d)) assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d))
return d return d


def add_collect_fn(self, fn, inputs, outputs, is_input, is_target):
def add_collect_fn(self, fn, name=None):
""" """
添加 CollectFn,使用多个field产生batch中的数据
添加 CollectFn,collect_fn允许在生成的batch的过程中动态生成一些数据(在DataSetIter作为迭代器的情况下有效,默认情况下就是用的
这个)。支持依次添加多个collect_fn, 如果相同的key,后面的collect_fn的结果覆盖前面的collect_fn的结果。


:param CollectFn fn: 定义产生数据的方式
:param list inputs: 生成的数据在batch中的名称
:param list outputs: 用于产生数据的 fields,有序
:param bool is_input: 是否出现在input中,为否则出现在target batch中
:param bool is_target:
:param callable fn: 传入一个可调用的function, 该function可接受的参数为List[(ind1, instance1), (ind2, instance2)]
(某个batch被选中的所有的indice以及instance),其中ind1/ind2是该instance在dataset中的index,instance1/instance2是
这次batch取出来的数据,包含了所有的field。返回值需要为两个dict,第一个dict的值将被认为是input,第二个dict的值被认为是
target,返回的值至多允许一个空dict。若返回的dict中包含了被设置为input或target的field的名称,将覆盖dataset中的field。
fastNLP不会将collect_fn的返回结果pad和转换为tensor,需要在collect_fn中完成pad和转换为tensor(不需要将tensor移动到
gpu中,如果是pytorch的tensor,fastNLP会自动将其移动到特定gpu)。不要修改传入collect_fn中的数据,否则可能导致未知问题。
:param str,int name: collect_fn的名称,如果不传入,默认使用自增长的数字作为key。相同的name会覆盖之前的collect_fn。
""" """
def check_fields(fields):
for f in fields:
if f not in self.field_arrays:
raise ValueError(f)

def check_name(names):
for name in names:
if name in self.field_arrays:
logger.warning('name of collect_fn will cover the field name in dataset')
assert callable(fn), "You must pass in a callable object."
self.collector.add_fn(fn, name=name)


check_fields(inputs)
check_name(outputs)
def delete_collect_fn(self, name=None):
"""
删除某个collect_fn


self.collector.add_fn(fn, inputs, outputs, is_input, is_target)
:param str,int name: 如果为None,则删除最近加入的collect_fn
:return:
"""
self.collector.delete_fn(name)


def _collect_batch(self, ins_list): def _collect_batch(self, ins_list):
return self.collector.collect_batch(ins_list) return self.collector.collect_batch(ins_list)

+ 14
- 7
fastNLP/core/field.py View File

@@ -191,24 +191,31 @@ class FieldArray:
def get(self, indices, pad=True): def get(self, indices, pad=True):
""" """
根据给定的indices返回内容
根据给定的indices返回内容


:param int,List[int] indices: 获取indices对应的内容。 :param int,List[int] indices: 获取indices对应的内容。
:param bool pad: 是否对返回的结果进行padding。仅对indices为List[int]时有效
:return: 根据给定的indices返回的内容,可能是单个值或List
:param bool pad: 是否对返回的结果进行padding。仅对: (1) indices为List[int]; (2)padder不为None; (3)field设置了input
或target,有效
:return: 根据给定的indices返回的内容,可能是单个值或ndarray
""" """
if isinstance(indices, int): if isinstance(indices, int):
return self.content[indices] return self.content[indices]
if self.is_input is False and self.is_target is False:
raise RuntimeError("Please specify either is_input or is_target to True for {}".format(self.name))

contents = [self.content[i] for i in indices] contents = [self.content[i] for i in indices]
if self.padder is None or pad is False: if self.padder is None or pad is False:
return np.array(contents) return np.array(contents)
else:
elif self.is_input or self.is_target:
return self.pad(contents) return self.pad(contents)
else:
return np.array(contents)
def pad(self, contents): def pad(self, contents):
"""
传入list的contents,将contents使用padder进行padding,contents必须为从本FieldArray中取出的。

:param list contents:
:return:
"""
return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim) return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim)
def set_padder(self, padder): def set_padder(self, padder):


+ 1
- 1
fastNLP/core/tester.py View File

@@ -71,7 +71,7 @@ class Tester(object):
def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1, use_tqdm=True): def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1, use_tqdm=True):
""" """
:param ~fastNLP.DataSet data: 需要测试的数据集
:param ~fastNLP.DataSet,~fastNLP.BatchIter data: 需要测试的数据集
:param torch.nn.Module model: 使用的模型 :param torch.nn.Module model: 使用的模型
:param ~fastNLP.core.metrics.MetricBase,List[~fastNLP.core.metrics.MetricBase] metrics: 测试时使用的metrics :param ~fastNLP.core.metrics.MetricBase,List[~fastNLP.core.metrics.MetricBase] metrics: 测试时使用的metrics
:param int batch_size: evaluation时使用的batch_size有多大。 :param int batch_size: evaluation时使用的batch_size有多大。


+ 1
- 1
fastNLP/core/trainer.py View File

@@ -375,7 +375,7 @@ class Trainer(object):
callbacks=None, check_code_level=0, **kwargs): callbacks=None, check_code_level=0, **kwargs):
""" """
:param train_data: 训练集, :class:`~fastNLP.DataSet` 类型
:param train_data: 训练集, :class:`~fastNLP.DataSet` 类型或 :class:`~fastNLP.BatchIter`的子类
:param nn.modules model: 待训练的模型 :param nn.modules model: 待训练的模型
:param optimizer: `torch.optim.Optimizer` 优化器。如果为None,则Trainer使用默认的Adam(model.parameters(), lr=4e-3)这个优化器 :param optimizer: `torch.optim.Optimizer` 优化器。如果为None,则Trainer使用默认的Adam(model.parameters(), lr=4e-3)这个优化器
:param int batch_size: 训练和验证的时候的batch大小。 :param int batch_size: 训练和验证的时候的batch大小。


+ 8
- 7
fastNLP/core/utils.py View File

@@ -624,9 +624,11 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
if check_res.unused: if check_res.unused:
_tmp = f"Check key assignment for `{input_func_map.get(_miss,_miss)}` when initialize {module_name}." _tmp = f"Check key assignment for `{input_func_map.get(_miss,_miss)}` when initialize {module_name}."
if _tmp: if _tmp:
_tmp += f' Or provide `{_miss}` in DataSet or output of {prev_func_signature}.'
_tmp += f' Or provide `{_miss}` in DataSet or the output of {prev_func_signature}. '
else: else:
_tmp = f'Provide `{_miss}` in DataSet or output of {prev_func_signature}.'
_tmp = f'Provide `{_miss}` in DataSet or the output of {prev_func_signature}.'
if not dataset.collector.is_empty():
_tmp += f'Or you need to add `{_miss}` in the output of your collect_fn. '
suggestions.append(_tmp) suggestions.append(_tmp)


if check_res.duplicated: if check_res.duplicated:
@@ -683,12 +685,11 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level):
else: else:
_miss_out_dataset.append(_miss) _miss_out_dataset.append(_miss)
if _miss_in_dataset: if _miss_in_dataset:
suggestions.append(f"You might need to set {_miss_in_dataset} as input. ")
suggestions.append(f"You might need to set `{_miss_in_dataset}` as input. ")
if _miss_out_dataset: if _miss_out_dataset:
_tmp = f"You need to provide {_miss_out_dataset} in DataSet and set it as input. "
# if check_res.unused:
# _tmp += f"Or you might find it in `unused field:`, you can use DataSet.rename_field() to " \
# f"rename the field in `unused field:`."
_tmp = f"You need to provide `{_miss_out_dataset}` in DataSet and set it as input. "
if not dataset.collector.is_empty():
_tmp += f'Or you need to add `{_miss_out_dataset}` in the output of your collect_fn. '
suggestions.append(_tmp) suggestions.append(_tmp)


if check_res.unused: if check_res.unused:


+ 118
- 8
test/core/test_batch.py View File

@@ -158,20 +158,130 @@ class TestCase1(unittest.TestCase):
dataset.set_input('1','2') dataset.set_input('1','2')
dataset.set_target('0','3') dataset.set_target('0','3')


fn = ConcatCollectFn()
dataset.add_collect_fn(fn, inputs=['1', '2'],
outputs=['12', 'seq_len'],
is_input=True, is_target=False)

fn = ConcatCollectFn(inputs=['1', '2'], output='12', pad_val=0, max_len=0, is_input=True, is_target=False)
dataset.add_collect_fn(fn, name='demo')
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True)
for batch_x, batch_y in batch: for batch_x, batch_y in batch:
for i in range(batch_size): for i in range(batch_size):
# print(i) # print(i)
self.assertEqual(batch_x['12'][i].sum(), batch_x['1'][i].sum() + batch_x['2'][i].sum()) self.assertEqual(batch_x['12'][i].sum(), batch_x['1'][i].sum() + batch_x['2'][i].sum())
self.assertEqual(
batch_x['seq_len'][i],
(batch_x['1'][i]!=0).sum() + (batch_x['2'][i]!=0).sum())
dataset.delete_collect_fn(name='demo')

# 测试非input的情况
dataset.set_input('1', '2', flag=False) #
fn = ConcatCollectFn(inputs=['1', '2'], output='12', pad_val=0, max_len=0, is_input=True, is_target=False)
dataset.add_collect_fn(fn, name='demo')
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True)
for batch_x, batch_y in batch:
for i in range(batch_size):
self.assertTrue('12' in batch_x)
dataset.delete_collect_fn(name='demo')
dataset.set_input('1', '2', flag=True) #

# 测试覆盖其它field的情况
fn = ConcatCollectFn(inputs=['1', '2'], output='3', pad_val=0, max_len=0, is_input=True, is_target=True)
dataset.add_collect_fn(fn, name='demo')
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True)
for batch_x, batch_y in batch:
for i in range(batch_size):
# print(i)
self.assertEqual(batch_y['3'][i].sum(), batch_x['1'][i].sum() + batch_x['2'][i].sum())
dataset.delete_collect_fn(name='demo')

# 测试非input,target的情况
dataset.set_input('1', '2', flag=False)
fn = ConcatCollectFn(inputs=['1', '2'], output='3', pad_val=0, max_len=0, is_input=True, is_target=True)
dataset.add_collect_fn(fn, name='demo')
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True)
for batch_x, batch_y in batch:
for i in range(batch_size):
# print(i)
self.assertTrue('3' in batch_x)
self.assertTrue('3' in batch_y)
dataset.delete_collect_fn(name='demo')

# 测试加入非法fn的请
with self.assertRaises(AssertionError):
dataset.add_collect_fn(1)

# 测试collect_fn返回值只有一个的情况
def demo_collect_fn(ins_list):
return {'3':1}
dataset.add_collect_fn(demo_collect_fn, name='demo')
with self.assertRaises(BaseException):
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True)
for batch_x, batch_y in batch:
pass
dataset.delete_collect_fn(name='demo')

# 测试多个collect_fn
dataset.add_collect_fn(demo_collect_fn, name='demo')
dataset.add_collect_fn(demo_collect_fn, name='demo')
# 测试删除
dataset.delete_collect_fn()
dataset.delete_collect_fn()
self.assertTrue(dataset.collector.is_empty())

def test_demo(self):
import torch

data = DataSet({
'x1': [[0, 1],
[2]],
'x2': [[3],
[2, 4, 5]
],
'y': [0, 1]
})
data.set_target('y')

# 所有的collect_fn函数都接受list[(ind1, instance1), (ind2, instance2), ...]作为输入,其中ind1/ind2是该instance在dataset中
# 的index,instance1/instance2是这次batch取出来的数据,包含了所有的field.
def concat_collect_fn(ins_list):
x1 = [ins['x1'] for ind,ins in ins_list]
x2 = [ins['x2'] for ind,ins in ins_list]
xs = []
for i in range(len(ins_list)):
xs.append(torch.LongTensor(x1[i] + x2[i]))
# 需要自行pad并转换为tensor,但不需要移动到gpu
arr = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=0)
b_x = {'x': arr}
b_y = {}
# 返回值一定是两个dict,第一个dict的值会认为是input,第二个dict的值会认为是target. 若名称与已有input或target重复,则
# 采用返回值。
return b_x, b_y

data.add_collect_fn(concat_collect_fn)

for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2):
print("batch_x:", batch_x)
print("batch_y:", batch_y)
# batch_x: {'x': tensor([[0, 1, 3, 0],
# [2, 2, 4, 5]])}
# batch_y: {'y': array([0, 1])}


# 如果取batch过程含有一些参数,可以通过类来实现
class ConCollectFn:
def __init__(self, max_len=3):
self.max_len = max_len
def __call__(self, ins_list):
x1 = [ins['x1'] for ind, ins in ins_list]
x2 = [ins['x2'] for ind, ins in ins_list]
xs = []
for i in range(len(ins_list)):
xs.append(torch.LongTensor(x1[i] + x2[i])[:self.max_len])
arr = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=0)
b_x = {'x': arr}
b_y = {}
return b_x, b_y
data.delete_collect_fn() # 删除之前的collect_fn
data.add_collect_fn(ConCollectFn(max_len=3))
for batch_x, batch_y in DataSetIter(data, sampler=SequentialSampler(), batch_size=2):
print("batch_x:", batch_x)
print("batch_y:", batch_y)
# batch_x: {'x': tensor([[0, 1, 3],
# [2, 2, 4]])}
# batch_y: {'y': array([0, 1])}


def testTensorLoaderIter(self): def testTensorLoaderIter(self):
class FakeData: class FakeData:


Loading…
Cancel
Save