Browse Source

Merge branch 'dataset' of https://github.com/yhcc/fastNLP into dataset

tags/v0.2.0
yh 5 years ago
parent
commit
217cab94d1
3 changed files with 73 additions and 57 deletions
  1. +53
    -8
      fastNLP/core/dataset.py
  2. +0
    -29
      fastNLP/core/field.py
  3. +20
    -20
      fastNLP/core/fieldarray.py

+ 53
- 8
fastNLP/core/dataset.py View File

@@ -1,5 +1,8 @@
import random
import sys
import sys, os
sys.path.append('../..')
sys.path = [os.path.join(os.path.dirname(__file__), '../..')] + sys.path

from collections import defaultdict
from copy import deepcopy
import numpy as np
@@ -15,36 +18,67 @@ class DataSet(object):
"""A DataSet object is a list of Instance objects.

"""
class DataSetIter(object):
def __init__(self, dataset):
self.dataset = dataset
self.idx = -1

def __next__(self):
self.idx += 1
if self.idx >= len(self.dataset):
raise StopIteration
return self

def __getitem__(self, name):
return self.dataset[name][self.idx]

def __setitem__(self, name, val):
# TODO check new field.
self.dataset[name][self.idx] = val

def __repr__(self):
# TODO
pass

def __init__(self, instance=None):
self.field_arrays = {}
if instance is not None:
self._convert_ins(instance)
else:
self.field_arrays = {}

def __iter__(self):
return self.DataSetIter(self)

def _convert_ins(self, ins_list):
if isinstance(ins_list, list):
for ins in ins_list:
self.append(ins)
else:
self.append(ins)
self.append(ins_list)

def append(self, ins):
# no field
if len(self.field_arrays) == 0:
for name, field in ins.field.items():
for name, field in ins.fields.items():
self.field_arrays[name] = FieldArray(name, [field])
else:
assert len(self.field_arrays) == len(ins.field)
for name, field in ins.field.items():
assert len(self.field_arrays) == len(ins.fields)
for name, field in ins.fields.items():
assert name in self.field_arrays
self.field_arrays[name].append(field)

def add_field(self, name, fields):
assert len(self) == len(fields)
self.field_arrays[name] = FieldArray(name, fields)

def get_fields(self):
return self.field_arrays

def __getitem__(self, name):
assert name in self.field_arrays
return self.field_arrays[name]

def __len__(self):
field = self.field_arrays.values()[0]
field = iter(self.field_arrays.values()).__next__()
return len(field)

def get_length(self):
@@ -121,3 +155,14 @@ class DataSet(object):
_READERS[method_name] = read_cls
return read_cls
return wrapper


if __name__ == '__main__':
from fastNLP.core.instance import Instance
ins = Instance(test='test0')
dataset = DataSet([ins])
for _iter in dataset:
print(_iter['test'])
_iter['test'] = 'abc'
print(_iter['test'])
print(dataset.field_arrays)

+ 0
- 29
fastNLP/core/field.py View File

@@ -39,35 +39,6 @@ class TextField(Field):
super(TextField, self).__init__(text, is_target)


class IndexField(Field):
def __init__(self, name, content, vocab, is_target):
super(IndexField, self).__init__(name, is_target)
self.content = []
self.padding_idx = vocab.padding_idx
for sent in content:
idx = vocab.index_sent(sent)
if isinstance(idx, list):
idx = torch.Tensor(idx)
elif isinstance(idx, np.array):
idx = torch.from_numpy(idx)
elif not isinstance(idx, torch.Tensor):
raise ValueError
self.content.append(idx)

def to_tensor(self, id_list, sort_within_batch=False):
max_len = max(id_list)
batch_size = len(id_list)
tensor = torch.full((batch_size, max_len), self.padding_idx, dtype=torch.long)
len_list = [(i, self.content[i].size(0)) for i in id_list]
if sort_within_batch:
len_list = sorted(len_list, key=lambda x: x[1], reverse=True)
for i, (idx, length) in enumerate(len_list):
if length == max_len:
tensor[i] = self.content[idx]
else:
tensor[i][:length] = self.content[idx]
return tensor

class LabelField(Field):
"""The Field representing a single label. Can be a string or integer.



+ 20
- 20
fastNLP/core/fieldarray.py View File

@@ -2,38 +2,38 @@ import torch
import numpy as np

class FieldArray(object):
def __init__(self, name, content, padding_val=0, is_target=True, need_tensor=True):
def __init__(self, name, content, padding_val=0, is_target=False, need_tensor=False):
self.name = name
self.data = [self._convert_np(val) for val in content]
self.content = content
self.padding_val = padding_val
self.is_target = is_target
self.need_tensor = need_tensor

def _convert_np(self, val):
if not isinstance(val, np.array):
return np.array(val)
else:
return val
def __repr__(self):
#TODO
return '{}: {}'.format(self.name, self.content.__repr__())

def append(self, val):
self.data.append(self._convert_np(val))
self.content.append(val)

def __getitem__(self, name):
return self.get(name)

def __setitem__(self, name, val):
assert isinstance(name, int)
self.content[name] = val

def get(self, idxes):
if isinstance(idxes, int):
return self.data[idxes]
elif isinstance(idxes, list):
id_list = np.array(idxes)
batch_size = len(id_list)
len_list = [(i, self.data[i].shape[0]) for i in id_list]
_, max_len = max(len_list, key=lambda x: x[1])
return self.content[idxes]
assert self.need_tensor is True
batch_size = len(idxes)
max_len = max([len(self.content[i]) for i in idxes])
array = np.full((batch_size, max_len), self.padding_val, dtype=np.int32)

for i, (idx, length) in enumerate(len_list):
if length == max_len:
array[i] = self.data[idx]
else:
array[i][:length] = self.data[idx]
for i, idx in enumerate(idxes):
array[i][:len(self.content[idx])] = self.content[idx]
return array

def __len__(self):
return len(self.data)
return len(self.content)

Loading…
Cancel
Save