@@ -44,6 +44,9 @@ __all__ = [ | |||
"AutoPadder", | |||
"EngChar2DPadder", | |||
"CollectFn", | |||
"ConcatCollectFn", | |||
"MetricBase", | |||
"AccuracyMetric", | |||
"SpanFPreRecMetric", | |||
@@ -20,6 +20,9 @@ __all__ = [ | |||
"Padder", | |||
"AutoPadder", | |||
"EngChar2DPadder", | |||
"CollectFn", | |||
"ConcatCollectFn", | |||
"Vocabulary", | |||
@@ -94,3 +97,4 @@ from .tester import Tester | |||
from .trainer import Trainer | |||
from .utils import cache_results, seq_len_to_mask, get_seq_len | |||
from .vocabulary import Vocabulary | |||
from .collect_fn import CollectFn, ConcatCollectFn |
@@ -65,25 +65,49 @@ class DataSetGetter: | |||
for n, v in y.items(): | |||
batch_y[n].append(v) | |||
def may_to_tensor(data): | |||
if not self.as_numpy: | |||
try: | |||
data, flag = _to_tensor(data, data.dtype) | |||
except TypeError as e: | |||
logger.error(f"Field {n} cannot be converted to torch.tensor.") | |||
raise e | |||
return data | |||
def pad_collect(batch_dict): | |||
batch_x, batch_y = self.dataset._collect_batch(batch_dict) | |||
for b in [batch_x, batch_y]: | |||
for n in b.keys(): | |||
b[n] = may_to_tensor(b[n]) | |||
return batch_x, batch_y | |||
def pad_batch(batch_dict, field_array): | |||
result = {} | |||
for n, vlist in batch_dict.items(): | |||
f = field_array[n] | |||
if f.padder is None: | |||
batch_dict[n] = np.array(vlist) | |||
result[n] = np.array(vlist) | |||
else: | |||
data = f.pad(vlist) | |||
if not self.as_numpy: | |||
try: | |||
data, flag = _to_tensor(data, f.dtype) | |||
except TypeError as e: | |||
logger.error(f"Field {n} cannot be converted to torch.tensor.") | |||
raise e | |||
batch_dict[n] = data | |||
return batch_dict | |||
result[n] = may_to_tensor(data) | |||
return result | |||
# do padding on field_array | |||
pad_batch_x = pad_batch(batch_x, self.inputs) | |||
pad_batch_y = pad_batch(batch_y, self.targets) | |||
# do padding on dataset collect_fn | |||
batch_dict = batch_x.copy() | |||
batch_dict.update(batch_y) | |||
pad_dict_x, pad_dict_y = pad_collect(batch_dict) | |||
# group together | |||
pad_batch_x.update(pad_dict_x) | |||
pad_batch_y.update(pad_dict_y) | |||
return (indices, | |||
pad_batch(batch_x, self.inputs), | |||
pad_batch(batch_y, self.targets)) | |||
pad_batch_x, | |||
pad_batch_y) | |||
def set_idx_list(self, idx_list): | |||
if len(idx_list) != len(self.idx_list): | |||
@@ -0,0 +1,118 @@ | |||
import torch | |||
import numpy as np | |||
from .field import _get_ele_type_and_dim | |||
def _check_type(batch_dict, fields): | |||
if len(fields) == 0: | |||
raise RuntimeError | |||
types = [] | |||
dims = [] | |||
for f in fields: | |||
t, d = _get_ele_type_and_dim(batch_dict[f]) | |||
types.append(t) | |||
dims.append(d) | |||
diff_types = set(types) | |||
diff_dims = set(dims) | |||
if len(diff_types) > 1 or len(diff_dims) > 1: | |||
raise ValueError | |||
return types[0] | |||
def batching(samples, max_len=0, padding_val=0): | |||
if len(samples) == 0: | |||
return samples | |||
if max_len <= 0: | |||
max_len = max(s.shape[0] for s in samples) | |||
batch = np.full((len(samples), max_len), fill_value=padding_val) | |||
for i, s in enumerate(samples): | |||
slen = min(s.shape[0], max_len) | |||
batch[i][:slen] = s[:slen] | |||
return batch | |||
class Collector: | |||
def __init__(self): | |||
self.fns = [] | |||
self.names = [] | |||
self.fields_list = [] | |||
self.is_input = [] | |||
def add_fn(self, fn, name, fields, is_input): | |||
if name in self.names: | |||
raise ValueError("Duplicated name: {} for CollectFn: {}".format(name, fn)) | |||
if fn.num_fields() > 0 and len(fields) != fn.num_fields(): | |||
raise ValueError( | |||
"Incorrect num of fields, should be {} not {}".format( | |||
fn.num_fields(), len(fields) | |||
)) | |||
self.fns.append(fn) | |||
self.names.append(name) | |||
self.fields_list.append(fields) | |||
self.is_input.append(is_input) | |||
def collect_batch(self, batch_dict): | |||
if len(batch_dict) == 0: | |||
return {}, {} | |||
batch_x, batch_y = {}, {} | |||
for fn, name, fields, is_input in zip(self.fns, self.names, self.fields_list, self.is_input): | |||
batch = fn.collect(batch_dict, fields) | |||
if is_input: | |||
batch_x[name] = batch | |||
else: | |||
batch_y[name] = batch | |||
return batch_x, batch_y | |||
class CollectFn: | |||
def __init__(self): | |||
self.fields = [] | |||
def collect(self, batch_dict, fields): | |||
raise NotImplementedError | |||
def num_fields(self): | |||
return 0 | |||
@staticmethod | |||
def get_batch_size(batch_dict): | |||
if len(batch_dict) == 0: | |||
return 0 | |||
return len(next(iter(batch_dict.values()))) | |||
class ConcatCollectFn(CollectFn): | |||
""" | |||
field拼接Fn,将不同field按序拼接后,padding产生数据。所有field必须有相同的dim。 | |||
:param pad_val: padding的数值 | |||
:param max_len: 拼接后最大长度 | |||
""" | |||
def __init__(self, pad_val=0, max_len=0): | |||
super().__init__() | |||
self.pad_val = pad_val | |||
self.max_len = max_len | |||
def collect(self, batch_dict, fields): | |||
samples = [] | |||
dtype = _check_type(batch_dict, fields) | |||
batch_size = self.get_batch_size(batch_dict) | |||
for i in range(batch_size): | |||
sample = [] | |||
for n in fields: | |||
seq = batch_dict[n][i] | |||
if str(dtype).startswith('torch'): | |||
seq = seq.numpy() | |||
else: | |||
seq = np.array(seq, dtype=dtype) | |||
sample.append(seq) | |||
samples.append(np.concatenate(sample, axis=0)) | |||
batch = batching(samples, max_len=self.max_len, padding_val=self.pad_val) | |||
if str(dtype).startswith('torch'): | |||
batch = torch.tensor(batch, dtype=dtype) | |||
return batch | |||
def num_fields(self): | |||
return 0 |
@@ -302,6 +302,7 @@ from .field import SetInputOrTargetException | |||
from .instance import Instance | |||
from .utils import _get_func_signature | |||
from .utils import pretty_table_printer | |||
from .collect_fn import Collector | |||
class DataSet(object): | |||
@@ -331,6 +332,7 @@ class DataSet(object): | |||
else: | |||
raise ValueError("data only be dict or list type.") | |||
self.collector = Collector() | |||
def __contains__(self, item): | |||
return item in self.field_arrays | |||
@@ -954,3 +956,29 @@ class DataSet(object): | |||
d = pickle.load(f) | |||
assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d)) | |||
return d | |||
def add_collect_fn(self, fn, name, fields, is_input=True): | |||
""" | |||
添加 CollectFn,使用多个field产生batch中的数据 | |||
:param CollectFn fn: 定义产生数据的方式 | |||
:param str name: 生成的数据在batch中的名称 | |||
:param list fields: 用于产生数据的 fields,有序 | |||
:param bool is_input: 是否出现在input中,为否则出现在target batch中 | |||
""" | |||
def check_fields(fields): | |||
for f in fields: | |||
if f not in self.field_arrays: | |||
raise ValueError(f) | |||
def check_name(name): | |||
if name in self.field_arrays: | |||
logger.warning('name of collect_fn will cover the field name in dataset') | |||
check_fields(fields) | |||
check_name(name) | |||
self.collector.add_fn(fn, name, fields, is_input) | |||
def _collect_batch(self, batch_dict): | |||
return self.collector.collect_batch(batch_dict) |
@@ -7,6 +7,7 @@ from fastNLP import DataSetIter, TorchLoaderIter | |||
from fastNLP import DataSet | |||
from fastNLP import Instance | |||
from fastNLP import SequentialSampler | |||
from fastNLP import ConcatCollectFn | |||
def generate_fake_dataset(num_samples=1000): | |||
@@ -150,6 +151,21 @@ class TestCase1(unittest.TestCase): | |||
for batch_x, batch_y in batch: | |||
pass | |||
def test_collect_fn(self): | |||
batch_size = 32 | |||
num_samples = 1000 | |||
dataset = generate_fake_dataset(num_samples) | |||
dataset.set_input('1','2') | |||
fn = ConcatCollectFn() | |||
dataset.add_collect_fn(fn, '12', fields=['1', '2'], is_input=True) | |||
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) | |||
for batch_x, batch_y in batch: | |||
for i in range(batch_size): | |||
# print(i) | |||
self.assertEqual(batch_x['12'][i].sum(), batch_x['1'][i].sum() + batch_x['2'][i].sum()) | |||
def testTensorLoaderIter(self): | |||
class FakeData: | |||
def __init__(self, return_dict=True): | |||