Browse Source

[add] collect_fn for multi field combined inputs

tags/v0.5.5
yunfan 5 years ago
parent
commit
127d6d194b
6 changed files with 204 additions and 11 deletions
  1. +3
    -0
      fastNLP/__init__.py
  2. +4
    -0
      fastNLP/core/__init__.py
  3. +35
    -11
      fastNLP/core/batch.py
  4. +118
    -0
      fastNLP/core/collect_fn.py
  5. +28
    -0
      fastNLP/core/dataset.py
  6. +16
    -0
      test/core/test_batch.py

+ 3
- 0
fastNLP/__init__.py View File

@@ -44,6 +44,9 @@ __all__ = [
"AutoPadder", "AutoPadder",
"EngChar2DPadder", "EngChar2DPadder",


"CollectFn",
"ConcatCollectFn",

"MetricBase", "MetricBase",
"AccuracyMetric", "AccuracyMetric",
"SpanFPreRecMetric", "SpanFPreRecMetric",


+ 4
- 0
fastNLP/core/__init__.py View File

@@ -20,6 +20,9 @@ __all__ = [
"Padder", "Padder",
"AutoPadder", "AutoPadder",
"EngChar2DPadder", "EngChar2DPadder",

"CollectFn",
"ConcatCollectFn",
"Vocabulary", "Vocabulary",
@@ -94,3 +97,4 @@ from .tester import Tester
from .trainer import Trainer from .trainer import Trainer
from .utils import cache_results, seq_len_to_mask, get_seq_len from .utils import cache_results, seq_len_to_mask, get_seq_len
from .vocabulary import Vocabulary from .vocabulary import Vocabulary
from .collect_fn import CollectFn, ConcatCollectFn

+ 35
- 11
fastNLP/core/batch.py View File

@@ -65,25 +65,49 @@ class DataSetGetter:
for n, v in y.items(): for n, v in y.items():
batch_y[n].append(v) batch_y[n].append(v)


def may_to_tensor(data):
if not self.as_numpy:
try:
data, flag = _to_tensor(data, data.dtype)
except TypeError as e:
logger.error(f"Field {n} cannot be converted to torch.tensor.")
raise e
return data

def pad_collect(batch_dict):
batch_x, batch_y = self.dataset._collect_batch(batch_dict)
for b in [batch_x, batch_y]:
for n in b.keys():
b[n] = may_to_tensor(b[n])
return batch_x, batch_y

def pad_batch(batch_dict, field_array): def pad_batch(batch_dict, field_array):
result = {}
for n, vlist in batch_dict.items(): for n, vlist in batch_dict.items():
f = field_array[n] f = field_array[n]
if f.padder is None: if f.padder is None:
batch_dict[n] = np.array(vlist)
result[n] = np.array(vlist)
else: else:
data = f.pad(vlist) data = f.pad(vlist)
if not self.as_numpy:
try:
data, flag = _to_tensor(data, f.dtype)
except TypeError as e:
logger.error(f"Field {n} cannot be converted to torch.tensor.")
raise e
batch_dict[n] = data
return batch_dict
result[n] = may_to_tensor(data)
return result

# do padding on field_array
pad_batch_x = pad_batch(batch_x, self.inputs)
pad_batch_y = pad_batch(batch_y, self.targets)

# do padding on dataset collect_fn
batch_dict = batch_x.copy()
batch_dict.update(batch_y)
pad_dict_x, pad_dict_y = pad_collect(batch_dict)

# group together
pad_batch_x.update(pad_dict_x)
pad_batch_y.update(pad_dict_y)


return (indices, return (indices,
pad_batch(batch_x, self.inputs),
pad_batch(batch_y, self.targets))
pad_batch_x,
pad_batch_y)


def set_idx_list(self, idx_list): def set_idx_list(self, idx_list):
if len(idx_list) != len(self.idx_list): if len(idx_list) != len(self.idx_list):


+ 118
- 0
fastNLP/core/collect_fn.py View File

@@ -0,0 +1,118 @@
import torch
import numpy as np
from .field import _get_ele_type_and_dim


def _check_type(batch_dict, fields):
if len(fields) == 0:
raise RuntimeError
types = []
dims = []
for f in fields:
t, d = _get_ele_type_and_dim(batch_dict[f])
types.append(t)
dims.append(d)
diff_types = set(types)
diff_dims = set(dims)
if len(diff_types) > 1 or len(diff_dims) > 1:
raise ValueError
return types[0]


def batching(samples, max_len=0, padding_val=0):
if len(samples) == 0:
return samples
if max_len <= 0:
max_len = max(s.shape[0] for s in samples)
batch = np.full((len(samples), max_len), fill_value=padding_val)
for i, s in enumerate(samples):
slen = min(s.shape[0], max_len)
batch[i][:slen] = s[:slen]
return batch


class Collector:
def __init__(self):
self.fns = []
self.names = []
self.fields_list = []
self.is_input = []

def add_fn(self, fn, name, fields, is_input):
if name in self.names:
raise ValueError("Duplicated name: {} for CollectFn: {}".format(name, fn))
if fn.num_fields() > 0 and len(fields) != fn.num_fields():
raise ValueError(
"Incorrect num of fields, should be {} not {}".format(
fn.num_fields(), len(fields)
))

self.fns.append(fn)
self.names.append(name)
self.fields_list.append(fields)
self.is_input.append(is_input)

def collect_batch(self, batch_dict):
if len(batch_dict) == 0:
return {}, {}
batch_x, batch_y = {}, {}
for fn, name, fields, is_input in zip(self.fns, self.names, self.fields_list, self.is_input):
batch = fn.collect(batch_dict, fields)
if is_input:
batch_x[name] = batch
else:
batch_y[name] = batch
return batch_x, batch_y


class CollectFn:
def __init__(self):
self.fields = []

def collect(self, batch_dict, fields):
raise NotImplementedError

def num_fields(self):
return 0

@staticmethod
def get_batch_size(batch_dict):
if len(batch_dict) == 0:
return 0
return len(next(iter(batch_dict.values())))


class ConcatCollectFn(CollectFn):
"""
field拼接Fn,将不同field按序拼接后,padding产生数据。所有field必须有相同的dim。

:param pad_val: padding的数值
:param max_len: 拼接后最大长度
"""

def __init__(self, pad_val=0, max_len=0):
super().__init__()
self.pad_val = pad_val
self.max_len = max_len

def collect(self, batch_dict, fields):
samples = []
dtype = _check_type(batch_dict, fields)
batch_size = self.get_batch_size(batch_dict)
for i in range(batch_size):
sample = []
for n in fields:
seq = batch_dict[n][i]
if str(dtype).startswith('torch'):
seq = seq.numpy()
else:
seq = np.array(seq, dtype=dtype)
sample.append(seq)
samples.append(np.concatenate(sample, axis=0))
batch = batching(samples, max_len=self.max_len, padding_val=self.pad_val)
if str(dtype).startswith('torch'):
batch = torch.tensor(batch, dtype=dtype)
return batch

def num_fields(self):
return 0

+ 28
- 0
fastNLP/core/dataset.py View File

@@ -302,6 +302,7 @@ from .field import SetInputOrTargetException
from .instance import Instance from .instance import Instance
from .utils import _get_func_signature from .utils import _get_func_signature
from .utils import pretty_table_printer from .utils import pretty_table_printer
from .collect_fn import Collector




class DataSet(object): class DataSet(object):
@@ -331,6 +332,7 @@ class DataSet(object):


else: else:
raise ValueError("data only be dict or list type.") raise ValueError("data only be dict or list type.")
self.collector = Collector()


def __contains__(self, item): def __contains__(self, item):
return item in self.field_arrays return item in self.field_arrays
@@ -954,3 +956,29 @@ class DataSet(object):
d = pickle.load(f) d = pickle.load(f)
assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d)) assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d))
return d return d

