|
|
@@ -1,5 +1,8 @@ |
|
|
|
import random |
|
|
|
import sys |
|
|
|
import sys, os |
|
|
|
sys.path.append('../..') |
|
|
|
sys.path = [os.path.join(os.path.dirname(__file__), '../..')] + sys.path |
|
|
|
|
|
|
|
from collections import defaultdict |
|
|
|
from copy import deepcopy |
|
|
|
import numpy as np |
|
|
@@ -15,12 +18,35 @@ class DataSet(object): |
|
|
|
"""A DataSet object is a list of Instance objects. |
|
|
|
|
|
|
|
""" |
|
|
|
class DataSetIter(object): |
|
|
|
def __init__(self, dataset): |
|
|
|
self.dataset = dataset |
|
|
|
self.idx = -1 |
|
|
|
|
|
|
|
def __next__(self): |
|
|
|
self.idx += 1 |
|
|
|
if self.idx >= len(self.dataset): |
|
|
|
raise StopIteration |
|
|
|
return self |
|
|
|
|
|
|
|
def __getitem__(self, name): |
|
|
|
return self.dataset[name][self.idx] |
|
|
|
|
|
|
|
def __setitem__(self, name, val): |
|
|
|
# TODO check new field. |
|
|
|
self.dataset[name][self.idx] = val |
|
|
|
|
|
|
|
def __repr__(self): |
|
|
|
# TODO |
|
|
|
pass |
|
|
|
|
|
|
|
def __init__(self, instance=None): |
|
|
|
self.field_arrays = {} |
|
|
|
if instance is not None: |
|
|
|
self._convert_ins(instance) |
|
|
|
else: |
|
|
|
self.field_arrays = {} |
|
|
|
|
|
|
|
def __iter__(self): |
|
|
|
return self.DataSetIter(self) |
|
|
|
|
|
|
|
def _convert_ins(self, ins_list): |
|
|
|
if isinstance(ins_list, list): |
|
|
@@ -32,23 +58,27 @@ class DataSet(object): |
|
|
|
def append(self, ins): |
|
|
|
# no field |
|
|
|
if len(self.field_arrays) == 0: |
|
|
|
for name, field in ins.field.items(): |
|
|
|
for name, field in ins.fields.items(): |
|
|
|
self.field_arrays[name] = FieldArray(name, [field]) |
|
|
|
else: |
|
|
|
assert len(self.field_arrays) == len(ins.field) |
|
|
|
for name, field in ins.field.items(): |
|
|
|
assert len(self.field_arrays) == len(ins.fields) |
|
|
|
for name, field in ins.fields.items(): |
|
|
|
assert name in self.field_arrays |
|
|
|
self.field_arrays[name].append(field) |
|
|
|
|
|
|
|
def add_field(self, name, fields): |
|
|
|
assert len(self) == len(fields) |
|
|
|
self.field_arrays[name] = fields |
|
|
|
self.field_arrays[name] = FieldArray(name, fields) |
|
|
|
|
|
|
|
def get_fields(self): |
|
|
|
return self.field_arrays |
|
|
|
|
|
|
|
def __getitem__(self, name): |
|
|
|
assert name in self.field_arrays |
|
|
|
return self.field_arrays[name] |
|
|
|
|
|
|
|
def __len__(self): |
|
|
|
field = self.field_arrays.values()[0] |
|
|
|
field = iter(self.field_arrays.values()).__next__() |
|
|
|
return len(field) |
|
|
|
|
|
|
|
def get_length(self): |
|
|
@@ -125,3 +155,14 @@ class DataSet(object): |
|
|
|
_READERS[method_name] = read_cls |
|
|
|
return read_cls |
|
|
|
return wrapper |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
from fastNLP.core.instance import Instance |
|
|
|
ins = Instance(test='test0') |
|
|
|
dataset = DataSet([ins]) |
|
|
|
for _iter in dataset: |
|
|
|
print(_iter['test']) |
|
|
|
_iter['test'] = 'abc' |
|
|
|
print(_iter['test']) |
|
|
|
print(dataset.field_arrays) |