diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 668bb93e..5e72106f 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -54,12 +54,6 @@ class DataSet(object): else: raise AttributeError('{} does not exist.'.format(item)) - def __setattr__(self, key, value): - if hasattr(self, 'fields'): - self.__setitem__(key, value) - else: - super().__setattr__(self, key, value) - def __repr__(self): return "\n".join(['{}: {}'.format(name, repr(self.dataset[name][self.idx])) for name in self.dataset.get_fields().keys()]) @@ -205,17 +199,23 @@ class DataSet(object): return [name for name, field in self.field_arrays.items() if field.is_target] def __getattr__(self, item): - if item in self.field_arrays: - return self.field_arrays[item] - elif item in _READERS: + # block infinite recursion for copy, pickle + if item == '__setstate__': + raise AttributeError(item) + try: + return self.field_arrays.__getitem__(item) + except KeyError: + pass + try: + reader_cls = _READERS[item] # add read_*data() support def _read(*args, **kwargs): - data = _READERS[item]().load(*args, **kwargs) + data = reader_cls().load(*args, **kwargs) self.extend(data) return self return _read - else: + except KeyError: raise AttributeError('{} does not exist.'.format(item)) @classmethod @@ -269,3 +269,6 @@ if __name__ == '__main__': _ = d.a d.apply(lambda x: x['a']) print(d[1]) + import copy + dd = copy.deepcopy(d) + print(dd.a)