@@ -14,10 +14,12 @@ from numbers import Number | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
import torch.utils.data | import torch.utils.data | ||||
from collections import defaultdict | |||||
from ._logger import logger | from ._logger import logger | ||||
from .dataset import DataSet | from .dataset import DataSet | ||||
from .sampler import SequentialSampler | from .sampler import SequentialSampler | ||||
from .field import _get_ele_type_and_dim | |||||
_python_is_exit = False | _python_is_exit = False | ||||
@@ -33,81 +35,75 @@ atexit.register(_set_python_is_exit) | |||||
class DataSetGetter: | class DataSetGetter: | ||||
def __init__(self, dataset: DataSet, as_numpy=False): | def __init__(self, dataset: DataSet, as_numpy=False): | ||||
self.dataset = dataset | self.dataset = dataset | ||||
self.inputs = {n: f for n, f in dataset.get_all_fields().items() if f.is_input} | |||||
self.targets = {n: f for n, f in dataset.get_all_fields().items() if f.is_target} | |||||
self.as_numpy = as_numpy | self.as_numpy = as_numpy | ||||
self.idx_list = list(range(len(dataset))) | self.idx_list = list(range(len(dataset))) | ||||
self.x_names = {n for n, f in dataset.get_all_fields().items() if f.is_input} | |||||
self.y_names = {n for n, f in dataset.get_all_fields().items() if f.is_target} | |||||
def __getitem__(self, idx: int): | def __getitem__(self, idx: int): | ||||
# mapping idx to sampled idx | # mapping idx to sampled idx | ||||
idx = self.idx_list[idx] | idx = self.idx_list[idx] | ||||
inputs = {n:f.get(idx) for n, f in self.inputs.items()} | |||||
targets = {n:f.get(idx) for n, f in self.targets.items()} | |||||
return idx, inputs, targets | |||||
ins = self.dataset[idx] | |||||
return idx, ins | |||||
def __len__(self): | def __len__(self): | ||||
return len(self.dataset) | return len(self.dataset) | ||||
def collate_fn(self, batch: list): | |||||
def collate_fn(self, ins_list: list): | |||||
""" | """ | ||||
:param batch: [[idx1, x_dict1, y_dict1], [idx2, x_dict2, y_dict2], [xx, xx, xx]] | :param batch: [[idx1, x_dict1, y_dict1], [idx2, x_dict2, y_dict2], [xx, xx, xx]] | ||||
:return: | :return: | ||||
""" | """ | ||||
# TODO 支持在DataSet中定义collate_fn,因为有时候可能需要不同的field之间融合,比如BERT的场景 | # TODO 支持在DataSet中定义collate_fn,因为有时候可能需要不同的field之间融合,比如BERT的场景 | ||||
batch_x = {n:[] for n in self.inputs.keys()} | |||||
batch_y = {n:[] for n in self.targets.keys()} | |||||
indices = [] | indices = [] | ||||
for idx, x, y in batch: | |||||
sin_x, sin_y = defaultdict(list), defaultdict(list) | |||||
for idx, ins in ins_list: | |||||
indices.append(idx) | indices.append(idx) | ||||
for n, v in x.items(): | |||||
batch_x[n].append(v) | |||||
for n, v in y.items(): | |||||
batch_y[n].append(v) | |||||
for n, v in ins.items(): | |||||
if n in self.x_names: | |||||
sin_x[n].append(v) | |||||
if n in self.y_names: | |||||
sin_y[n].append(v) | |||||
def may_to_tensor(data): | def may_to_tensor(data): | ||||
dtype, dim = _get_ele_type_and_dim(data) | |||||
print(dtype, type(dtype)) | |||||
if not self.as_numpy: | if not self.as_numpy: | ||||
try: | try: | ||||
data, flag = _to_tensor(data, data.dtype) | |||||
data, flag = _to_tensor(data, dtype) | |||||
except TypeError as e: | except TypeError as e: | ||||
logger.error(f"Field {n} cannot be converted to torch.tensor.") | logger.error(f"Field {n} cannot be converted to torch.tensor.") | ||||
raise e | raise e | ||||
return data | return data | ||||
def pad_collect(batch_dict): | |||||
batch_x, batch_y = self.dataset._collect_batch(batch_dict) | |||||
for b in [batch_x, batch_y]: | |||||
for n in b.keys(): | |||||
b[n] = may_to_tensor(b[n]) | |||||
return batch_x, batch_y | |||||
def pad_batch(batch_dict, field_array): | |||||
def pad(batch_dict): | |||||
result = {} | result = {} | ||||
for n, vlist in batch_dict.items(): | for n, vlist in batch_dict.items(): | ||||
f = field_array[n] | |||||
f = self.dataset.field_arrays[n] | |||||
if f.padder is None: | if f.padder is None: | ||||
result[n] = np.array(vlist) | result[n] = np.array(vlist) | ||||
else: | else: | ||||
data = f.pad(vlist) | |||||
result[n] = may_to_tensor(data) | |||||
result[n] = f.pad(vlist) | |||||
return result | return result | ||||
# do padding on field_array | |||||
pad_batch_x = pad_batch(batch_x, self.inputs) | |||||
pad_batch_y = pad_batch(batch_y, self.targets) | |||||
sin_x = pad(sin_x) | |||||
sin_y = pad(sin_y) | |||||
bx, by = self.dataset._collect_batch(ins_list) | |||||
def convert_tensor(batch_dict): | |||||
for n, v in batch_dict.items(): | |||||
batch_dict[n] = may_to_tensor(v) | |||||
# do padding on dataset collect_fn | |||||
batch_dict = batch_x.copy() | |||||
batch_dict.update(batch_y) | |||||
pad_dict_x, pad_dict_y = pad_collect(batch_dict) | |||||
# collect_fn replaces single field | |||||
sin_x.update(bx) | |||||
sin_y.update(by) | |||||
# group together | |||||
pad_batch_x.update(pad_dict_x) | |||||
pad_batch_y.update(pad_dict_y) | |||||
convert_tensor(sin_x) | |||||
convert_tensor(sin_y) | |||||
return (indices, | |||||
pad_batch_x, | |||||
pad_batch_y) | |||||
return (indices, sin_x, sin_y) | |||||
def set_idx_list(self, idx_list): | def set_idx_list(self, idx_list): | ||||
if len(idx_list) != len(self.idx_list): | if len(idx_list) != len(self.idx_list): | ||||
@@ -297,9 +293,9 @@ def _to_tensor(batch, field_dtype): | |||||
if field_dtype is not None and isinstance(field_dtype, type)\ | if field_dtype is not None and isinstance(field_dtype, type)\ | ||||
and issubclass(field_dtype, Number) \ | and issubclass(field_dtype, Number) \ | ||||
and not isinstance(batch, torch.Tensor): | and not isinstance(batch, torch.Tensor): | ||||
if issubclass(batch.dtype.type, np.floating): | |||||
if issubclass(field_dtype, np.floating): | |||||
new_batch = torch.as_tensor(batch).float() # 默认使用float32 | new_batch = torch.as_tensor(batch).float() # 默认使用float32 | ||||
elif issubclass(batch.dtype.type, np.integer): | |||||
elif issubclass(field_dtype, np.integer): | |||||
new_batch = torch.as_tensor(batch).long() # 复用内存地址,避免复制 | new_batch = torch.as_tensor(batch).long() # 复用内存地址,避免复制 | ||||
else: | else: | ||||
new_batch = torch.as_tensor(batch) | new_batch = torch.as_tensor(batch) | ||||
@@ -1,6 +1,9 @@ | |||||
from builtins import sorted | |||||
import torch | import torch | ||||
import numpy as np | import numpy as np | ||||
from .field import _get_ele_type_and_dim | from .field import _get_ele_type_and_dim | ||||
from collections import defaultdict | |||||
def _check_type(batch_dict, fields): | def _check_type(batch_dict, fields): | ||||
@@ -33,46 +36,99 @@ def batching(samples, max_len=0, padding_val=0): | |||||
class Collector: | class Collector: | ||||
def __init__(self): | def __init__(self): | ||||
self.fns = [] | |||||
self.names = [] | |||||
self.fields_list = [] | |||||
self.is_input = [] | |||||
def add_fn(self, fn, name, fields, is_input): | |||||
if name in self.names: | |||||
raise ValueError("Duplicated name: {} for CollectFn: {}".format(name, fn)) | |||||
if fn.num_fields() > 0 and len(fields) != fn.num_fields(): | |||||
self.fns = {} | |||||
self.input2fn = defaultdict(list) | |||||
self.output2fn = defaultdict(list) | |||||
self.fn2input = {} | |||||
self.fn2output = {} | |||||
def add_fn(self, fn, inputs, outputs, is_input, is_target): | |||||
for name in outputs: | |||||
if name in self.output2fn: | |||||
raise ValueError("Duplicated name: {} for CollectFn: {}".format(name, fn)) | |||||
if fn.num_inputs() > 0 and len(inputs) != fn.num_inputs(): | |||||
raise ValueError( | raise ValueError( | ||||
"Incorrect num of fields, should be {} not {}".format( | |||||
fn.num_fields(), len(fields) | |||||
"Incorrect num of inputs, should be {} not {}".format( | |||||
fn.num_inputs(), len(inputs) | |||||
)) | )) | ||||
self.fns.append(fn) | |||||
self.names.append(name) | |||||
self.fields_list.append(fields) | |||||
self.is_input.append(is_input) | |||||
def collect_batch(self, batch_dict): | |||||
if len(batch_dict) == 0: | |||||
if fn.num_outputs() > 0 and len(outputs) != fn.num_outputs(): | |||||
raise ValueError("Incorrect num of inputs, should be {} not {}".format( | |||||
fn.num_outputs(), len(outputs))) | |||||
self.fns[fn] = {'is_input': is_input, 'is_target': is_target} | |||||
for i, field in enumerate(inputs): | |||||
self.input2fn[field].append((fn, i)) | |||||
for i, name in enumerate(outputs): | |||||
self.output2fn[name].append((fn, i)) | |||||
def _rebuild_fn2io(self): | |||||
def transpose(name2fn): | |||||
fn2names = defaultdict(list) | |||||
for name, vlist in name2fn.items(): | |||||
for fn, i in vlist: | |||||
fn2names[fn].append((name, i)) | |||||
for fn, vlist in fn2names.items(): | |||||
vlist = sorted(vlist, key=lambda x: x[1]) | |||||
fn2names[fn] = [name for name, i in vlist] | |||||
return fn2names | |||||
self.fn2input = transpose(self.input2fn) | |||||
self.fn2output = transpose(self.output2fn) | |||||
def _clear_fn2io(self): | |||||
self.fn2input.clear() | |||||
self.fn2output.clear() | |||||
def collect_batch(self, ins_list): | |||||
if len(ins_list) == 0: | |||||
return {}, {} | return {}, {} | ||||
batch_x, batch_y = {}, {} | |||||
for fn, name, fields, is_input in zip(self.fns, self.names, self.fields_list, self.is_input): | |||||
batch = fn.collect(batch_dict, fields) | |||||
if is_input: | |||||
batch_x[name] = batch | |||||
else: | |||||
batch_y[name] = batch | |||||
return batch_x, batch_y | |||||
if len(self.fn2output) == 0: | |||||
self._rebuild_fn2io() | |||||
bx = {} | |||||
by = {} | |||||
for fn, attr in self.fns.items(): | |||||
inputs = self.fn2input.get(fn, None) | |||||
outputs = self.fn2output.get(fn, None) | |||||
res = fn.collect(ins_list, inputs, outputs) | |||||
if attr.get('is_input', False): | |||||
bx.update(res) | |||||
if attr.get('is_target', False): | |||||
by.update(res) | |||||
return bx, by | |||||
def rename_field(self, old_f, new_f): | |||||
if new_f in self.input2fn: | |||||
# name conflict | |||||
raise ValueError | |||||
if old_f not in self.input2fn: | |||||
# renamed field not affect collectors | |||||
return | |||||
self.input2fn[new_f] = self.input2fn[old_f] | |||||
self._clear_fn2io() | |||||
def drop_field(self, f): | |||||
if f in self.input2fn: | |||||
raise ValueError | |||||
def outputs(self): | |||||
return self.output2fn.keys() | |||||
class CollectFn: | class CollectFn: | ||||
def __init__(self): | def __init__(self): | ||||
self.fields = [] | self.fields = [] | ||||
def collect(self, batch_dict, fields): | |||||
def collect(self, ins_list, inputs, outputs): | |||||
raise NotImplementedError | raise NotImplementedError | ||||
def num_fields(self): | |||||
def num_inputs(self): | |||||
return 0 | |||||
def num_outputs(self): | |||||
return 0 | return 0 | ||||
@staticmethod | @staticmethod | ||||
@@ -95,24 +151,28 @@ class ConcatCollectFn(CollectFn): | |||||
self.pad_val = pad_val | self.pad_val = pad_val | ||||
self.max_len = max_len | self.max_len = max_len | ||||
def collect(self, batch_dict, fields): | |||||
@staticmethod | |||||
def _to_numpy(seq): | |||||
if torch.is_tensor(seq): | |||||
return seq.numpy() | |||||
else: | |||||
return np.array(seq) | |||||
def collect(self, ins_list, inputs, outputs): | |||||
samples = [] | samples = [] | ||||
dtype = _check_type(batch_dict, fields) | |||||
batch_size = self.get_batch_size(batch_dict) | |||||
for i in range(batch_size): | |||||
for i, ins in ins_list: | |||||
sample = [] | sample = [] | ||||
for n in fields: | |||||
seq = batch_dict[n][i] | |||||
if str(dtype).startswith('torch'): | |||||
seq = seq.numpy() | |||||
else: | |||||
seq = np.array(seq, dtype=dtype) | |||||
sample.append(seq) | |||||
for i in inputs: | |||||
sample.append(self._to_numpy(ins[i])) | |||||
samples.append(np.concatenate(sample, axis=0)) | samples.append(np.concatenate(sample, axis=0)) | ||||
seq_len = [s.shape[0] for s in samples] | |||||
batch = batching(samples, max_len=self.max_len, padding_val=self.pad_val) | batch = batching(samples, max_len=self.max_len, padding_val=self.pad_val) | ||||
if str(dtype).startswith('torch'): | |||||
batch = torch.tensor(batch, dtype=dtype) | |||||
return batch | |||||
o1, o2 = outputs | |||||
return {o1: batch, o2: seq_len} | |||||
def num_fields(self): | |||||
def num_inputs(self): | |||||
return 0 | return 0 | ||||
def num_outputs(self): | |||||
# (concat_words, seq_len) | |||||
return 2 |
@@ -957,28 +957,30 @@ class DataSet(object): | |||||
assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d)) | assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d)) | ||||
return d | return d | ||||
def add_collect_fn(self, fn, name, fields, is_input=True): | |||||
def add_collect_fn(self, fn, inputs, outputs, is_input, is_target): | |||||
""" | """ | ||||
添加 CollectFn,使用多个field产生batch中的数据 | 添加 CollectFn,使用多个field产生batch中的数据 | ||||
:param CollectFn fn: 定义产生数据的方式 | :param CollectFn fn: 定义产生数据的方式 | ||||
:param str name: 生成的数据在batch中的名称 | |||||
:param list fields: 用于产生数据的 fields,有序 | |||||
:param list inputs: 生成的数据在batch中的名称 | |||||
:param list outputs: 用于产生数据的 fields,有序 | |||||
:param bool is_input: 是否出现在input中,为否则出现在target batch中 | :param bool is_input: 是否出现在input中,为否则出现在target batch中 | ||||
:param bool is_target: | |||||
""" | """ | ||||
def check_fields(fields): | def check_fields(fields): | ||||
for f in fields: | for f in fields: | ||||
if f not in self.field_arrays: | if f not in self.field_arrays: | ||||
raise ValueError(f) | raise ValueError(f) | ||||
def check_name(name): | |||||
if name in self.field_arrays: | |||||
logger.warning('name of collect_fn will cover the field name in dataset') | |||||
def check_name(names): | |||||
for name in names: | |||||
if name in self.field_arrays: | |||||
logger.warning('name of collect_fn will cover the field name in dataset') | |||||
check_fields(fields) | |||||
check_name(name) | |||||
check_fields(inputs) | |||||
check_name(outputs) | |||||
self.collector.add_fn(fn, name, fields, is_input) | |||||
self.collector.add_fn(fn, inputs, outputs, is_input, is_target) | |||||
def _collect_batch(self, batch_dict): | |||||
return self.collector.collect_batch(batch_dict) | |||||
def _collect_batch(self, ins_list): | |||||
return self.collector.collect_batch(ins_list) |
@@ -26,7 +26,7 @@ def generate_fake_dataset(num_samples=1000): | |||||
data = [] | data = [] | ||||
lengths = np.random.randint(min_len, max_len, size=(num_samples)) | lengths = np.random.randint(min_len, max_len, size=(num_samples)) | ||||
for length in lengths: | for length in lengths: | ||||
data.append(np.random.randint(100, size=length)) | |||||
data.append(np.random.randint(1, 100, size=length)) | |||||
data_dict[str(i)] = data | data_dict[str(i)] = data | ||||
dataset = DataSet(data_dict) | dataset = DataSet(data_dict) | ||||
@@ -156,14 +156,21 @@ class TestCase1(unittest.TestCase): | |||||
num_samples = 1000 | num_samples = 1000 | ||||
dataset = generate_fake_dataset(num_samples) | dataset = generate_fake_dataset(num_samples) | ||||
dataset.set_input('1','2') | dataset.set_input('1','2') | ||||
dataset.set_target('0','3') | |||||
fn = ConcatCollectFn() | fn = ConcatCollectFn() | ||||
dataset.add_collect_fn(fn, '12', fields=['1', '2'], is_input=True) | |||||
dataset.add_collect_fn(fn, inputs=['1', '2'], | |||||
outputs=['12', 'seq_len'], | |||||
is_input=True, is_target=False) | |||||
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) | batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) | ||||
for batch_x, batch_y in batch: | for batch_x, batch_y in batch: | ||||
for i in range(batch_size): | for i in range(batch_size): | ||||
# print(i) | # print(i) | ||||
self.assertEqual(batch_x['12'][i].sum(), batch_x['1'][i].sum() + batch_x['2'][i].sum()) | self.assertEqual(batch_x['12'][i].sum(), batch_x['1'][i].sum() + batch_x['2'][i].sum()) | ||||
self.assertEqual( | |||||
batch_x['seq_len'][i], | |||||
(batch_x['1'][i]!=0).sum() + (batch_x['2'][i]!=0).sum()) | |||||
def testTensorLoaderIter(self): | def testTensorLoaderIter(self): | ||||