@@ -1,5 +1,8 @@ | |||
import random | |||
import sys | |||
import sys, os | |||
sys.path.append('../..') | |||
sys.path = [os.path.join(os.path.dirname(__file__), '../..')] + sys.path | |||
from collections import defaultdict | |||
from copy import deepcopy | |||
import numpy as np | |||
@@ -15,36 +18,67 @@ class DataSet(object): | |||
"""A DataSet object is a list of Instance objects. | |||
""" | |||
class DataSetIter(object): | |||
def __init__(self, dataset): | |||
self.dataset = dataset | |||
self.idx = -1 | |||
def __next__(self): | |||
self.idx += 1 | |||
if self.idx >= len(self.dataset): | |||
raise StopIteration | |||
return self | |||
def __getitem__(self, name): | |||
return self.dataset[name][self.idx] | |||
def __setitem__(self, name, val): | |||
# TODO check new field. | |||
self.dataset[name][self.idx] = val | |||
def __repr__(self): | |||
# TODO | |||
pass | |||
def __init__(self, instance=None): | |||
self.field_arrays = {} | |||
if instance is not None: | |||
self._convert_ins(instance) | |||
else: | |||
self.field_arrays = {} | |||
def __iter__(self): | |||
return self.DataSetIter(self) | |||
def _convert_ins(self, ins_list): | |||
if isinstance(ins_list, list): | |||
for ins in ins_list: | |||
self.append(ins) | |||
else: | |||
self.append(ins) | |||
self.append(ins_list) | |||
def append(self, ins): | |||
# no field | |||
if len(self.field_arrays) == 0: | |||
for name, field in ins.field.items(): | |||
for name, field in ins.fields.items(): | |||
self.field_arrays[name] = FieldArray(name, [field]) | |||
else: | |||
assert len(self.field_arrays) == len(ins.field) | |||
for name, field in ins.field.items(): | |||
assert len(self.field_arrays) == len(ins.fields) | |||
for name, field in ins.fields.items(): | |||
assert name in self.field_arrays | |||
self.field_arrays[name].append(field) | |||
def add_field(self, name, fields): | |||
assert len(self) == len(fields) | |||
self.field_arrays[name] = FieldArray(name, fields) | |||
def get_fields(self): | |||
return self.field_arrays | |||
def __getitem__(self, name): | |||
assert name in self.field_arrays | |||
return self.field_arrays[name] | |||
def __len__(self): | |||
field = self.field_arrays.values()[0] | |||
field = iter(self.field_arrays.values()).__next__() | |||
return len(field) | |||
def get_length(self): | |||
@@ -121,3 +155,14 @@ class DataSet(object): | |||
_READERS[method_name] = read_cls | |||
return read_cls | |||
return wrapper | |||
if __name__ == '__main__': | |||
from fastNLP.core.instance import Instance | |||
ins = Instance(test='test0') | |||
dataset = DataSet([ins]) | |||
for _iter in dataset: | |||
print(_iter['test']) | |||
_iter['test'] = 'abc' | |||
print(_iter['test']) | |||
print(dataset.field_arrays) |
@@ -39,35 +39,6 @@ class TextField(Field): | |||
super(TextField, self).__init__(text, is_target) | |||
class IndexField(Field): | |||
def __init__(self, name, content, vocab, is_target): | |||
super(IndexField, self).__init__(name, is_target) | |||
self.content = [] | |||
self.padding_idx = vocab.padding_idx | |||
for sent in content: | |||
idx = vocab.index_sent(sent) | |||
if isinstance(idx, list): | |||
idx = torch.Tensor(idx) | |||
elif isinstance(idx, np.array): | |||
idx = torch.from_numpy(idx) | |||
elif not isinstance(idx, torch.Tensor): | |||
raise ValueError | |||
self.content.append(idx) | |||
def to_tensor(self, id_list, sort_within_batch=False): | |||
max_len = max(id_list) | |||
batch_size = len(id_list) | |||
tensor = torch.full((batch_size, max_len), self.padding_idx, dtype=torch.long) | |||
len_list = [(i, self.content[i].size(0)) for i in id_list] | |||
if sort_within_batch: | |||
len_list = sorted(len_list, key=lambda x: x[1], reverse=True) | |||
for i, (idx, length) in enumerate(len_list): | |||
if length == max_len: | |||
tensor[i] = self.content[idx] | |||
else: | |||
tensor[i][:length] = self.content[idx] | |||
return tensor | |||
class LabelField(Field): | |||
"""The Field representing a single label. Can be a string or integer. | |||
@@ -2,38 +2,38 @@ import torch | |||
import numpy as np | |||
class FieldArray(object): | |||
def __init__(self, name, content, padding_val=0, is_target=True, need_tensor=True): | |||
def __init__(self, name, content, padding_val=0, is_target=False, need_tensor=False): | |||
self.name = name | |||
self.data = [self._convert_np(val) for val in content] | |||
self.content = content | |||
self.padding_val = padding_val | |||
self.is_target = is_target | |||
self.need_tensor = need_tensor | |||
def _convert_np(self, val): | |||
if not isinstance(val, np.array): | |||
return np.array(val) | |||
else: | |||
return val | |||
def __repr__(self): | |||
#TODO | |||
return '{}: {}'.format(self.name, self.content.__repr__()) | |||
def append(self, val): | |||
self.data.append(self._convert_np(val)) | |||
self.content.append(val) | |||
def __getitem__(self, name): | |||
return self.get(name) | |||
def __setitem__(self, name, val): | |||
assert isinstance(name, int) | |||
self.content[name] = val | |||
def get(self, idxes): | |||
if isinstance(idxes, int): | |||
return self.data[idxes] | |||
elif isinstance(idxes, list): | |||
id_list = np.array(idxes) | |||
batch_size = len(id_list) | |||
len_list = [(i, self.data[i].shape[0]) for i in id_list] | |||
_, max_len = max(len_list, key=lambda x: x[1]) | |||
return self.content[idxes] | |||
assert self.need_tensor is True | |||
batch_size = len(idxes) | |||
max_len = max([len(self.content[i]) for i in idxes]) | |||
array = np.full((batch_size, max_len), self.padding_val, dtype=np.int32) | |||
for i, (idx, length) in enumerate(len_list): | |||
if length == max_len: | |||
array[i] = self.data[idx] | |||
else: | |||
array[i][:length] = self.data[idx] | |||
for i, idx in enumerate(idxes): | |||
array[i][:len(self.content[idx])] = self.content[idx] | |||
return array | |||
def __len__(self): | |||
return len(self.data) | |||
return len(self.content) |