Browse Source

[add] collect_fn for multi field combined inputs

tags/v0.5.5
yunfan 5 years ago
parent
commit
127d6d194b
6 changed files with 204 additions and 11 deletions
  1. +3
    -0
      fastNLP/__init__.py
  2. +4
    -0
      fastNLP/core/__init__.py
  3. +35
    -11
      fastNLP/core/batch.py
  4. +118
    -0
      fastNLP/core/collect_fn.py
  5. +28
    -0
      fastNLP/core/dataset.py
  6. +16
    -0
      test/core/test_batch.py

+ 3
- 0
fastNLP/__init__.py View File

@@ -44,6 +44,9 @@ __all__ = [
"AutoPadder",
"EngChar2DPadder",

"CollectFn",
"ConcatCollectFn",

"MetricBase",
"AccuracyMetric",
"SpanFPreRecMetric",


+ 4
- 0
fastNLP/core/__init__.py View File

@@ -20,6 +20,9 @@ __all__ = [
"Padder",
"AutoPadder",
"EngChar2DPadder",

"CollectFn",
"ConcatCollectFn",
"Vocabulary",
@@ -94,3 +97,4 @@ from .tester import Tester
from .trainer import Trainer
from .utils import cache_results, seq_len_to_mask, get_seq_len
from .vocabulary import Vocabulary
from .collect_fn import CollectFn, ConcatCollectFn

+ 35
- 11
fastNLP/core/batch.py View File

@@ -65,25 +65,49 @@ class DataSetGetter:
for n, v in y.items():
batch_y[n].append(v)

def may_to_tensor(data):
if not self.as_numpy:
try:
data, flag = _to_tensor(data, data.dtype)
except TypeError as e:
logger.error(f"Field {n} cannot be converted to torch.tensor.")
raise e
return data

def pad_collect(batch_dict):
batch_x, batch_y = self.dataset._collect_batch(batch_dict)
for b in [batch_x, batch_y]:
for n in b.keys():
b[n] = may_to_tensor(b[n])
return batch_x, batch_y

def pad_batch(batch_dict, field_array):
result = {}
for n, vlist in batch_dict.items():
f = field_array[n]
if f.padder is None:
batch_dict[n] = np.array(vlist)
result[n] = np.array(vlist)
else:
data = f.pad(vlist)
if not self.as_numpy:
try:
data, flag = _to_tensor(data, f.dtype)
except TypeError as e:
logger.error(f"Field {n} cannot be converted to torch.tensor.")
raise e
batch_dict[n] = data
return batch_dict
result[n] = may_to_tensor(data)
return result

# do padding on field_array
pad_batch_x = pad_batch(batch_x, self.inputs)
pad_batch_y = pad_batch(batch_y, self.targets)

# do padding on dataset collect_fn
batch_dict = batch_x.copy()
batch_dict.update(batch_y)
pad_dict_x, pad_dict_y = pad_collect(batch_dict)

# group together
pad_batch_x.update(pad_dict_x)
pad_batch_y.update(pad_dict_y)

return (indices,
pad_batch(batch_x, self.inputs),
pad_batch(batch_y, self.targets))
pad_batch_x,
pad_batch_y)

def set_idx_list(self, idx_list):
if len(idx_list) != len(self.idx_list):


+ 118
- 0
fastNLP/core/collect_fn.py View File

@@ -0,0 +1,118 @@
import torch
import numpy as np
from .field import _get_ele_type_and_dim


def _check_type(batch_dict, fields):
if len(fields) == 0:
raise RuntimeError
types = []
dims = []
for f in fields:
t, d = _get_ele_type_and_dim(batch_dict[f])
types.append(t)
dims.append(d)
diff_types = set(types)
diff_dims = set(dims)
if len(diff_types) > 1 or len(diff_dims) > 1:
raise ValueError
return types[0]


def batching(samples, max_len=0, padding_val=0):
if len(samples) == 0:
return samples
if max_len <= 0:
max_len = max(s.shape[0] for s in samples)
batch = np.full((len(samples), max_len), fill_value=padding_val)
for i, s in enumerate(samples):
slen = min(s.shape[0], max_len)
batch[i][:slen] = s[:slen]
return batch


class Collector:
def __init__(self):
self.fns = []
self.names = []
self.fields_list = []
self.is_input = []

def add_fn(self, fn, name, fields, is_input):
if name in self.names:
raise ValueError("Duplicated name: {} for CollectFn: {}".format(name, fn))
if fn.num_fields() > 0 and len(fields) != fn.num_fields():
raise ValueError(
"Incorrect num of fields, should be {} not {}".format(
fn.num_fields(), len(fields)
))

self.fns.append(fn)
self.names.append(name)
self.fields_list.append(fields)
self.is_input.append(is_input)

def collect_batch(self, batch_dict):
if len(batch_dict) == 0:
return {}, {}
batch_x, batch_y = {}, {}
for fn, name, fields, is_input in zip(self.fns, self.names, self.fields_list, self.is_input):
batch = fn.collect(batch_dict, fields)
if is_input:
batch_x[name] = batch
else:
batch_y[name] = batch
return batch_x, batch_y


class CollectFn:
def __init__(self):
self.fields = []

def collect(self, batch_dict, fields):
raise NotImplementedError

def num_fields(self):
return 0

@staticmethod
def get_batch_size(batch_dict):
if len(batch_dict) == 0:
return 0
return len(next(iter(batch_dict.values())))


class ConcatCollectFn(CollectFn):
"""
field拼接Fn,将不同field按序拼接后,padding产生数据。所有field必须有相同的dim。

:param pad_val: padding的数值
:param max_len: 拼接后最大长度
"""

def __init__(self, pad_val=0, max_len=0):
super().__init__()
self.pad_val = pad_val
self.max_len = max_len

def collect(self, batch_dict, fields):
samples = []
dtype = _check_type(batch_dict, fields)
batch_size = self.get_batch_size(batch_dict)
for i in range(batch_size):
sample = []
for n in fields:
seq = batch_dict[n][i]
if str(dtype).startswith('torch'):
seq = seq.numpy()
else:
seq = np.array(seq, dtype=dtype)
sample.append(seq)
samples.append(np.concatenate(sample, axis=0))
batch = batching(samples, max_len=self.max_len, padding_val=self.pad_val)
if str(dtype).startswith('torch'):
batch = torch.tensor(batch, dtype=dtype)
return batch

def num_fields(self):
return 0

+ 28
- 0
fastNLP/core/dataset.py View File

@@ -302,6 +302,7 @@ from .field import SetInputOrTargetException
from .instance import Instance
from .utils import _get_func_signature
from .utils import pretty_table_printer
from .collect_fn import Collector


class DataSet(object):
@@ -331,6 +332,7 @@ class DataSet(object):

else:
raise ValueError("data only be dict or list type.")
self.collector = Collector()

def __contains__(self, item):
return item in self.field_arrays
@@ -954,3 +956,29 @@ class DataSet(object):
d = pickle.load(f)
assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d))
return d

