Browse Source

[update] huge modify on collect_fn

tags/v0.5.5
yunfan 4 years ago
parent
commit
423e8e3746
4 changed files with 161 additions and 96 deletions
  1. +36
    -40
      fastNLP/core/batch.py
  2. +103
    -43
      fastNLP/core/collect_fn.py
  3. +13
    -11
      fastNLP/core/dataset.py
  4. +9
    -2
      test/core/test_batch.py

+ 36
- 40
fastNLP/core/batch.py View File

@@ -14,10 +14,12 @@ from numbers import Number
import numpy as np
import torch
import torch.utils.data
from collections import defaultdict

from ._logger import logger
from .dataset import DataSet
from .sampler import SequentialSampler
from .field import _get_ele_type_and_dim

_python_is_exit = False

@@ -33,81 +35,75 @@ atexit.register(_set_python_is_exit)
class DataSetGetter:
def __init__(self, dataset: DataSet, as_numpy=False):
self.dataset = dataset
self.inputs = {n: f for n, f in dataset.get_all_fields().items() if f.is_input}
self.targets = {n: f for n, f in dataset.get_all_fields().items() if f.is_target}
self.as_numpy = as_numpy
self.idx_list = list(range(len(dataset)))

self.x_names = {n for n, f in dataset.get_all_fields().items() if f.is_input}
self.y_names = {n for n, f in dataset.get_all_fields().items() if f.is_target}

def __getitem__(self, idx: int):
# mapping idx to sampled idx
idx = self.idx_list[idx]
inputs = {n:f.get(idx) for n, f in self.inputs.items()}
targets = {n:f.get(idx) for n, f in self.targets.items()}
return idx, inputs, targets
ins = self.dataset[idx]
return idx, ins

def __len__(self):
return len(self.dataset)

def collate_fn(self, batch: list):
def collate_fn(self, ins_list: list):
"""

:param batch: [[idx1, x_dict1, y_dict1], [idx2, x_dict2, y_dict2], [xx, xx, xx]]
:return:
"""
# TODO 支持在DataSet中定义collate_fn,因为有时候可能需要不同的field之间融合,比如BERT的场景
batch_x = {n:[] for n in self.inputs.keys()}
batch_y = {n:[] for n in self.targets.keys()}
indices = []
for idx, x, y in batch:
sin_x, sin_y = defaultdict(list), defaultdict(list)
for idx, ins in ins_list:
indices.append(idx)
for n, v in x.items():
batch_x[n].append(v)
for n, v in y.items():
batch_y[n].append(v)
for n, v in ins.items():
if n in self.x_names:
sin_x[n].append(v)
if n in self.y_names:
sin_y[n].append(v)

def may_to_tensor(data):
dtype, dim = _get_ele_type_and_dim(data)
print(dtype, type(dtype))
if not self.as_numpy:
try:
data, flag = _to_tensor(data, data.dtype)
data, flag = _to_tensor(data, dtype)
except TypeError as e:
logger.error(f"Field {n} cannot be converted to torch.tensor.")
raise e
return data

def pad_collect(batch_dict):
batch_x, batch_y = self.dataset._collect_batch(batch_dict)
for b in [batch_x, batch_y]:
for n in b.keys():
b[n] = may_to_tensor(b[n])
return batch_x, batch_y

def pad_batch(batch_dict, field_array):
def pad(batch_dict):
result = {}
for n, vlist in batch_dict.items():
f = field_array[n]
f = self.dataset.field_arrays[n]
if f.padder is None:
result[n] = np.array(vlist)
else:
data = f.pad(vlist)
result[n] = may_to_tensor(data)
result[n] = f.pad(vlist)
return result

# do padding on field_array
pad_batch_x = pad_batch(batch_x, self.inputs)
pad_batch_y = pad_batch(batch_y, self.targets)
sin_x = pad(sin_x)
sin_y = pad(sin_y)

bx, by = self.dataset._collect_batch(ins_list)
def convert_tensor(batch_dict):
for n, v in batch_dict.items():
batch_dict[n] = may_to_tensor(v)

# do padding on dataset collect_fn
batch_dict = batch_x.copy()
batch_dict.update(batch_y)
pad_dict_x, pad_dict_y = pad_collect(batch_dict)
# collect_fn replaces single field
sin_x.update(bx)
sin_y.update(by)

# group together
pad_batch_x.update(pad_dict_x)
pad_batch_y.update(pad_dict_y)
convert_tensor(sin_x)
convert_tensor(sin_y)

