From 423e8e37467d33876d50921070d4d2d3ed76f4c1 Mon Sep 17 00:00:00 2001 From: yunfan Date: Fri, 13 Mar 2020 21:23:18 +0800 Subject: [PATCH] [update] huge modify on collect_fn --- fastNLP/core/batch.py | 76 +++++++++---------- fastNLP/core/collect_fn.py | 146 ++++++++++++++++++++++++++----------- fastNLP/core/dataset.py | 24 +++--- test/core/test_batch.py | 11 ++- 4 files changed, 161 insertions(+), 96 deletions(-) diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index 84ad0f62..7f0c858b 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -14,10 +14,12 @@ from numbers import Number import numpy as np import torch import torch.utils.data +from collections import defaultdict from ._logger import logger from .dataset import DataSet from .sampler import SequentialSampler +from .field import _get_ele_type_and_dim _python_is_exit = False @@ -33,81 +35,75 @@ atexit.register(_set_python_is_exit) class DataSetGetter: def __init__(self, dataset: DataSet, as_numpy=False): self.dataset = dataset - self.inputs = {n: f for n, f in dataset.get_all_fields().items() if f.is_input} - self.targets = {n: f for n, f in dataset.get_all_fields().items() if f.is_target} self.as_numpy = as_numpy self.idx_list = list(range(len(dataset))) + self.x_names = {n for n, f in dataset.get_all_fields().items() if f.is_input} + self.y_names = {n for n, f in dataset.get_all_fields().items() if f.is_target} + def __getitem__(self, idx: int): # mapping idx to sampled idx idx = self.idx_list[idx] - inputs = {n:f.get(idx) for n, f in self.inputs.items()} - targets = {n:f.get(idx) for n, f in self.targets.items()} - return idx, inputs, targets + ins = self.dataset[idx] + return idx, ins def __len__(self): return len(self.dataset) - def collate_fn(self, batch: list): + def collate_fn(self, ins_list: list): """ :param batch: [[idx1, x_dict1, y_dict1], [idx2, x_dict2, y_dict2], [xx, xx, xx]] :return: """ # TODO 支持在DataSet中定义collate_fn,因为有时候可能需要不同的field之间融合,比如BERT的场景 - batch_x = {n:[] for n in self.inputs.keys()} - batch_y = {n:[] for n in self.targets.keys()} indices = [] - for idx, x, y in batch: + sin_x, sin_y = defaultdict(list), defaultdict(list) + for idx, ins in ins_list: indices.append(idx) - for n, v in x.items(): - batch_x[n].append(v) - for n, v in y.items(): - batch_y[n].append(v) + for n, v in ins.items(): + if n in self.x_names: + sin_x[n].append(v) + if n in self.y_names: + sin_y[n].append(v) def may_to_tensor(data): + dtype, dim = _get_ele_type_and_dim(data) + print(dtype, type(dtype)) if not self.as_numpy: try: - data, flag = _to_tensor(data, data.dtype) + data, flag = _to_tensor(data, dtype) except TypeError as e: logger.error(f"Field {n} cannot be converted to torch.tensor.") raise e return data - def pad_collect(batch_dict): - batch_x, batch_y = self.dataset._collect_batch(batch_dict) - for b in [batch_x, batch_y]: - for n in b.keys(): - b[n] = may_to_tensor(b[n]) - return batch_x, batch_y - - def pad_batch(batch_dict, field_array): + def pad(batch_dict): result = {} for n, vlist in batch_dict.items(): - f = field_array[n] + f = self.dataset.field_arrays[n] if f.padder is None: result[n] = np.array(vlist) else: - data = f.pad(vlist) - result[n] = may_to_tensor(data) + result[n] = f.pad(vlist) return result - # do padding on field_array - pad_batch_x = pad_batch(batch_x, self.inputs) - pad_batch_y = pad_batch(batch_y, self.targets) + sin_x = pad(sin_x) + sin_y = pad(sin_y) + + bx, by = self.dataset._collect_batch(ins_list) + def convert_tensor(batch_dict): + for n, v in batch_dict.items(): + batch_dict[n] = may_to_tensor(v) - # do padding on dataset collect_fn - batch_dict = batch_x.copy() - batch_dict.update(batch_y) - pad_dict_x, pad_dict_y = pad_collect(batch_dict) + # collect_fn replaces single field + sin_x.update(bx) + sin_y.update(by) - # group together - pad_batch_x.update(pad_dict_x) - pad_batch_y.update(pad_dict_y) + convert_tensor(sin_x) + convert_tensor(sin_y) - return (indices, - pad_batch_x, - pad_batch_y) + return (indices, sin_x, sin_y) def set_idx_list(self, idx_list): if len(idx_list) != len(self.idx_list): @@ -297,9 +293,9 @@ def _to_tensor(batch, field_dtype): if field_dtype is not None and isinstance(field_dtype, type)\ and issubclass(field_dtype, Number) \ and not isinstance(batch, torch.Tensor): - if issubclass(batch.dtype.type, np.floating): + if issubclass(field_dtype, np.floating): new_batch = torch.as_tensor(batch).float() # 默认使用float32 - elif issubclass(batch.dtype.type, np.integer): + elif issubclass(field_dtype, np.integer): new_batch = torch.as_tensor(batch).long() # 复用内存地址,避免复制 else: new_batch = torch.as_tensor(batch) diff --git a/fastNLP/core/collect_fn.py b/fastNLP/core/collect_fn.py index 6d56151e..7a869c9a 100644 --- a/fastNLP/core/collect_fn.py +++ b/fastNLP/core/collect_fn.py @@ -1,6 +1,9 @@ +from builtins import sorted + import torch import numpy as np from .field import _get_ele_type_and_dim +from collections import defaultdict def _check_type(batch_dict, fields): @@ -33,46 +36,99 @@ def batching(samples, max_len=0, padding_val=0): class Collector: def __init__(self): - self.fns = [] - self.names = [] - self.fields_list = [] - self.is_input = [] - - def add_fn(self, fn, name, fields, is_input): - if name in self.names: - raise ValueError("Duplicated name: {} for CollectFn: {}".format(name, fn)) - if fn.num_fields() > 0 and len(fields) != fn.num_fields(): + self.fns = {} + self.input2fn = defaultdict(list) + self.output2fn = defaultdict(list) + self.fn2input = {} + self.fn2output = {} + + def add_fn(self, fn, inputs, outputs, is_input, is_target): + for name in outputs: + if name in self.output2fn: + raise ValueError("Duplicated name: {} for CollectFn: {}".format(name, fn)) + + if fn.num_inputs() > 0 and len(inputs) != fn.num_inputs(): raise ValueError( - "Incorrect num of fields, should be {} not {}".format( - fn.num_fields(), len(fields) + "Incorrect num of inputs, should be {} not {}".format( + fn.num_inputs(), len(inputs) )) - self.fns.append(fn) - self.names.append(name) - self.fields_list.append(fields) - self.is_input.append(is_input) - - def collect_batch(self, batch_dict): - if len(batch_dict) == 0: + if fn.num_outputs() > 0 and len(outputs) != fn.num_outputs(): + raise ValueError("Incorrect num of inputs, should be {} not {}".format( + fn.num_outputs(), len(outputs))) + + self.fns[fn] = {'is_input': is_input, 'is_target': is_target} + for i, field in enumerate(inputs): + self.input2fn[field].append((fn, i)) + for i, name in enumerate(outputs): + self.output2fn[name].append((fn, i)) + + def _rebuild_fn2io(self): + def transpose(name2fn): + fn2names = defaultdict(list) + for name, vlist in name2fn.items(): + for fn, i in vlist: + fn2names[fn].append((name, i)) + for fn, vlist in fn2names.items(): + vlist = sorted(vlist, key=lambda x: x[1]) + fn2names[fn] = [name for name, i in vlist] + return fn2names + + self.fn2input = transpose(self.input2fn) + self.fn2output = transpose(self.output2fn) + + def _clear_fn2io(self): + self.fn2input.clear() + self.fn2output.clear() + + def collect_batch(self, ins_list): + if len(ins_list) == 0: return {}, {} - batch_x, batch_y = {}, {} - for fn, name, fields, is_input in zip(self.fns, self.names, self.fields_list, self.is_input): - batch = fn.collect(batch_dict, fields) - if is_input: - batch_x[name] = batch - else: - batch_y[name] = batch - return batch_x, batch_y + + if len(self.fn2output) == 0: + self._rebuild_fn2io() + + bx = {} + by = {} + for fn, attr in self.fns.items(): + inputs = self.fn2input.get(fn, None) + outputs = self.fn2output.get(fn, None) + res = fn.collect(ins_list, inputs, outputs) + if attr.get('is_input', False): + bx.update(res) + if attr.get('is_target', False): + by.update(res) + return bx, by + + def rename_field(self, old_f, new_f): + if new_f in self.input2fn: + # name conflict + raise ValueError + if old_f not in self.input2fn: + # renamed field not affect collectors + return + self.input2fn[new_f] = self.input2fn[old_f] + self._clear_fn2io() + + def drop_field(self, f): + if f in self.input2fn: + raise ValueError + + def outputs(self): + return self.output2fn.keys() class CollectFn: def __init__(self): self.fields = [] - def collect(self, batch_dict, fields): + def collect(self, ins_list, inputs, outputs): raise NotImplementedError - def num_fields(self): + def num_inputs(self): + return 0 + + def num_outputs(self): return 0 @staticmethod @@ -95,24 +151,28 @@ class ConcatCollectFn(CollectFn): self.pad_val = pad_val self.max_len = max_len - def collect(self, batch_dict, fields): + @staticmethod + def _to_numpy(seq): + if torch.is_tensor(seq): + return seq.numpy() + else: + return np.array(seq) + + def collect(self, ins_list, inputs, outputs): samples = [] - dtype = _check_type(batch_dict, fields) - batch_size = self.get_batch_size(batch_dict) - for i in range(batch_size): + for i, ins in ins_list: sample = [] - for n in fields: - seq = batch_dict[n][i] - if str(dtype).startswith('torch'): - seq = seq.numpy() - else: - seq = np.array(seq, dtype=dtype) - sample.append(seq) + for i in inputs: + sample.append(self._to_numpy(ins[i])) samples.append(np.concatenate(sample, axis=0)) + seq_len = [s.shape[0] for s in samples] batch = batching(samples, max_len=self.max_len, padding_val=self.pad_val) - if str(dtype).startswith('torch'): - batch = torch.tensor(batch, dtype=dtype) - return batch + o1, o2 = outputs + return {o1: batch, o2: seq_len} - def num_fields(self): + def num_inputs(self): return 0 + + def num_outputs(self): + # (concat_words, seq_len) + return 2 diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 7d77344d..b13eab76 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -957,28 +957,30 @@ class DataSet(object): assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d)) return d - def add_collect_fn(self, fn, name, fields, is_input=True): + def add_collect_fn(self, fn, inputs, outputs, is_input, is_target): """ 添加 CollectFn,使用多个field产生batch中的数据 :param CollectFn fn: 定义产生数据的方式 - :param str name: 生成的数据在batch中的名称 - :param list fields: 用于产生数据的 fields,有序 + :param list inputs: 生成的数据在batch中的名称 + :param list outputs: 用于产生数据的 fields,有序 :param bool is_input: 是否出现在input中,为否则出现在target batch中 + :param bool is_target: """ def check_fields(fields): for f in fields: if f not in self.field_arrays: raise ValueError(f) - def check_name(name): - if name in self.field_arrays: - logger.warning('name of collect_fn will cover the field name in dataset') + def check_name(names): + for name in names: + if name in self.field_arrays: + logger.warning('name of collect_fn will cover the field name in dataset') - check_fields(fields) - check_name(name) + check_fields(inputs) + check_name(outputs) - self.collector.add_fn(fn, name, fields, is_input) + self.collector.add_fn(fn, inputs, outputs, is_input, is_target) - def _collect_batch(self, batch_dict): - return self.collector.collect_batch(batch_dict) + def _collect_batch(self, ins_list): + return self.collector.collect_batch(ins_list) diff --git a/test/core/test_batch.py b/test/core/test_batch.py index fde2e5a5..dfdf28b7 100644 --- a/test/core/test_batch.py +++ b/test/core/test_batch.py @@ -26,7 +26,7 @@ def generate_fake_dataset(num_samples=1000): data = [] lengths = np.random.randint(min_len, max_len, size=(num_samples)) for length in lengths: - data.append(np.random.randint(100, size=length)) + data.append(np.random.randint(1, 100, size=length)) data_dict[str(i)] = data dataset = DataSet(data_dict) @@ -156,14 +156,21 @@ class TestCase1(unittest.TestCase): num_samples = 1000 dataset = generate_fake_dataset(num_samples) dataset.set_input('1','2') + dataset.set_target('0','3') + fn = ConcatCollectFn() - dataset.add_collect_fn(fn, '12', fields=['1', '2'], is_input=True) + dataset.add_collect_fn(fn, inputs=['1', '2'], + outputs=['12', 'seq_len'], + is_input=True, is_target=False) batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) for batch_x, batch_y in batch: for i in range(batch_size): # print(i) self.assertEqual(batch_x['12'][i].sum(), batch_x['1'][i].sum() + batch_x['2'][i].sum()) + self.assertEqual( + batch_x['seq_len'][i], + (batch_x['1'][i]!=0).sum() + (batch_x['2'][i]!=0).sum()) def testTensorLoaderIter(self):