def add_collect_fn(self, fn, name, fields, is_input=True):
"""
添加 CollectFn,使用多个field产生batch中的数据

:param CollectFn fn: 定义产生数据的方式
:param str name: 生成的数据在batch中的名称
:param list fields: 用于产生数据的 fields,有序
:param bool is_input: 是否出现在input中,为否则出现在target batch中
"""
def check_fields(fields):
for f in fields:
if f not in self.field_arrays:
raise ValueError(f)

def check_name(name):
if name in self.field_arrays:
logger.warning('name of collect_fn will cover the field name in dataset')

check_fields(fields)
check_name(name)

self.collector.add_fn(fn, name, fields, is_input)

def _collect_batch(self, batch_dict):
return self.collector.collect_batch(batch_dict)

+ 16
- 0
test/core/test_batch.py View File

@@ -7,6 +7,7 @@ from fastNLP import DataSetIter, TorchLoaderIter
from fastNLP import DataSet
from fastNLP import Instance
from fastNLP import SequentialSampler
from fastNLP import ConcatCollectFn


def generate_fake_dataset(num_samples=1000):
@@ -150,6 +151,21 @@ class TestCase1(unittest.TestCase):
for batch_x, batch_y in batch:
pass

def test_collect_fn(self):
batch_size = 32
num_samples = 1000
dataset = generate_fake_dataset(num_samples)
dataset.set_input('1','2')
fn = ConcatCollectFn()
dataset.add_collect_fn(fn, '12', fields=['1', '2'], is_input=True)

batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True)
for batch_x, batch_y in batch:
for i in range(batch_size):
# print(i)
self.assertEqual(batch_x['12'][i].sum(), batch_x['1'][i].sum() + batch_x['2'][i].sum())


def testTensorLoaderIter(self):
class FakeData:
def __init__(self, return_dict=True):


Loading…
Cancel
Save