return (indices,
pad_batch_x,
pad_batch_y)
return (indices, sin_x, sin_y)

def set_idx_list(self, idx_list):
if len(idx_list) != len(self.idx_list):
@@ -297,9 +293,9 @@ def _to_tensor(batch, field_dtype):
if field_dtype is not None and isinstance(field_dtype, type)\
and issubclass(field_dtype, Number) \
and not isinstance(batch, torch.Tensor):
if issubclass(batch.dtype.type, np.floating):
if issubclass(field_dtype, np.floating):
new_batch = torch.as_tensor(batch).float() # 默认使用float32
elif issubclass(batch.dtype.type, np.integer):
elif issubclass(field_dtype, np.integer):
new_batch = torch.as_tensor(batch).long() # 复用内存地址,避免复制
else:
new_batch = torch.as_tensor(batch)


+ 103
- 43
fastNLP/core/collect_fn.py View File

@@ -1,6 +1,9 @@
from builtins import sorted

import torch
import numpy as np
from .field import _get_ele_type_and_dim
from collections import defaultdict


def _check_type(batch_dict, fields):
@@ -33,46 +36,99 @@ def batching(samples, max_len=0, padding_val=0):

class Collector:
def __init__(self):
self.fns = []
self.names = []
self.fields_list = []
self.is_input = []

def add_fn(self, fn, name, fields, is_input):
if name in self.names:
raise ValueError("Duplicated name: {} for CollectFn: {}".format(name, fn))
if fn.num_fields() > 0 and len(fields) != fn.num_fields():
self.fns = {}
self.input2fn = defaultdict(list)
self.output2fn = defaultdict(list)
self.fn2input = {}
self.fn2output = {}

def add_fn(self, fn, inputs, outputs, is_input, is_target):
for name in outputs:
if name in self.output2fn:
raise ValueError("Duplicated name: {} for CollectFn: {}".format(name, fn))

if fn.num_inputs() > 0 and len(inputs) != fn.num_inputs():
raise ValueError(
"Incorrect num of fields, should be {} not {}".format(
fn.num_fields(), len(fields)
"Incorrect num of inputs, should be {} not {}".format(
fn.num_inputs(), len(inputs)
))

self.fns.append(fn)
self.names.append(name)
self.fields_list.append(fields)
self.is_input.append(is_input)

def collect_batch(self, batch_dict):
if len(batch_dict) == 0:
if fn.num_outputs() > 0 and len(outputs) != fn.num_outputs():
raise ValueError("Incorrect num of inputs, should be {} not {}".format(
fn.num_outputs(), len(outputs)))

self.fns[fn] = {'is_input': is_input, 'is_target': is_target}
for i, field in enumerate(inputs):
self.input2fn[field].append((fn, i))
for i, name in enumerate(outputs):
self.output2fn[name].append((fn, i))

def _rebuild_fn2io(self):
def transpose(name2fn):
fn2names = defaultdict(list)
for name, vlist in name2fn.items():
for fn, i in vlist:
fn2names[fn].append((name, i))
for fn, vlist in fn2names.items():
vlist = sorted(vlist, key=lambda x: x[1])
fn2names[fn] = [name for name, i in vlist]
return fn2names

self.fn2input = transpose(self.input2fn)
self.fn2output = transpose(self.output2fn)

def _clear_fn2io(self):
self.fn2input.clear()
self.fn2output.clear()

def collect_batch(self, ins_list):
if len(ins_list) == 0:
return {}, {}
batch_x, batch_y = {}, {}
for fn, name, fields, is_input in zip(self.fns, self.names, self.fields_list, self.is_input):
batch = fn.collect(batch_dict, fields)
if is_input:
batch_x[name] = batch
else:
batch_y[name] = batch
return batch_x, batch_y

if len(self.fn2output) == 0:
self._rebuild_fn2io()

bx = {}
by = {}
for fn, attr in self.fns.items():
inputs = self.fn2input.get(fn, None)
outputs = self.fn2output.get(fn, None)
res = fn.collect(ins_list, inputs, outputs)
if attr.get('is_input', False):
bx.update(res)
if attr.get('is_target', False):
by.update(res)
return bx, by

def rename_field(self, old_f, new_f):
if new_f in self.input2fn:
# name conflict
raise ValueError
if old_f not in self.input2fn:
# renamed field not affect collectors
return
self.input2fn[new_f] = self.input2fn[old_f]
self._clear_fn2io()

def drop_field(self, f):
if f in self.input2fn:
raise ValueError

def outputs(self):
return self.output2fn.keys()


class CollectFn:
def __init__(self):
self.fields = []

def collect(self, batch_dict, fields):
def collect(self, ins_list, inputs, outputs):
raise NotImplementedError

def num_fields(self):
def num_inputs(self):
return 0

def num_outputs(self):
return 0

@staticmethod
@@ -95,24 +151,28 @@ class ConcatCollectFn(CollectFn):
self.pad_val = pad_val
self.max_len = max_len

def collect(self, batch_dict, fields):
@staticmethod
def _to_numpy(seq):
if torch.is_tensor(seq):
return seq.numpy()
else:
return np.array(seq)

def collect(self, ins_list, inputs, outputs):
samples = []
dtype = _check_type(batch_dict, fields)
batch_size = self.get_batch_size(batch_dict)
for i in range(batch_size):
for i, ins in ins_list:
sample = []
for n in fields:
seq = batch_dict[n][i]
if str(dtype).startswith('torch'):
seq = seq.numpy()
else:
seq = np.array(seq, dtype=dtype)
sample.append(seq)
for i in inputs:
sample.append(self._to_numpy(ins[i]))
samples.append(np.concatenate(sample, axis=0))
seq_len = [s.shape[0] for s in samples]
batch = batching(samples, max_len=self.max_len, padding_val=self.pad_val)
if str(dtype).startswith('torch'):
batch = torch.tensor(batch, dtype=dtype)
return batch
o1, o2 = outputs
return {o1: batch, o2: seq_len}

def num_fields(self):
def num_inputs(self):
return 0

def num_outputs(self):
# (concat_words, seq_len)
return 2

+ 13
- 11
fastNLP/core/dataset.py View File

@@ -957,28 +957,30 @@ class DataSet(object):
assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d))
return d

