Browse Source

update Instance

tags/v0.2.0
yunfan 6 years ago
parent
commit
713510f65b
2 changed files with 28 additions and 28 deletions
  1. +19
    -15
      fastNLP/core/dataset.py
  2. +9
    -13
      fastNLP/core/instance.py

+ 19
- 15
fastNLP/core/dataset.py View File

@@ -34,21 +34,29 @@ class DataSet(object):


def __next__(self): def __next__(self):
self.idx += 1 self.idx += 1
try:
self.fields = {name: field[self.idx] for name, field in self.dataset.get_fields().items()}
except IndexError:
if self.idx >= len(self.dataset):
raise StopIteration raise StopIteration
return self return self


def __getitem__(self, name): def __getitem__(self, name):
return self.fields[name]
return self.dataset[name][self.idx]


def __setitem__(self, name, val): def __setitem__(self, name, val):
if name not in self.dataset: if name not in self.dataset:
new_fields = [None] * len(self.dataset) new_fields = [None] * len(self.dataset)
self.dataset.add_field(name, new_fields) self.dataset.add_field(name, new_fields)
self.dataset[name][self.idx] = val self.dataset[name][self.idx] = val
self.fields[name] = val

def __getattr__(self, item):
if item == 'fields':
self.fields = {name: field[self.idx] for name, field in self.dataset.get_fields().items()}
return self.fields
else:
raise AttributeError('{} does not exist.'.format(item))

def __setattr__(self, key, value):
self.__setitem__(key, value)



def __repr__(self): def __repr__(self):
return "\n".join(['{}: {}'.format(name, repr(self.dataset[name][self.idx])) for name return "\n".join(['{}: {}'.format(name, repr(self.dataset[name][self.idx])) for name
@@ -201,23 +209,19 @@ class DataSet(object):
raise KeyError raise KeyError
return self return self


def __getattribute__(self, name):
if name in _READERS:
def __getattr__(self, item):
if item in self.field_arrays:
return self.field_arrays[item]
elif item in _READERS:
# add read_*data() support # add read_*data() support
def _read(*args, **kwargs): def _read(*args, **kwargs):
data = _READERS[name]().load(*args, **kwargs)
data = _READERS[item]().load(*args, **kwargs)
self.extend(data) self.extend(data)
return self return self


return _read return _read
else: else:
return object.__getattribute__(self, name)

def __getattr__(self, item):
if item in self.field_arrays:
return self.field_arrays[item]
else:
self.__getattribute__(item)
raise AttributeError('{} does not exist.'.format(item))


@classmethod @classmethod
def set_reader(cls, method_name): def set_reader(cls, method_name):


+ 9
- 13
fastNLP/core/instance.py View File

@@ -12,19 +12,6 @@ class Instance(object):
self.fields[field_name] = field self.fields[field_name] = field
return self return self


def rename_field(self, old_name, new_name):
if old_name in self.fields:
self.fields[new_name] = self.fields.pop(old_name)
else:
raise KeyError("error, no such field: {}".format(old_name))
return self

def set_target(self, **fields):
for name, val in fields.items():
if name in self.fields:
self.fields[name].is_target = val
return self

def __getitem__(self, name): def __getitem__(self, name):
if name in self.fields: if name in self.fields:
return self.fields[name] return self.fields[name]
@@ -34,5 +21,14 @@ class Instance(object):
def __setitem__(self, name, field): def __setitem__(self, name, field):
return self.add_field(name, field) return self.add_field(name, field)


def __getattr__(self, item):
if item in self.fields:
return self.fields[item]
else:
raise AttributeError('{} does not exist.'.format(item))

def __setattr__(self, key, value):
self.__setitem__(key, value)

def __repr__(self): def __repr__(self):
return self.fields.__repr__() return self.fields.__repr__()

Loading…
Cancel
Save