def add_collect_fn(self, fn, name, fields, is_input=True):
"""
添加 CollectFn,使用多个field产生batch中的数据

:param CollectFn fn: 定义产生数据的方式
:param str name: 生成的数据在batch中的名称
:param list fields: 用于产生数据的 fields,有序
:param bool is_input: 是否出现在input中,为否则出现在target batch中
"""
def check_fields(fields):
for f in fields:
if f not in self.field_arrays:
raise ValueError(f)

def check_name(name):
if name in self.field_arrays:
logger.warning('name of collect_fn will cover the field name in dataset')

check_fields(fields)
check_name(name)

self.collector.add_fn(fn, name, fields, is_input)

def _collect_batch(self, batch_dict):
return self.collector.collect_batch(batch_dict)

+ 16
- 0
test/core/test_batch.py View File

@@ -7,6 +7,7 @@ from fastNLP import DataSetIter, TorchLoaderIter
from fastNLP import DataSet from fastNLP import DataSet
from fastNLP import Instance from fastNLP import Instance
from fastNLP import SequentialSampler from fastNLP import SequentialSampler
from fastNLP import ConcatCollectFn




def generate_fake_dataset(num_samples=1000): def generate_fake_dataset(num_samples=1000):
@@ -150,6 +151,21 @@ class TestCase1(unittest.TestCase):
for batch_x, batch_y in batch: for batch_x, batch_y in batch:
pass pass


def test_collect_fn(self):
batch_size = 32
num_samples = 1000
dataset = generate_fake_dataset(num_samples)
dataset.set_input('1','2')
fn = ConcatCollectFn()
dataset.add_collect_fn(fn, '12', fields=['1', '2'], is_input=True)

batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True)
for batch_x, batch_y in batch:
for i in range(batch_size):
# print(i)
self.assertEqual(batch_x['12'][i].sum(), batch_x['1'][i].sum() + batch_x['2'][i].sum())


def testTensorLoaderIter(self): def testTensorLoaderIter(self):
class FakeData: class FakeData:
def __init__(self, return_dict=True): def __init__(self, return_dict=True):


Loading…
Cancel
Save