From 4149eb9c0655fd75b2d2b64786c9a601bc5a53c9 Mon Sep 17 00:00:00 2001 From: yh Date: Mon, 19 Nov 2018 15:12:07 +0800 Subject: [PATCH] add apply to dataset --- fastNLP/core/dataset.py | 52 ++++++++++++++++++++++++++++++----------- 1 file changed, 38 insertions(+), 14 deletions(-) diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 3e92e711..8375cf74 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -22,7 +22,7 @@ class DataSet(object): """ - class DataSetIter(object): + class Instance(object): def __init__(self, dataset, idx=-1): self.dataset = dataset self.idx = idx @@ -43,18 +43,32 @@ class DataSet(object): self.dataset[name][self.idx] = val def __repr__(self): - return "\n".join(['{}: {}'.format(name, repr(self.dataset[name][self.idx])) for name in self.dataset.get_fields().keys()]) + return "\n".join(['{}: {}'.format(name, repr(self.dataset[name][self.idx])) for name + in self.dataset.get_fields().keys()]) - def __init__(self, instance=None): + def __init__(self, data=None): self.field_arrays = {} - if instance is not None: - self._convert_ins(instance) + if data is not None: + if isinstance(data, dict): + length_set = set() + for key, value in data.items(): + length_set.add(len(value)) + assert len(length_set)==1, "Arrays must all be same length." + for key, value in data.items(): + self.add_field(name=key, fields=value) + elif isinstance(data, list): + for ins in data: + assert isinstance(ins, Instance), "Must be Instance type, not {}.".format(type(ins)) + self.append(ins) + + else: + raise ValueError("data only be dict or list type.") def __contains__(self, item): return item in self.field_arrays def __iter__(self): - return self.DataSetIter(self) + return self.Instance(self) def _convert_ins(self, ins_list): if isinstance(ins_list, list): @@ -89,7 +103,7 @@ class DataSet(object): def __getitem__(self, name): if isinstance(name, int): - return self.DataSetIter(self, idx=name) + return self.Instance(self, idx=name) elif isinstance(name, str): return self.field_arrays[name] else: @@ -150,6 +164,12 @@ class DataSet(object): else: return object.__getattribute__(self, name) + def __getattr__(self, item): + if item in self.field_arrays: + return self.field_arrays[item] + else: + self.__getattribute__(item) + @classmethod def set_reader(cls, method_name): """decorator to add dataloader support @@ -162,14 +182,18 @@ class DataSet(object): return wrapper + def apply(self, func, new_field_name=None): + results = [] + for ins in self: + results.append(func(ins)) + if new_field_name is not None: + self.add_field(new_field_name, results) + return results if __name__ == '__main__': from fastNLP.core.instance import Instance - ins = Instance(test='test0') - dataset = DataSet([ins]) - for _iter in dataset: - print(_iter['test']) - _iter['test'] = 'abc' - print(_iter['test']) - print(dataset.field_arrays) + d = DataSet({'a': list('abc')}) + d.a + d.apply(lambda x: x['a']) + print(d[1])