From dd0bb0d7913dd93e064817356caf585c7513c5f3 Mon Sep 17 00:00:00 2001 From: yunfan Date: Fri, 9 Nov 2018 22:02:34 +0800 Subject: [PATCH] add data iter --- fastNLP/core/dataset.py | 57 ++++++++++++++++++++++++++++++++------ fastNLP/core/fieldarray.py | 14 +++++++++- 2 files changed, 62 insertions(+), 9 deletions(-) diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index c6f0de35..131ba28d 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -1,5 +1,8 @@ import random -import sys +import sys, os +sys.path.append('../..') +sys.path = [os.path.join(os.path.dirname(__file__), '../..')] + sys.path + from collections import defaultdict from copy import deepcopy import numpy as np @@ -15,12 +18,35 @@ class DataSet(object): """A DataSet object is a list of Instance objects. """ + class DataSetIter(object): + def __init__(self, dataset): + self.dataset = dataset + self.idx = -1 + + def __next__(self): + self.idx += 1 + if self.idx >= len(self.dataset): + raise StopIteration + return self + + def __getitem__(self, name): + return self.dataset[name][self.idx] + + def __setitem__(self, name, val): + # TODO check new field. + self.dataset[name][self.idx] = val + + def __repr__(self): + # TODO + pass def __init__(self, instance=None): + self.field_arrays = {} if instance is not None: self._convert_ins(instance) - else: - self.field_arrays = {} + + def __iter__(self): + return self.DataSetIter(self) def _convert_ins(self, ins_list): if isinstance(ins_list, list): @@ -32,23 +58,27 @@ class DataSet(object): def append(self, ins): # no field if len(self.field_arrays) == 0: - for name, field in ins.field.items(): + for name, field in ins.fields.items(): self.field_arrays[name] = FieldArray(name, [field]) else: - assert len(self.field_arrays) == len(ins.field) - for name, field in ins.field.items(): + assert len(self.field_arrays) == len(ins.fields) + for name, field in ins.fields.items(): assert name in self.field_arrays self.field_arrays[name].append(field) def add_field(self, name, fields): assert len(self) == len(fields) - self.field_arrays[name] = fields + self.field_arrays[name] = FieldArray(name, fields) def get_fields(self): return self.field_arrays + def __getitem__(self, name): + assert name in self.field_arrays + return self.field_arrays[name] + def __len__(self): - field = self.field_arrays.values()[0] + field = iter(self.field_arrays.values()).__next__() return len(field) def get_length(self): @@ -125,3 +155,14 @@ class DataSet(object): _READERS[method_name] = read_cls return read_cls return wrapper + + +if __name__ == '__main__': + from fastNLP.core.instance import Instance + ins = Instance(test='test0') + dataset = DataSet([ins]) + for _iter in dataset: + print(_iter['test']) + _iter['test'] = 'abc' + print(_iter['test']) + print(dataset.field_arrays) \ No newline at end of file diff --git a/fastNLP/core/fieldarray.py b/fastNLP/core/fieldarray.py index 9d0f8e9e..a08e7f12 100644 --- a/fastNLP/core/fieldarray.py +++ b/fastNLP/core/fieldarray.py @@ -2,19 +2,31 @@ import torch import numpy as np class FieldArray(object): - def __init__(self, name, content, padding_val=0, is_target=True, need_tensor=True): + def __init__(self, name, content, padding_val=0, is_target=False, need_tensor=False): self.name = name self.content = content self.padding_val = padding_val self.is_target = is_target self.need_tensor = need_tensor + def __repr__(self): + #TODO + return '{}: {}'.format(self.name, self.content.__repr__()) + def append(self, val): self.content.append(val) + def __getitem__(self, name): + return self.get(name) + + def __setitem__(self, name, val): + assert isinstance(name, int) + self.content[name] = val + def get(self, idxes): if isinstance(idxes, int): return self.content[idxes] + assert self.need_tensor is True batch_size = len(idxes) max_len = max([len(self.content[i]) for i in idxes]) array = np.full((batch_size, max_len), self.padding_val, dtype=np.int32)