Browse Source

add data iter

tags/v0.2.0
yunfan 5 years ago
parent
commit
dd0bb0d791
2 changed files with 62 additions and 9 deletions
  1. +49
    -8
      fastNLP/core/dataset.py
  2. +13
    -1
      fastNLP/core/fieldarray.py

+ 49
- 8
fastNLP/core/dataset.py View File

@@ -1,5 +1,8 @@
import random
import sys
import sys, os
sys.path.append('../..')
sys.path = [os.path.join(os.path.dirname(__file__), '../..')] + sys.path

from collections import defaultdict
from copy import deepcopy
import numpy as np
@@ -15,12 +18,35 @@ class DataSet(object):
"""A DataSet object is a list of Instance objects.

"""
class DataSetIter(object):
def __init__(self, dataset):
self.dataset = dataset
self.idx = -1

def __next__(self):
self.idx += 1
if self.idx >= len(self.dataset):
raise StopIteration
return self

def __getitem__(self, name):
return self.dataset[name][self.idx]

def __setitem__(self, name, val):
# TODO check new field.
self.dataset[name][self.idx] = val

def __repr__(self):
# TODO
pass

def __init__(self, instance=None):
self.field_arrays = {}
if instance is not None:
self._convert_ins(instance)
else:
self.field_arrays = {}

def __iter__(self):
return self.DataSetIter(self)

def _convert_ins(self, ins_list):
if isinstance(ins_list, list):
@@ -32,23 +58,27 @@ class DataSet(object):
def append(self, ins):
# no field
if len(self.field_arrays) == 0:
for name, field in ins.field.items():
for name, field in ins.fields.items():
self.field_arrays[name] = FieldArray(name, [field])
else:
assert len(self.field_arrays) == len(ins.field)
for name, field in ins.field.items():
assert len(self.field_arrays) == len(ins.fields)
for name, field in ins.fields.items():
assert name in self.field_arrays
self.field_arrays[name].append(field)

def add_field(self, name, fields):
assert len(self) == len(fields)
self.field_arrays[name] = fields
self.field_arrays[name] = FieldArray(name, fields)

def get_fields(self):
return self.field_arrays

def __getitem__(self, name):
assert name in self.field_arrays
return self.field_arrays[name]

def __len__(self):
field = self.field_arrays.values()[0]
field = iter(self.field_arrays.values()).__next__()
return len(field)

def get_length(self):
@@ -125,3 +155,14 @@ class DataSet(object):
_READERS[method_name] = read_cls
return read_cls
return wrapper


if __name__ == '__main__':
from fastNLP.core.instance import Instance
ins = Instance(test='test0')
dataset = DataSet([ins])
for _iter in dataset:
print(_iter['test'])
_iter['test'] = 'abc'
print(_iter['test'])
print(dataset.field_arrays)

+ 13
- 1
fastNLP/core/fieldarray.py View File

@@ -2,19 +2,31 @@ import torch
import numpy as np

class FieldArray(object):
def __init__(self, name, content, padding_val=0, is_target=True, need_tensor=True):
def __init__(self, name, content, padding_val=0, is_target=False, need_tensor=False):
self.name = name
self.content = content
self.padding_val = padding_val
self.is_target = is_target
self.need_tensor = need_tensor

def __repr__(self):
#TODO
return '{}: {}'.format(self.name, self.content.__repr__())

def append(self, val):
self.content.append(val)

def __getitem__(self, name):
return self.get(name)

def __setitem__(self, name, val):
assert isinstance(name, int)
self.content[name] = val

def get(self, idxes):
if isinstance(idxes, int):
return self.content[idxes]
assert self.need_tensor is True
batch_size = len(idxes)
max_len = max([len(self.content[i]) for i in idxes])
array = np.full((batch_size, max_len), self.padding_val, dtype=np.int32)


Loading…
Cancel
Save