diff --git a/fastNLP/__init__.py b/fastNLP/__init__.py index 3cb6aa88..077a4d95 100644 --- a/fastNLP/__init__.py +++ b/fastNLP/__init__.py @@ -44,6 +44,9 @@ __all__ = [ "AutoPadder", "EngChar2DPadder", + "CollectFn", + "ConcatCollectFn", + "MetricBase", "AccuracyMetric", "SpanFPreRecMetric", diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index 18cdcac4..bda9c11e 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -20,6 +20,9 @@ __all__ = [ "Padder", "AutoPadder", "EngChar2DPadder", + + "CollectFn", + "ConcatCollectFn", "Vocabulary", @@ -94,3 +97,4 @@ from .tester import Tester from .trainer import Trainer from .utils import cache_results, seq_len_to_mask, get_seq_len from .vocabulary import Vocabulary +from .collect_fn import CollectFn, ConcatCollectFn diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index f2e34c52..84ad0f62 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -65,25 +65,49 @@ class DataSetGetter: for n, v in y.items(): batch_y[n].append(v) + def may_to_tensor(data): + if not self.as_numpy: + try: + data, flag = _to_tensor(data, data.dtype) + except TypeError as e: + logger.error(f"Field {n} cannot be converted to torch.tensor.") + raise e + return data + + def pad_collect(batch_dict): + batch_x, batch_y = self.dataset._collect_batch(batch_dict) + for b in [batch_x, batch_y]: + for n in b.keys(): + b[n] = may_to_tensor(b[n]) + return batch_x, batch_y + def pad_batch(batch_dict, field_array): + result = {} for n, vlist in batch_dict.items(): f = field_array[n] if f.padder is None: - batch_dict[n] = np.array(vlist) + result[n] = np.array(vlist) else: data = f.pad(vlist) - if not self.as_numpy: - try: - data, flag = _to_tensor(data, f.dtype) - except TypeError as e: - logger.error(f"Field {n} cannot be converted to torch.tensor.") - raise e - batch_dict[n] = data - return batch_dict + result[n] = may_to_tensor(data) + return result + + # do padding on field_array + pad_batch_x = pad_batch(batch_x, self.inputs) + pad_batch_y = pad_batch(batch_y, self.targets) + + # do padding on dataset collect_fn + batch_dict = batch_x.copy() + batch_dict.update(batch_y) + pad_dict_x, pad_dict_y = pad_collect(batch_dict) + + # group together + pad_batch_x.update(pad_dict_x) + pad_batch_y.update(pad_dict_y) return (indices, - pad_batch(batch_x, self.inputs), - pad_batch(batch_y, self.targets)) + pad_batch_x, + pad_batch_y) def set_idx_list(self, idx_list): if len(idx_list) != len(self.idx_list): diff --git a/fastNLP/core/collect_fn.py b/fastNLP/core/collect_fn.py new file mode 100644 index 00000000..6d56151e --- /dev/null +++ b/fastNLP/core/collect_fn.py @@ -0,0 +1,118 @@ +import torch +import numpy as np +from .field import _get_ele_type_and_dim + + +def _check_type(batch_dict, fields): + if len(fields) == 0: + raise RuntimeError + types = [] + dims = [] + for f in fields: + t, d = _get_ele_type_and_dim(batch_dict[f]) + types.append(t) + dims.append(d) + diff_types = set(types) + diff_dims = set(dims) + if len(diff_types) > 1 or len(diff_dims) > 1: + raise ValueError + return types[0] + + +def batching(samples, max_len=0, padding_val=0): + if len(samples) == 0: + return samples + if max_len <= 0: + max_len = max(s.shape[0] for s in samples) + batch = np.full((len(samples), max_len), fill_value=padding_val) + for i, s in enumerate(samples): + slen = min(s.shape[0], max_len) + batch[i][:slen] = s[:slen] + return batch + + +class Collector: + def __init__(self): + self.fns = [] + self.names = [] + self.fields_list = [] + self.is_input = [] + + def add_fn(self, fn, name, fields, is_input): + if name in self.names: + raise ValueError("Duplicated name: {} for CollectFn: {}".format(name, fn)) + if fn.num_fields() > 0 and len(fields) != fn.num_fields(): + raise ValueError( + "Incorrect num of fields, should be {} not {}".format( + fn.num_fields(), len(fields) + )) + + self.fns.append(fn) + self.names.append(name) + self.fields_list.append(fields) + self.is_input.append(is_input) + + def collect_batch(self, batch_dict): + if len(batch_dict) == 0: + return {}, {} + batch_x, batch_y = {}, {} + for fn, name, fields, is_input in zip(self.fns, self.names, self.fields_list, self.is_input): + batch = fn.collect(batch_dict, fields) + if is_input: + batch_x[name] = batch + else: + batch_y[name] = batch + return batch_x, batch_y + + +class CollectFn: + def __init__(self): + self.fields = [] + + def collect(self, batch_dict, fields): + raise NotImplementedError + + def num_fields(self): + return 0 + + @staticmethod + def get_batch_size(batch_dict): + if len(batch_dict) == 0: + return 0 + return len(next(iter(batch_dict.values()))) + + +class ConcatCollectFn(CollectFn): + """ + field拼接Fn,将不同field按序拼接后,padding产生数据。所有field必须有相同的dim。 + + :param pad_val: padding的数值 + :param max_len: 拼接后最大长度 + """ + + def __init__(self, pad_val=0, max_len=0): + super().__init__() + self.pad_val = pad_val + self.max_len = max_len + + def collect(self, batch_dict, fields): + samples = [] + dtype = _check_type(batch_dict, fields) + batch_size = self.get_batch_size(batch_dict) + for i in range(batch_size): + sample = [] + for n in fields: + seq = batch_dict[n][i] + if str(dtype).startswith('torch'): + seq = seq.numpy() + else: + seq = np.array(seq, dtype=dtype) + sample.append(seq) + samples.append(np.concatenate(sample, axis=0)) + batch = batching(samples, max_len=self.max_len, padding_val=self.pad_val) + if str(dtype).startswith('torch'): + batch = torch.tensor(batch, dtype=dtype) + return batch + + def num_fields(self): + return 0 diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 0a24ab22..7d77344d 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -302,6 +302,7 @@ from .field import SetInputOrTargetException from .instance import Instance from .utils import _get_func_signature from .utils import pretty_table_printer +from .collect_fn import Collector class DataSet(object): @@ -331,6 +332,7 @@ class DataSet(object): else: raise ValueError("data only be dict or list type.") + self.collector = Collector() def __contains__(self, item): return item in self.field_arrays @@ -954,3 +956,29 @@ class DataSet(object): d = pickle.load(f) assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d)) return d + + def add_collect_fn(self, fn, name, fields, is_input=True): + """ + 添加 CollectFn,使用多个field产生batch中的数据 + + :param CollectFn fn: 定义产生数据的方式 + :param str name: 生成的数据在batch中的名称 + :param list fields: 用于产生数据的 fields,有序 + :param bool is_input: 是否出现在input中,为否则出现在target batch中 + """ + def check_fields(fields): + for f in fields: + if f not in self.field_arrays: + raise ValueError(f) + + def check_name(name): + if name in self.field_arrays: + logger.warning('name of collect_fn will cover the field name in dataset') + + check_fields(fields) + check_name(name) + + self.collector.add_fn(fn, name, fields, is_input) + + def _collect_batch(self, batch_dict): + return self.collector.collect_batch(batch_dict) diff --git a/test/core/test_batch.py b/test/core/test_batch.py index d9898bc7..fde2e5a5 100644 --- a/test/core/test_batch.py +++ b/test/core/test_batch.py @@ -7,6 +7,7 @@ from fastNLP import DataSetIter, TorchLoaderIter from fastNLP import DataSet from fastNLP import Instance from fastNLP import SequentialSampler +from fastNLP import ConcatCollectFn def generate_fake_dataset(num_samples=1000): @@ -150,6 +151,21 @@ class TestCase1(unittest.TestCase): for batch_x, batch_y in batch: pass + def test_collect_fn(self): + batch_size = 32 + num_samples = 1000 + dataset = generate_fake_dataset(num_samples) + dataset.set_input('1','2') + fn = ConcatCollectFn() + dataset.add_collect_fn(fn, '12', fields=['1', '2'], is_input=True) + + batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) + for batch_x, batch_y in batch: + for i in range(batch_size): + # print(i) + self.assertEqual(batch_x['12'][i].sum(), batch_x['1'][i].sum() + batch_x['2'][i].sum()) + + def testTensorLoaderIter(self): class FakeData: def __init__(self, return_dict=True):