diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 702d37a1..2075515e 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -34,21 +34,29 @@ class DataSet(object): def __next__(self): self.idx += 1 - try: - self.fields = {name: field[self.idx] for name, field in self.dataset.get_fields().items()} - except IndexError: + if self.idx >= len(self.dataset): raise StopIteration return self def __getitem__(self, name): - return self.fields[name] + return self.dataset[name][self.idx] def __setitem__(self, name, val): if name not in self.dataset: new_fields = [None] * len(self.dataset) self.dataset.add_field(name, new_fields) self.dataset[name][self.idx] = val - self.fields[name] = val + + def __getattr__(self, item): + if item == 'fields': + self.fields = {name: field[self.idx] for name, field in self.dataset.get_fields().items()} + return self.fields + else: + raise AttributeError('{} does not exist.'.format(item)) + + def __setattr__(self, key, value): + self.__setitem__(key, value) + def __repr__(self): return "\n".join(['{}: {}'.format(name, repr(self.dataset[name][self.idx])) for name @@ -201,23 +209,19 @@ class DataSet(object): raise KeyError return self - def __getattribute__(self, name): - if name in _READERS: + def __getattr__(self, item): + if item in self.field_arrays: + return self.field_arrays[item] + elif item in _READERS: # add read_*data() support def _read(*args, **kwargs): - data = _READERS[name]().load(*args, **kwargs) + data = _READERS[item]().load(*args, **kwargs) self.extend(data) return self return _read else: - return object.__getattribute__(self, name) - - def __getattr__(self, item): - if item in self.field_arrays: - return self.field_arrays[item] - else: - self.__getattribute__(item) + raise AttributeError('{} does not exist.'.format(item)) @classmethod def set_reader(cls, method_name): diff --git a/fastNLP/core/instance.py b/fastNLP/core/instance.py index 12de4efa..89cf1221 100644 --- a/fastNLP/core/instance.py +++ b/fastNLP/core/instance.py @@ -12,19 +12,6 @@ class Instance(object): self.fields[field_name] = field return self - def rename_field(self, old_name, new_name): - if old_name in self.fields: - self.fields[new_name] = self.fields.pop(old_name) - else: - raise KeyError("error, no such field: {}".format(old_name)) - return self - - def set_target(self, **fields): - for name, val in fields.items(): - if name in self.fields: - self.fields[name].is_target = val - return self - def __getitem__(self, name): if name in self.fields: return self.fields[name] @@ -34,5 +21,14 @@ class Instance(object): def __setitem__(self, name, field): return self.add_field(name, field) + def __getattr__(self, item): + if item in self.fields: + return self.fields[item] + else: + raise AttributeError('{} does not exist.'.format(item)) + + def __setattr__(self, key, value): + self.__setitem__(key, value) + def __repr__(self): return self.fields.__repr__()