@@ -1,5 +1,8 @@ | |||||
import random | import random | ||||
import sys | |||||
import sys, os | |||||
sys.path.append('../..') | |||||
sys.path = [os.path.join(os.path.dirname(__file__), '../..')] + sys.path | |||||
from collections import defaultdict | from collections import defaultdict | ||||
from copy import deepcopy | from copy import deepcopy | ||||
import numpy as np | import numpy as np | ||||
@@ -15,36 +18,67 @@ class DataSet(object): | |||||
"""A DataSet object is a list of Instance objects. | """A DataSet object is a list of Instance objects. | ||||
""" | """ | ||||
class DataSetIter(object): | |||||
def __init__(self, dataset): | |||||
self.dataset = dataset | |||||
self.idx = -1 | |||||
def __next__(self): | |||||
self.idx += 1 | |||||
if self.idx >= len(self.dataset): | |||||
raise StopIteration | |||||
return self | |||||
def __getitem__(self, name): | |||||
return self.dataset[name][self.idx] | |||||
def __setitem__(self, name, val): | |||||
# TODO check new field. | |||||
self.dataset[name][self.idx] = val | |||||
def __repr__(self): | |||||
# TODO | |||||
pass | |||||
def __init__(self, instance=None): | def __init__(self, instance=None): | ||||
self.field_arrays = {} | |||||
if instance is not None: | if instance is not None: | ||||
self._convert_ins(instance) | self._convert_ins(instance) | ||||
else: | |||||
self.field_arrays = {} | |||||
def __iter__(self): | |||||
return self.DataSetIter(self) | |||||
def _convert_ins(self, ins_list): | def _convert_ins(self, ins_list): | ||||
if isinstance(ins_list, list): | if isinstance(ins_list, list): | ||||
for ins in ins_list: | for ins in ins_list: | ||||
self.append(ins) | self.append(ins) | ||||
else: | else: | ||||
self.append(ins) | |||||
self.append(ins_list) | |||||
def append(self, ins): | def append(self, ins): | ||||
# no field | # no field | ||||
if len(self.field_arrays) == 0: | if len(self.field_arrays) == 0: | ||||
for name, field in ins.field.items(): | |||||
for name, field in ins.fields.items(): | |||||
self.field_arrays[name] = FieldArray(name, [field]) | self.field_arrays[name] = FieldArray(name, [field]) | ||||
else: | else: | ||||
assert len(self.field_arrays) == len(ins.field) | |||||
for name, field in ins.field.items(): | |||||
assert len(self.field_arrays) == len(ins.fields) | |||||
for name, field in ins.fields.items(): | |||||
assert name in self.field_arrays | assert name in self.field_arrays | ||||
self.field_arrays[name].append(field) | self.field_arrays[name].append(field) | ||||
def add_field(self, name, fields): | |||||
assert len(self) == len(fields) | |||||
self.field_arrays[name] = FieldArray(name, fields) | |||||
def get_fields(self): | def get_fields(self): | ||||
return self.field_arrays | return self.field_arrays | ||||
def __getitem__(self, name): | |||||
assert name in self.field_arrays | |||||
return self.field_arrays[name] | |||||
def __len__(self): | def __len__(self): | ||||
field = self.field_arrays.values()[0] | |||||
field = iter(self.field_arrays.values()).__next__() | |||||
return len(field) | return len(field) | ||||
def get_length(self): | def get_length(self): | ||||
@@ -121,3 +155,14 @@ class DataSet(object): | |||||
_READERS[method_name] = read_cls | _READERS[method_name] = read_cls | ||||
return read_cls | return read_cls | ||||
return wrapper | return wrapper | ||||
if __name__ == '__main__': | |||||
from fastNLP.core.instance import Instance | |||||
ins = Instance(test='test0') | |||||
dataset = DataSet([ins]) | |||||
for _iter in dataset: | |||||
print(_iter['test']) | |||||
_iter['test'] = 'abc' | |||||
print(_iter['test']) | |||||
print(dataset.field_arrays) |
@@ -39,35 +39,6 @@ class TextField(Field): | |||||
super(TextField, self).__init__(text, is_target) | super(TextField, self).__init__(text, is_target) | ||||
class IndexField(Field): | |||||
def __init__(self, name, content, vocab, is_target): | |||||
super(IndexField, self).__init__(name, is_target) | |||||
self.content = [] | |||||
self.padding_idx = vocab.padding_idx | |||||
for sent in content: | |||||
idx = vocab.index_sent(sent) | |||||
if isinstance(idx, list): | |||||
idx = torch.Tensor(idx) | |||||
elif isinstance(idx, np.array): | |||||
idx = torch.from_numpy(idx) | |||||
elif not isinstance(idx, torch.Tensor): | |||||
raise ValueError | |||||
self.content.append(idx) | |||||
def to_tensor(self, id_list, sort_within_batch=False): | |||||
max_len = max(id_list) | |||||
batch_size = len(id_list) | |||||
tensor = torch.full((batch_size, max_len), self.padding_idx, dtype=torch.long) | |||||
len_list = [(i, self.content[i].size(0)) for i in id_list] | |||||
if sort_within_batch: | |||||
len_list = sorted(len_list, key=lambda x: x[1], reverse=True) | |||||
for i, (idx, length) in enumerate(len_list): | |||||
if length == max_len: | |||||
tensor[i] = self.content[idx] | |||||
else: | |||||
tensor[i][:length] = self.content[idx] | |||||
return tensor | |||||
class LabelField(Field): | class LabelField(Field): | ||||
"""The Field representing a single label. Can be a string or integer. | """The Field representing a single label. Can be a string or integer. | ||||
@@ -2,38 +2,38 @@ import torch | |||||
import numpy as np | import numpy as np | ||||
class FieldArray(object): | class FieldArray(object): | ||||
def __init__(self, name, content, padding_val=0, is_target=True, need_tensor=True): | |||||
def __init__(self, name, content, padding_val=0, is_target=False, need_tensor=False): | |||||
self.name = name | self.name = name | ||||
self.data = [self._convert_np(val) for val in content] | |||||
self.content = content | |||||
self.padding_val = padding_val | self.padding_val = padding_val | ||||
self.is_target = is_target | self.is_target = is_target | ||||
self.need_tensor = need_tensor | self.need_tensor = need_tensor | ||||
def _convert_np(self, val): | |||||
if not isinstance(val, np.array): | |||||
return np.array(val) | |||||
else: | |||||
return val | |||||
def __repr__(self): | |||||
#TODO | |||||
return '{}: {}'.format(self.name, self.content.__repr__()) | |||||
def append(self, val): | def append(self, val): | ||||
self.data.append(self._convert_np(val)) | |||||
self.content.append(val) | |||||
def __getitem__(self, name): | |||||
return self.get(name) | |||||
def __setitem__(self, name, val): | |||||
assert isinstance(name, int) | |||||
self.content[name] = val | |||||
def get(self, idxes): | def get(self, idxes): | ||||
if isinstance(idxes, int): | if isinstance(idxes, int): | ||||
return self.data[idxes] | |||||
elif isinstance(idxes, list): | |||||
id_list = np.array(idxes) | |||||
batch_size = len(id_list) | |||||
len_list = [(i, self.data[i].shape[0]) for i in id_list] | |||||
_, max_len = max(len_list, key=lambda x: x[1]) | |||||
return self.content[idxes] | |||||
assert self.need_tensor is True | |||||
batch_size = len(idxes) | |||||
max_len = max([len(self.content[i]) for i in idxes]) | |||||
array = np.full((batch_size, max_len), self.padding_val, dtype=np.int32) | array = np.full((batch_size, max_len), self.padding_val, dtype=np.int32) | ||||
for i, (idx, length) in enumerate(len_list): | |||||
if length == max_len: | |||||
array[i] = self.data[idx] | |||||
else: | |||||
array[i][:length] = self.data[idx] | |||||
for i, idx in enumerate(idxes): | |||||
array[i][:len(self.content[idx])] = self.content[idx] | |||||
return array | return array | ||||
def __len__(self): | def __len__(self): | ||||
return len(self.data) | |||||
return len(self.content) |