Browse Source

fix Dataset

tags/v0.2.0
yunfan 6 years ago
parent
commit
92da53a65b
1 changed files with 14 additions and 11 deletions
  1. +14
    -11
      fastNLP/core/dataset.py

+ 14
- 11
fastNLP/core/dataset.py View File

@@ -54,12 +54,6 @@ class DataSet(object):
else: else:
raise AttributeError('{} does not exist.'.format(item)) raise AttributeError('{} does not exist.'.format(item))


def __setattr__(self, key, value):
if hasattr(self, 'fields'):
self.__setitem__(key, value)
else:
super().__setattr__(self, key, value)

def __repr__(self): def __repr__(self):
return "\n".join(['{}: {}'.format(name, repr(self.dataset[name][self.idx])) for name return "\n".join(['{}: {}'.format(name, repr(self.dataset[name][self.idx])) for name
in self.dataset.get_fields().keys()]) in self.dataset.get_fields().keys()])
@@ -205,17 +199,23 @@ class DataSet(object):
return [name for name, field in self.field_arrays.items() if field.is_target] return [name for name, field in self.field_arrays.items() if field.is_target]


def __getattr__(self, item): def __getattr__(self, item):
if item in self.field_arrays:
return self.field_arrays[item]
elif item in _READERS:
# block infinite recursion for copy, pickle
if item == '__setstate__':
raise AttributeError(item)
try:
return self.field_arrays.__getitem__(item)
except KeyError:
pass
try:
reader_cls = _READERS[item]
# add read_*data() support # add read_*data() support
def _read(*args, **kwargs): def _read(*args, **kwargs):
data = _READERS[item]().load(*args, **kwargs)
data = reader_cls().load(*args, **kwargs)
self.extend(data) self.extend(data)
return self return self


return _read return _read
else:
except KeyError:
raise AttributeError('{} does not exist.'.format(item)) raise AttributeError('{} does not exist.'.format(item))


@classmethod @classmethod
@@ -269,3 +269,6 @@ if __name__ == '__main__':
_ = d.a _ = d.a
d.apply(lambda x: x['a']) d.apply(lambda x: x['a'])
print(d[1]) print(d[1])
import copy
dd = copy.deepcopy(d)
print(dd.a)

Loading…
Cancel
Save