Browse Source

fix Dataset

tags/v0.2.0
yunfan 6 years ago
parent
commit
92da53a65b
1 changed files with 14 additions and 11 deletions
  1. +14
    -11
      fastNLP/core/dataset.py

+ 14
- 11
fastNLP/core/dataset.py View File

@@ -54,12 +54,6 @@ class DataSet(object):
else:
raise AttributeError('{} does not exist.'.format(item))

def __setattr__(self, key, value):
if hasattr(self, 'fields'):
self.__setitem__(key, value)
else:
super().__setattr__(self, key, value)

def __repr__(self):
return "\n".join(['{}: {}'.format(name, repr(self.dataset[name][self.idx])) for name
in self.dataset.get_fields().keys()])
@@ -205,17 +199,23 @@ class DataSet(object):
return [name for name, field in self.field_arrays.items() if field.is_target]

def __getattr__(self, item):
if item in self.field_arrays:
return self.field_arrays[item]
elif item in _READERS:
# block infinite recursion for copy, pickle
if item == '__setstate__':
raise AttributeError(item)
try:
return self.field_arrays.__getitem__(item)
except KeyError:
pass
try:
reader_cls = _READERS[item]
# add read_*data() support
def _read(*args, **kwargs):
data = _READERS[item]().load(*args, **kwargs)
data = reader_cls().load(*args, **kwargs)
self.extend(data)
return self

return _read
else:
except KeyError:
raise AttributeError('{} does not exist.'.format(item))

@classmethod
@@ -269,3 +269,6 @@ if __name__ == '__main__':
_ = d.a
d.apply(lambda x: x['a'])
print(d[1])
import copy
dd = copy.deepcopy(d)
print(dd.a)

Loading…
Cancel
Save