Browse Source

add apply to dataset

tags/v0.2.0
yh yunfan 6 years ago
parent
commit
4149eb9c06
1 changed files with 38 additions and 14 deletions
  1. +38
    -14
      fastNLP/core/dataset.py

+ 38
- 14
fastNLP/core/dataset.py View File

@@ -22,7 +22,7 @@ class DataSet(object):


""" """


class DataSetIter(object):
class Instance(object):
def __init__(self, dataset, idx=-1): def __init__(self, dataset, idx=-1):
self.dataset = dataset self.dataset = dataset
self.idx = idx self.idx = idx
@@ -43,18 +43,32 @@ class DataSet(object):
self.dataset[name][self.idx] = val self.dataset[name][self.idx] = val


def __repr__(self): def __repr__(self):
return "\n".join(['{}: {}'.format(name, repr(self.dataset[name][self.idx])) for name in self.dataset.get_fields().keys()])
return "\n".join(['{}: {}'.format(name, repr(self.dataset[name][self.idx])) for name
in self.dataset.get_fields().keys()])


def __init__(self, instance=None):
def __init__(self, data=None):
self.field_arrays = {} self.field_arrays = {}
if instance is not None:
self._convert_ins(instance)
if data is not None:
if isinstance(data, dict):
length_set = set()
for key, value in data.items():
length_set.add(len(value))
assert len(length_set)==1, "Arrays must all be same length."
for key, value in data.items():
self.add_field(name=key, fields=value)
elif isinstance(data, list):
for ins in data:
assert isinstance(ins, Instance), "Must be Instance type, not {}.".format(type(ins))
self.append(ins)

else:
raise ValueError("data only be dict or list type.")


def __contains__(self, item): def __contains__(self, item):
return item in self.field_arrays return item in self.field_arrays


def __iter__(self): def __iter__(self):
return self.DataSetIter(self)
return self.Instance(self)


def _convert_ins(self, ins_list): def _convert_ins(self, ins_list):
if isinstance(ins_list, list): if isinstance(ins_list, list):
@@ -89,7 +103,7 @@ class DataSet(object):


def __getitem__(self, name): def __getitem__(self, name):
if isinstance(name, int): if isinstance(name, int):
return self.DataSetIter(self, idx=name)
return self.Instance(self, idx=name)
elif isinstance(name, str): elif isinstance(name, str):
return self.field_arrays[name] return self.field_arrays[name]
else: else:
@@ -150,6 +164,12 @@ class DataSet(object):
else: else:
return object.__getattribute__(self, name) return object.__getattribute__(self, name)


def __getattr__(self, item):
if item in self.field_arrays:
return self.field_arrays[item]
else:
self.__getattribute__(item)

@classmethod @classmethod
def set_reader(cls, method_name): def set_reader(cls, method_name):
"""decorator to add dataloader support """decorator to add dataloader support
@@ -162,14 +182,18 @@ class DataSet(object):


return wrapper return wrapper


def apply(self, func, new_field_name=None):
results = []
for ins in self:
results.append(func(ins))
if new_field_name is not None:
self.add_field(new_field_name, results)
return results


if __name__ == '__main__': if __name__ == '__main__':
from fastNLP.core.instance import Instance from fastNLP.core.instance import Instance


ins = Instance(test='test0')
dataset = DataSet([ins])
for _iter in dataset:
print(_iter['test'])
_iter['test'] = 'abc'
print(_iter['test'])
print(dataset.field_arrays)
d = DataSet({'a': list('abc')})
d.a
d.apply(lambda x: x['a'])
print(d[1])

Loading…
Cancel
Save