|
@@ -22,7 +22,7 @@ class DataSet(object): |
|
|
|
|
|
|
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
class DataSetIter(object): |
|
|
|
|
|
|
|
|
class Instance(object): |
|
|
def __init__(self, dataset, idx=-1): |
|
|
def __init__(self, dataset, idx=-1): |
|
|
self.dataset = dataset |
|
|
self.dataset = dataset |
|
|
self.idx = idx |
|
|
self.idx = idx |
|
@@ -43,18 +43,32 @@ class DataSet(object): |
|
|
self.dataset[name][self.idx] = val |
|
|
self.dataset[name][self.idx] = val |
|
|
|
|
|
|
|
|
def __repr__(self): |
|
|
def __repr__(self): |
|
|
return "\n".join(['{}: {}'.format(name, repr(self.dataset[name][self.idx])) for name in self.dataset.get_fields().keys()]) |
|
|
|
|
|
|
|
|
return "\n".join(['{}: {}'.format(name, repr(self.dataset[name][self.idx])) for name |
|
|
|
|
|
in self.dataset.get_fields().keys()]) |
|
|
|
|
|
|
|
|
def __init__(self, instance=None): |
|
|
|
|
|
|
|
|
def __init__(self, data=None): |
|
|
self.field_arrays = {} |
|
|
self.field_arrays = {} |
|
|
if instance is not None: |
|
|
|
|
|
self._convert_ins(instance) |
|
|
|
|
|
|
|
|
if data is not None: |
|
|
|
|
|
if isinstance(data, dict): |
|
|
|
|
|
length_set = set() |
|
|
|
|
|
for key, value in data.items(): |
|
|
|
|
|
length_set.add(len(value)) |
|
|
|
|
|
assert len(length_set)==1, "Arrays must all be same length." |
|
|
|
|
|
for key, value in data.items(): |
|
|
|
|
|
self.add_field(name=key, fields=value) |
|
|
|
|
|
elif isinstance(data, list): |
|
|
|
|
|
for ins in data: |
|
|
|
|
|
assert isinstance(ins, Instance), "Must be Instance type, not {}.".format(type(ins)) |
|
|
|
|
|
self.append(ins) |
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
raise ValueError("data only be dict or list type.") |
|
|
|
|
|
|
|
|
def __contains__(self, item): |
|
|
def __contains__(self, item): |
|
|
return item in self.field_arrays |
|
|
return item in self.field_arrays |
|
|
|
|
|
|
|
|
def __iter__(self): |
|
|
def __iter__(self): |
|
|
return self.DataSetIter(self) |
|
|
|
|
|
|
|
|
return self.Instance(self) |
|
|
|
|
|
|
|
|
def _convert_ins(self, ins_list): |
|
|
def _convert_ins(self, ins_list): |
|
|
if isinstance(ins_list, list): |
|
|
if isinstance(ins_list, list): |
|
@@ -89,7 +103,7 @@ class DataSet(object): |
|
|
|
|
|
|
|
|
def __getitem__(self, name): |
|
|
def __getitem__(self, name): |
|
|
if isinstance(name, int): |
|
|
if isinstance(name, int): |
|
|
return self.DataSetIter(self, idx=name) |
|
|
|
|
|
|
|
|
return self.Instance(self, idx=name) |
|
|
elif isinstance(name, str): |
|
|
elif isinstance(name, str): |
|
|
return self.field_arrays[name] |
|
|
return self.field_arrays[name] |
|
|
else: |
|
|
else: |
|
@@ -150,6 +164,12 @@ class DataSet(object): |
|
|
else: |
|
|
else: |
|
|
return object.__getattribute__(self, name) |
|
|
return object.__getattribute__(self, name) |
|
|
|
|
|
|
|
|
|
|
|
def __getattr__(self, item): |
|
|
|
|
|
if item in self.field_arrays: |
|
|
|
|
|
return self.field_arrays[item] |
|
|
|
|
|
else: |
|
|
|
|
|
self.__getattribute__(item) |
|
|
|
|
|
|
|
|
@classmethod |
|
|
@classmethod |
|
|
def set_reader(cls, method_name): |
|
|
def set_reader(cls, method_name): |
|
|
"""decorator to add dataloader support |
|
|
"""decorator to add dataloader support |
|
@@ -162,14 +182,18 @@ class DataSet(object): |
|
|
|
|
|
|
|
|
return wrapper |
|
|
return wrapper |
|
|
|
|
|
|
|
|
|
|
|
def apply(self, func, new_field_name=None): |
|
|
|
|
|
results = [] |
|
|
|
|
|
for ins in self: |
|
|
|
|
|
results.append(func(ins)) |
|
|
|
|
|
if new_field_name is not None: |
|
|
|
|
|
self.add_field(new_field_name, results) |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
if __name__ == '__main__': |
|
|
from fastNLP.core.instance import Instance |
|
|
from fastNLP.core.instance import Instance |
|
|
|
|
|
|
|
|
ins = Instance(test='test0') |
|
|
|
|
|
dataset = DataSet([ins]) |
|
|
|
|
|
for _iter in dataset: |
|
|
|
|
|
print(_iter['test']) |
|
|
|
|
|
_iter['test'] = 'abc' |
|
|
|
|
|
print(_iter['test']) |
|
|
|
|
|
print(dataset.field_arrays) |
|
|
|
|
|
|
|
|
d = DataSet({'a': list('abc')}) |
|
|
|
|
|
d.a |
|
|
|
|
|
d.apply(lambda x: x['a']) |
|
|
|
|
|
print(d[1]) |