diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index e626ff26..131ba28d 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -1,5 +1,8 @@ import random -import sys +import sys, os +sys.path.append('../..') +sys.path = [os.path.join(os.path.dirname(__file__), '../..')] + sys.path + from collections import defaultdict from copy import deepcopy import numpy as np @@ -15,36 +18,67 @@ class DataSet(object): """A DataSet object is a list of Instance objects. """ + class DataSetIter(object): + def __init__(self, dataset): + self.dataset = dataset + self.idx = -1 + + def __next__(self): + self.idx += 1 + if self.idx >= len(self.dataset): + raise StopIteration + return self + + def __getitem__(self, name): + return self.dataset[name][self.idx] + + def __setitem__(self, name, val): + # TODO check new field. + self.dataset[name][self.idx] = val + + def __repr__(self): + # TODO + pass def __init__(self, instance=None): + self.field_arrays = {} if instance is not None: self._convert_ins(instance) - else: - self.field_arrays = {} + + def __iter__(self): + return self.DataSetIter(self) def _convert_ins(self, ins_list): if isinstance(ins_list, list): for ins in ins_list: self.append(ins) else: - self.append(ins) + self.append(ins_list) def append(self, ins): # no field if len(self.field_arrays) == 0: - for name, field in ins.field.items(): + for name, field in ins.fields.items(): self.field_arrays[name] = FieldArray(name, [field]) else: - assert len(self.field_arrays) == len(ins.field) - for name, field in ins.field.items(): + assert len(self.field_arrays) == len(ins.fields) + for name, field in ins.fields.items(): assert name in self.field_arrays self.field_arrays[name].append(field) + def add_field(self, name, fields): + assert len(self) == len(fields) + self.field_arrays[name] = FieldArray(name, fields) + def get_fields(self): return self.field_arrays + def __getitem__(self, name): + assert name in self.field_arrays + return self.field_arrays[name] + def __len__(self): - field = self.field_arrays.values()[0] + field = iter(self.field_arrays.values()).__next__() return len(field) def get_length(self): @@ -121,3 +155,14 @@ class DataSet(object): _READERS[method_name] = read_cls return read_cls return wrapper + + +if __name__ == '__main__': + from fastNLP.core.instance import Instance + ins = Instance(test='test0') + dataset = DataSet([ins]) + for _iter in dataset: + print(_iter['test']) + _iter['test'] = 'abc' + print(_iter['test']) + print(dataset.field_arrays) \ No newline at end of file diff --git a/fastNLP/core/field.py b/fastNLP/core/field.py index 5b9c1b63..cf34abf8 100644 --- a/fastNLP/core/field.py +++ b/fastNLP/core/field.py @@ -39,35 +39,6 @@ class TextField(Field): super(TextField, self).__init__(text, is_target) -class IndexField(Field): - def __init__(self, name, content, vocab, is_target): - super(IndexField, self).__init__(name, is_target) - self.content = [] - self.padding_idx = vocab.padding_idx - for sent in content: - idx = vocab.index_sent(sent) - if isinstance(idx, list): - idx = torch.Tensor(idx) - elif isinstance(idx, np.array): - idx = torch.from_numpy(idx) - elif not isinstance(idx, torch.Tensor): - raise ValueError - self.content.append(idx) - - def to_tensor(self, id_list, sort_within_batch=False): - max_len = max(id_list) - batch_size = len(id_list) - tensor = torch.full((batch_size, max_len), self.padding_idx, dtype=torch.long) - len_list = [(i, self.content[i].size(0)) for i in id_list] - if sort_within_batch: - len_list = sorted(len_list, key=lambda x: x[1], reverse=True) - for i, (idx, length) in enumerate(len_list): - if length == max_len: - tensor[i] = self.content[idx] - else: - tensor[i][:length] = self.content[idx] - return tensor - class LabelField(Field): """The Field representing a single label. Can be a string or integer. diff --git a/fastNLP/core/fieldarray.py b/fastNLP/core/fieldarray.py index 9710f991..a08e7f12 100644 --- a/fastNLP/core/fieldarray.py +++ b/fastNLP/core/fieldarray.py @@ -2,38 +2,38 @@ import torch import numpy as np class FieldArray(object): - def __init__(self, name, content, padding_val=0, is_target=True, need_tensor=True): + def __init__(self, name, content, padding_val=0, is_target=False, need_tensor=False): self.name = name - self.data = [self._convert_np(val) for val in content] + self.content = content self.padding_val = padding_val self.is_target = is_target self.need_tensor = need_tensor - def _convert_np(self, val): - if not isinstance(val, np.array): - return np.array(val) - else: - return val + def __repr__(self): + #TODO + return '{}: {}'.format(self.name, self.content.__repr__()) def append(self, val): - self.data.append(self._convert_np(val)) + self.content.append(val) + + def __getitem__(self, name): + return self.get(name) + + def __setitem__(self, name, val): + assert isinstance(name, int) + self.content[name] = val def get(self, idxes): if isinstance(idxes, int): - return self.data[idxes] - elif isinstance(idxes, list): - id_list = np.array(idxes) - batch_size = len(id_list) - len_list = [(i, self.data[i].shape[0]) for i in id_list] - _, max_len = max(len_list, key=lambda x: x[1]) + return self.content[idxes] + assert self.need_tensor is True + batch_size = len(idxes) + max_len = max([len(self.content[i]) for i in idxes]) array = np.full((batch_size, max_len), self.padding_val, dtype=np.int32) - for i, (idx, length) in enumerate(len_list): - if length == max_len: - array[i] = self.data[idx] - else: - array[i][:length] = self.data[idx] + for i, idx in enumerate(idxes): + array[i][:len(self.content[idx])] = self.content[idx] return array def __len__(self): - return len(self.data) + return len(self.content)