@@ -14,10 +14,12 @@ from numbers import Number | |||
import numpy as np | |||
import torch | |||
import torch.utils.data | |||
from collections import defaultdict | |||
from ._logger import logger | |||
from .dataset import DataSet | |||
from .sampler import SequentialSampler | |||
from .field import _get_ele_type_and_dim | |||
_python_is_exit = False | |||
@@ -33,81 +35,75 @@ atexit.register(_set_python_is_exit) | |||
class DataSetGetter: | |||
def __init__(self, dataset: DataSet, as_numpy=False): | |||
self.dataset = dataset | |||
self.inputs = {n: f for n, f in dataset.get_all_fields().items() if f.is_input} | |||
self.targets = {n: f for n, f in dataset.get_all_fields().items() if f.is_target} | |||
self.as_numpy = as_numpy | |||
self.idx_list = list(range(len(dataset))) | |||
self.x_names = {n for n, f in dataset.get_all_fields().items() if f.is_input} | |||
self.y_names = {n for n, f in dataset.get_all_fields().items() if f.is_target} | |||
def __getitem__(self, idx: int): | |||
# mapping idx to sampled idx | |||
idx = self.idx_list[idx] | |||
inputs = {n:f.get(idx) for n, f in self.inputs.items()} | |||
targets = {n:f.get(idx) for n, f in self.targets.items()} | |||
return idx, inputs, targets | |||
ins = self.dataset[idx] | |||
return idx, ins | |||
def __len__(self): | |||
return len(self.dataset) | |||
def collate_fn(self, batch: list): | |||
def collate_fn(self, ins_list: list): | |||
""" | |||
:param batch: [[idx1, x_dict1, y_dict1], [idx2, x_dict2, y_dict2], [xx, xx, xx]] | |||
:return: | |||
""" | |||
# TODO 支持在DataSet中定义collate_fn,因为有时候可能需要不同的field之间融合,比如BERT的场景 | |||
batch_x = {n:[] for n in self.inputs.keys()} | |||
batch_y = {n:[] for n in self.targets.keys()} | |||
indices = [] | |||
for idx, x, y in batch: | |||
sin_x, sin_y = defaultdict(list), defaultdict(list) | |||
for idx, ins in ins_list: | |||
indices.append(idx) | |||
for n, v in x.items(): | |||
batch_x[n].append(v) | |||
for n, v in y.items(): | |||
batch_y[n].append(v) | |||
for n, v in ins.items(): | |||
if n in self.x_names: | |||
sin_x[n].append(v) | |||
if n in self.y_names: | |||
sin_y[n].append(v) | |||
def may_to_tensor(data): | |||
dtype, dim = _get_ele_type_and_dim(data) | |||
print(dtype, type(dtype)) | |||
if not self.as_numpy: | |||
try: | |||
data, flag = _to_tensor(data, data.dtype) | |||
data, flag = _to_tensor(data, dtype) | |||
except TypeError as e: | |||
logger.error(f"Field {n} cannot be converted to torch.tensor.") | |||
raise e | |||
return data | |||
def pad_collect(batch_dict): | |||
batch_x, batch_y = self.dataset._collect_batch(batch_dict) | |||
for b in [batch_x, batch_y]: | |||
for n in b.keys(): | |||
b[n] = may_to_tensor(b[n]) | |||
return batch_x, batch_y | |||
def pad_batch(batch_dict, field_array): | |||
def pad(batch_dict): | |||
result = {} | |||
for n, vlist in batch_dict.items(): | |||
f = field_array[n] | |||
f = self.dataset.field_arrays[n] | |||
if f.padder is None: | |||
result[n] = np.array(vlist) | |||
else: | |||
data = f.pad(vlist) | |||
result[n] = may_to_tensor(data) | |||
result[n] = f.pad(vlist) | |||
return result | |||
# do padding on field_array | |||
pad_batch_x = pad_batch(batch_x, self.inputs) | |||
pad_batch_y = pad_batch(batch_y, self.targets) | |||
sin_x = pad(sin_x) | |||
sin_y = pad(sin_y) | |||
bx, by = self.dataset._collect_batch(ins_list) | |||
def convert_tensor(batch_dict): | |||
for n, v in batch_dict.items(): | |||
batch_dict[n] = may_to_tensor(v) | |||
# do padding on dataset collect_fn | |||
batch_dict = batch_x.copy() | |||
batch_dict.update(batch_y) | |||
pad_dict_x, pad_dict_y = pad_collect(batch_dict) | |||
# collect_fn replaces single field | |||
sin_x.update(bx) | |||
sin_y.update(by) | |||
# group together | |||
pad_batch_x.update(pad_dict_x) | |||
pad_batch_y.update(pad_dict_y) | |||
convert_tensor(sin_x) | |||
convert_tensor(sin_y) | |||
return (indices, | |||
pad_batch_x, | |||
pad_batch_y) | |||
return (indices, sin_x, sin_y) | |||
def set_idx_list(self, idx_list): | |||
if len(idx_list) != len(self.idx_list): | |||
@@ -297,9 +293,9 @@ def _to_tensor(batch, field_dtype): | |||
if field_dtype is not None and isinstance(field_dtype, type)\ | |||
and issubclass(field_dtype, Number) \ | |||
and not isinstance(batch, torch.Tensor): | |||
if issubclass(batch.dtype.type, np.floating): | |||
if issubclass(field_dtype, np.floating): | |||
new_batch = torch.as_tensor(batch).float() # 默认使用float32 | |||
elif issubclass(batch.dtype.type, np.integer): | |||
elif issubclass(field_dtype, np.integer): | |||
new_batch = torch.as_tensor(batch).long() # 复用内存地址,避免复制 | |||
else: | |||
new_batch = torch.as_tensor(batch) | |||
@@ -1,6 +1,9 @@ | |||
from builtins import sorted | |||
import torch | |||
import numpy as np | |||
from .field import _get_ele_type_and_dim | |||
from collections import defaultdict | |||
def _check_type(batch_dict, fields): | |||
@@ -33,46 +36,99 @@ def batching(samples, max_len=0, padding_val=0): | |||
class Collector: | |||
def __init__(self): | |||
self.fns = [] | |||
self.names = [] | |||
self.fields_list = [] | |||
self.is_input = [] | |||
def add_fn(self, fn, name, fields, is_input): | |||
if name in self.names: | |||
raise ValueError("Duplicated name: {} for CollectFn: {}".format(name, fn)) | |||
if fn.num_fields() > 0 and len(fields) != fn.num_fields(): | |||
self.fns = {} | |||
self.input2fn = defaultdict(list) | |||
self.output2fn = defaultdict(list) | |||
self.fn2input = {} | |||
self.fn2output = {} | |||
def add_fn(self, fn, inputs, outputs, is_input, is_target): | |||
for name in outputs: | |||
if name in self.output2fn: | |||
raise ValueError("Duplicated name: {} for CollectFn: {}".format(name, fn)) | |||
if fn.num_inputs() > 0 and len(inputs) != fn.num_inputs(): | |||
raise ValueError( | |||
"Incorrect num of fields, should be {} not {}".format( | |||
fn.num_fields(), len(fields) | |||
"Incorrect num of inputs, should be {} not {}".format( | |||
fn.num_inputs(), len(inputs) | |||
)) | |||
self.fns.append(fn) | |||
self.names.append(name) | |||
self.fields_list.append(fields) | |||
self.is_input.append(is_input) | |||
def collect_batch(self, batch_dict): | |||
if len(batch_dict) == 0: | |||
if fn.num_outputs() > 0 and len(outputs) != fn.num_outputs(): | |||
raise ValueError("Incorrect num of inputs, should be {} not {}".format( | |||
fn.num_outputs(), len(outputs))) | |||
self.fns[fn] = {'is_input': is_input, 'is_target': is_target} | |||
for i, field in enumerate(inputs): | |||
self.input2fn[field].append((fn, i)) | |||
for i, name in enumerate(outputs): | |||
self.output2fn[name].append((fn, i)) | |||
def _rebuild_fn2io(self): | |||
def transpose(name2fn): | |||
fn2names = defaultdict(list) | |||
for name, vlist in name2fn.items(): | |||
for fn, i in vlist: | |||
fn2names[fn].append((name, i)) | |||
for fn, vlist in fn2names.items(): | |||
vlist = sorted(vlist, key=lambda x: x[1]) | |||
fn2names[fn] = [name for name, i in vlist] | |||
return fn2names | |||
self.fn2input = transpose(self.input2fn) | |||
self.fn2output = transpose(self.output2fn) | |||
def _clear_fn2io(self): | |||
self.fn2input.clear() | |||
self.fn2output.clear() | |||
def collect_batch(self, ins_list): | |||
if len(ins_list) == 0: | |||
return {}, {} | |||
batch_x, batch_y = {}, {} | |||
for fn, name, fields, is_input in zip(self.fns, self.names, self.fields_list, self.is_input): | |||
batch = fn.collect(batch_dict, fields) | |||
if is_input: | |||
batch_x[name] = batch | |||
else: | |||
batch_y[name] = batch | |||
return batch_x, batch_y | |||
if len(self.fn2output) == 0: | |||
self._rebuild_fn2io() | |||
bx = {} | |||
by = {} | |||
for fn, attr in self.fns.items(): | |||
inputs = self.fn2input.get(fn, None) | |||
outputs = self.fn2output.get(fn, None) | |||
res = fn.collect(ins_list, inputs, outputs) | |||
if attr.get('is_input', False): | |||
bx.update(res) | |||
if attr.get('is_target', False): | |||
by.update(res) | |||
return bx, by | |||
def rename_field(self, old_f, new_f): | |||
if new_f in self.input2fn: | |||
# name conflict | |||
raise ValueError | |||
if old_f not in self.input2fn: | |||
# renamed field not affect collectors | |||
return | |||
self.input2fn[new_f] = self.input2fn[old_f] | |||
self._clear_fn2io() | |||
def drop_field(self, f): | |||
if f in self.input2fn: | |||
raise ValueError | |||
def outputs(self): | |||
return self.output2fn.keys() | |||
class CollectFn: | |||
def __init__(self): | |||
self.fields = [] | |||
def collect(self, batch_dict, fields): | |||
def collect(self, ins_list, inputs, outputs): | |||
raise NotImplementedError | |||
def num_fields(self): | |||
def num_inputs(self): | |||
return 0 | |||
def num_outputs(self): | |||
return 0 | |||
@staticmethod | |||
@@ -95,24 +151,28 @@ class ConcatCollectFn(CollectFn): | |||
self.pad_val = pad_val | |||
self.max_len = max_len | |||
def collect(self, batch_dict, fields): | |||
@staticmethod | |||
def _to_numpy(seq): | |||
if torch.is_tensor(seq): | |||
return seq.numpy() | |||
else: | |||
return np.array(seq) | |||
def collect(self, ins_list, inputs, outputs): | |||
samples = [] | |||
dtype = _check_type(batch_dict, fields) | |||
batch_size = self.get_batch_size(batch_dict) | |||
for i in range(batch_size): | |||
for i, ins in ins_list: | |||
sample = [] | |||
for n in fields: | |||
seq = batch_dict[n][i] | |||
if str(dtype).startswith('torch'): | |||
seq = seq.numpy() | |||
else: | |||
seq = np.array(seq, dtype=dtype) | |||
sample.append(seq) | |||
for i in inputs: | |||
sample.append(self._to_numpy(ins[i])) | |||
samples.append(np.concatenate(sample, axis=0)) | |||
seq_len = [s.shape[0] for s in samples] | |||
batch = batching(samples, max_len=self.max_len, padding_val=self.pad_val) | |||
if str(dtype).startswith('torch'): | |||
batch = torch.tensor(batch, dtype=dtype) | |||
return batch | |||
o1, o2 = outputs | |||
return {o1: batch, o2: seq_len} | |||
def num_fields(self): | |||
def num_inputs(self): | |||
return 0 | |||
def num_outputs(self): | |||
# (concat_words, seq_len) | |||
return 2 |
@@ -957,28 +957,30 @@ class DataSet(object): | |||
assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d)) | |||
return d | |||
def add_collect_fn(self, fn, name, fields, is_input=True): | |||
def add_collect_fn(self, fn, inputs, outputs, is_input, is_target): | |||
""" | |||
添加 CollectFn,使用多个field产生batch中的数据 | |||
:param CollectFn fn: 定义产生数据的方式 | |||
:param str name: 生成的数据在batch中的名称 | |||
:param list fields: 用于产生数据的 fields,有序 | |||
:param list inputs: 生成的数据在batch中的名称 | |||
:param list outputs: 用于产生数据的 fields,有序 | |||
:param bool is_input: 是否出现在input中,为否则出现在target batch中 | |||
:param bool is_target: | |||
""" | |||
def check_fields(fields): | |||
for f in fields: | |||
if f not in self.field_arrays: | |||
raise ValueError(f) | |||
def check_name(name): | |||
if name in self.field_arrays: | |||
logger.warning('name of collect_fn will cover the field name in dataset') | |||
def check_name(names): | |||
for name in names: | |||
if name in self.field_arrays: | |||
logger.warning('name of collect_fn will cover the field name in dataset') | |||
check_fields(fields) | |||
check_name(name) | |||
check_fields(inputs) | |||
check_name(outputs) | |||
self.collector.add_fn(fn, name, fields, is_input) | |||
self.collector.add_fn(fn, inputs, outputs, is_input, is_target) | |||
def _collect_batch(self, batch_dict): | |||
return self.collector.collect_batch(batch_dict) | |||
def _collect_batch(self, ins_list): | |||
return self.collector.collect_batch(ins_list) |
@@ -26,7 +26,7 @@ def generate_fake_dataset(num_samples=1000): | |||
data = [] | |||
lengths = np.random.randint(min_len, max_len, size=(num_samples)) | |||
for length in lengths: | |||
data.append(np.random.randint(100, size=length)) | |||
data.append(np.random.randint(1, 100, size=length)) | |||
data_dict[str(i)] = data | |||
dataset = DataSet(data_dict) | |||
@@ -156,14 +156,21 @@ class TestCase1(unittest.TestCase): | |||
num_samples = 1000 | |||
dataset = generate_fake_dataset(num_samples) | |||
dataset.set_input('1','2') | |||
dataset.set_target('0','3') | |||
fn = ConcatCollectFn() | |||
dataset.add_collect_fn(fn, '12', fields=['1', '2'], is_input=True) | |||
dataset.add_collect_fn(fn, inputs=['1', '2'], | |||
outputs=['12', 'seq_len'], | |||
is_input=True, is_target=False) | |||
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler(), drop_last=True) | |||
for batch_x, batch_y in batch: | |||
for i in range(batch_size): | |||
# print(i) | |||
self.assertEqual(batch_x['12'][i].sum(), batch_x['1'][i].sum() + batch_x['2'][i].sum()) | |||
self.assertEqual( | |||
batch_x['seq_len'][i], | |||
(batch_x['1'][i]!=0).sum() + (batch_x['2'][i]!=0).sum()) | |||
def testTensorLoaderIter(self): | |||