def add_collect_fn(self, fn, name, fields, is_input=True):
def add_collect_fn(self, fn, inputs, outputs, is_input, is_target):
"""
添加 CollectFn,使用多个field产生batch中的数据

:param CollectFn fn: 定义产生数据的方式
:param str name: 生成的数据在batch中的名称
:param list fields: 用于产生数据的 fields,有序
:param list inputs: 生成的数据在batch中的名称
:param list outputs: 用于产生数据的 fields,有序
:param bool is_input: 是否出现在input中,为否则出现在target batch中
:param bool is_target:
"""
def check_fields(fields):
for f in fields:
if f not in self.field_arrays:
raise ValueError(f)

def check_name(name):
if name in self.field_arrays:
logger.warning('name of collect_fn will cover the field name in dataset')
def check_name(names):
for name in names:
if name in self.field_arrays:
logger.warning('name of collect_fn will cover the field name in dataset')

check_fields(fields)
check_name(name)
check_fields(inputs)
check_name(outputs)

self.collector.add_fn(fn, name, fields, is_input)
self.collector.add_fn(fn, inputs, outputs, is_input, is_target)

def _collect_batch(self, batch_dict):
return self.collector.collect_batch(batch_dict)
def _collect_batch(self, ins_list):
return self.collector.collect_batch(ins_list)

+ 9
- 2
test/core/test_batch.py View File

@@ -26,7 +26,7 @@ def generate_fake_dataset(num_samples=1000):
data = []
lengths = np.random.randint(min_len, max_len, size=(num_samples))
for length in lengths:
data.append(np.random.randint(100, size=length))
data.append(np.random.randint(1, 100, size=length))
data_dict[str(i)] = data
dataset = DataSet(data_dict)
@@ -156,14 +156,21 @@ class TestCase1(unittest.TestCase):
num_samples = 1000
dataset = generate_fake_dataset(num_samples)
dataset.set_input('1','2')
dataset.set_target('0','3')

fn = ConcatCollectFn()
dataset.add_collect_fn(fn, '12', fields=['1', '2'], is_input=True)
dataset.add_collect_fn(fn, inputs=['1', '2'],
outputs=['12', 'seq_len'],
is_input=True, is_target=False)

batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True)
for batch_x, batch_y in batch:
for i in range(batch_size):
# print(i)
self.assertEqual(batch_x['12'][i].sum(), batch_x['1'][i].sum() + batch_x['2'][i].sum())
self.assertEqual(
batch_x['seq_len'][i],
(batch_x['1'][i]!=0).sum() + (batch_x['2'][i]!=0).sum())


def testTensorLoaderIter(self):


Loading…
Cancel
Save