@@ -44,6 +44,9 @@ __all__ = [ | |||||
"AutoPadder", | "AutoPadder", | ||||
"EngChar2DPadder", | "EngChar2DPadder", | ||||
"CollectFn", | |||||
"ConcatCollectFn", | |||||
"MetricBase", | "MetricBase", | ||||
"AccuracyMetric", | "AccuracyMetric", | ||||
"SpanFPreRecMetric", | "SpanFPreRecMetric", | ||||
@@ -20,6 +20,9 @@ __all__ = [ | |||||
"Padder", | "Padder", | ||||
"AutoPadder", | "AutoPadder", | ||||
"EngChar2DPadder", | "EngChar2DPadder", | ||||
"CollectFn", | |||||
"ConcatCollectFn", | |||||
"Vocabulary", | "Vocabulary", | ||||
@@ -94,3 +97,4 @@ from .tester import Tester | |||||
from .trainer import Trainer | from .trainer import Trainer | ||||
from .utils import cache_results, seq_len_to_mask, get_seq_len | from .utils import cache_results, seq_len_to_mask, get_seq_len | ||||
from .vocabulary import Vocabulary | from .vocabulary import Vocabulary | ||||
from .collect_fn import CollectFn, ConcatCollectFn |
@@ -65,25 +65,49 @@ class DataSetGetter: | |||||
for n, v in y.items(): | for n, v in y.items(): | ||||
batch_y[n].append(v) | batch_y[n].append(v) | ||||
def may_to_tensor(data): | |||||
if not self.as_numpy: | |||||
try: | |||||
data, flag = _to_tensor(data, data.dtype) | |||||
except TypeError as e: | |||||
logger.error(f"Field {n} cannot be converted to torch.tensor.") | |||||
raise e | |||||
return data | |||||
def pad_collect(batch_dict): | |||||
batch_x, batch_y = self.dataset._collect_batch(batch_dict) | |||||
for b in [batch_x, batch_y]: | |||||
for n in b.keys(): | |||||
b[n] = may_to_tensor(b[n]) | |||||
return batch_x, batch_y | |||||
def pad_batch(batch_dict, field_array): | def pad_batch(batch_dict, field_array): | ||||
result = {} | |||||
for n, vlist in batch_dict.items(): | for n, vlist in batch_dict.items(): | ||||
f = field_array[n] | f = field_array[n] | ||||
if f.padder is None: | if f.padder is None: | ||||
batch_dict[n] = np.array(vlist) | |||||
result[n] = np.array(vlist) | |||||
else: | else: | ||||
data = f.pad(vlist) | data = f.pad(vlist) | ||||
if not self.as_numpy: | |||||
try: | |||||
data, flag = _to_tensor(data, f.dtype) | |||||
except TypeError as e: | |||||
logger.error(f"Field {n} cannot be converted to torch.tensor.") | |||||
raise e | |||||
batch_dict[n] = data | |||||
return batch_dict | |||||
result[n] = may_to_tensor(data) | |||||
return result | |||||
# do padding on field_array | |||||
pad_batch_x = pad_batch(batch_x, self.inputs) | |||||
pad_batch_y = pad_batch(batch_y, self.targets) | |||||
# do padding on dataset collect_fn | |||||
batch_dict = batch_x.copy() | |||||
batch_dict.update(batch_y) | |||||
pad_dict_x, pad_dict_y = pad_collect(batch_dict) | |||||
# group together | |||||
pad_batch_x.update(pad_dict_x) | |||||
pad_batch_y.update(pad_dict_y) | |||||
return (indices, | return (indices, | ||||
pad_batch(batch_x, self.inputs), | |||||
pad_batch(batch_y, self.targets)) | |||||
pad_batch_x, | |||||
pad_batch_y) | |||||
def set_idx_list(self, idx_list): | def set_idx_list(self, idx_list): | ||||
if len(idx_list) != len(self.idx_list): | if len(idx_list) != len(self.idx_list): | ||||
@@ -0,0 +1,118 @@ | |||||
import torch | |||||
import numpy as np | |||||
from .field import _get_ele_type_and_dim | |||||
def _check_type(batch_dict, fields): | |||||
if len(fields) == 0: | |||||
raise RuntimeError | |||||
types = [] | |||||
dims = [] | |||||
for f in fields: | |||||
t, d = _get_ele_type_and_dim(batch_dict[f]) | |||||
types.append(t) | |||||
dims.append(d) | |||||
diff_types = set(types) | |||||
diff_dims = set(dims) | |||||
if len(diff_types) > 1 or len(diff_dims) > 1: | |||||
raise ValueError | |||||
return types[0] | |||||
def batching(samples, max_len=0, padding_val=0): | |||||
if len(samples) == 0: | |||||
return samples | |||||
if max_len <= 0: | |||||
max_len = max(s.shape[0] for s in samples) | |||||
batch = np.full((len(samples), max_len), fill_value=padding_val) | |||||
for i, s in enumerate(samples): | |||||
slen = min(s.shape[0], max_len) | |||||
batch[i][:slen] = s[:slen] | |||||
return batch | |||||
class Collector: | |||||
def __init__(self): | |||||
self.fns = [] | |||||
self.names = [] | |||||
self.fields_list = [] | |||||
self.is_input = [] | |||||
def add_fn(self, fn, name, fields, is_input): | |||||
if name in self.names: | |||||
raise ValueError("Duplicated name: {} for CollectFn: {}".format(name, fn)) | |||||
if fn.num_fields() > 0 and len(fields) != fn.num_fields(): | |||||
raise ValueError( | |||||
"Incorrect num of fields, should be {} not {}".format( | |||||
fn.num_fields(), len(fields) | |||||
)) | |||||
self.fns.append(fn) | |||||
self.names.append(name) | |||||
self.fields_list.append(fields) | |||||
self.is_input.append(is_input) | |||||
def collect_batch(self, batch_dict): | |||||
if len(batch_dict) == 0: | |||||
return {}, {} | |||||
batch_x, batch_y = {}, {} | |||||
for fn, name, fields, is_input in zip(self.fns, self.names, self.fields_list, self.is_input): | |||||
batch = fn.collect(batch_dict, fields) | |||||
if is_input: | |||||
batch_x[name] = batch | |||||
else: | |||||
batch_y[name] = batch | |||||
return batch_x, batch_y | |||||
class CollectFn: | |||||
def __init__(self): | |||||
self.fields = [] | |||||
def collect(self, batch_dict, fields): | |||||
raise NotImplementedError | |||||
def num_fields(self): | |||||
return 0 | |||||
@staticmethod | |||||
def get_batch_size(batch_dict): | |||||
if len(batch_dict) == 0: | |||||
return 0 | |||||
return len(next(iter(batch_dict.values()))) | |||||
class ConcatCollectFn(CollectFn): | |||||
""" | |||||
field拼接Fn,将不同field按序拼接后,padding产生数据。所有field必须有相同的dim。 | |||||
:param pad_val: padding的数值 | |||||
:param max_len: 拼接后最大长度 | |||||
""" | |||||
def __init__(self, pad_val=0, max_len=0): | |||||
super().__init__() | |||||
self.pad_val = pad_val | |||||
self.max_len = max_len | |||||
def collect(self, batch_dict, fields): | |||||
samples = [] | |||||
dtype = _check_type(batch_dict, fields) | |||||
batch_size = self.get_batch_size(batch_dict) | |||||
for i in range(batch_size): | |||||
sample = [] | |||||
for n in fields: | |||||
seq = batch_dict[n][i] | |||||
if str(dtype).startswith('torch'): | |||||
seq = seq.numpy() | |||||
else: | |||||
seq = np.array(seq, dtype=dtype) | |||||
sample.append(seq) | |||||
samples.append(np.concatenate(sample, axis=0)) | |||||
batch = batching(samples, max_len=self.max_len, padding_val=self.pad_val) | |||||
if str(dtype).startswith('torch'): | |||||
batch = torch.tensor(batch, dtype=dtype) | |||||
return batch | |||||
def num_fields(self): | |||||
return 0 |
@@ -302,6 +302,7 @@ from .field import SetInputOrTargetException | |||||
from .instance import Instance | from .instance import Instance | ||||
from .utils import _get_func_signature | from .utils import _get_func_signature | ||||
from .utils import pretty_table_printer | from .utils import pretty_table_printer | ||||
from .collect_fn import Collector | |||||
class DataSet(object): | class DataSet(object): | ||||
@@ -331,6 +332,7 @@ class DataSet(object): | |||||
else: | else: | ||||
raise ValueError("data only be dict or list type.") | raise ValueError("data only be dict or list type.") | ||||
self.collector = Collector() | |||||
def __contains__(self, item): | def __contains__(self, item): | ||||
return item in self.field_arrays | return item in self.field_arrays | ||||
@@ -954,3 +956,29 @@ class DataSet(object): | |||||
d = pickle.load(f) | d = pickle.load(f) | ||||
assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d)) | assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d)) | ||||
return d | return d | ||||
def add_collect_fn(self, fn, name, fields, is_input=True): | |||||
""" | |||||
添加 CollectFn,使用多个field产生batch中的数据 | |||||
:param CollectFn fn: 定义产生数据的方式 | |||||
:param str name: 生成的数据在batch中的名称 | |||||
:param list fields: 用于产生数据的 fields,有序 | |||||
:param bool is_input: 是否出现在input中,为否则出现在target batch中 | |||||
""" | |||||
def check_fields(fields): | |||||
for f in fields: | |||||
if f not in self.field_arrays: | |||||
raise ValueError(f) | |||||
def check_name(name): | |||||
if name in self.field_arrays: | |||||
logger.warning('name of collect_fn will cover the field name in dataset') | |||||
check_fields(fields) | |||||
check_name(name) | |||||
self.collector.add_fn(fn, name, fields, is_input) | |||||
def _collect_batch(self, batch_dict): | |||||
return self.collector.collect_batch(batch_dict) |
@@ -7,6 +7,7 @@ from fastNLP import DataSetIter, TorchLoaderIter | |||||
from fastNLP import DataSet | from fastNLP import DataSet | ||||
from fastNLP import Instance | from fastNLP import Instance | ||||
from fastNLP import SequentialSampler | from fastNLP import SequentialSampler | ||||
from fastNLP import ConcatCollectFn | |||||
def generate_fake_dataset(num_samples=1000): | def generate_fake_dataset(num_samples=1000): | ||||
@@ -150,6 +151,21 @@ class TestCase1(unittest.TestCase): | |||||
for batch_x, batch_y in batch: | for batch_x, batch_y in batch: | ||||
pass | pass | ||||
def test_collect_fn(self): | |||||
batch_size = 32 | |||||
num_samples = 1000 | |||||
dataset = generate_fake_dataset(num_samples) | |||||
dataset.set_input('1','2') | |||||
fn = ConcatCollectFn() | |||||
dataset.add_collect_fn(fn, '12', fields=['1', '2'], is_input=True) | |||||
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) | |||||
for batch_x, batch_y in batch: | |||||
for i in range(batch_size): | |||||
# print(i) | |||||
self.assertEqual(batch_x['12'][i].sum(), batch_x['1'][i].sum() + batch_x['2'][i].sum()) | |||||
def testTensorLoaderIter(self): | def testTensorLoaderIter(self): | ||||
class FakeData: | class FakeData: | ||||
def __init__(self, return_dict=True): | def __init__(self, return_dict=True): | ||||