@@ -27,7 +27,7 @@ class DataSet(object): | |||||
for ins in ins_list: | for ins in ins_list: | ||||
self.append(ins) | self.append(ins) | ||||
else: | else: | ||||
self.append(ins) | |||||
self.append(ins_list) | |||||
def append(self, ins): | def append(self, ins): | ||||
# no field | # no field | ||||
@@ -40,6 +40,10 @@ class DataSet(object): | |||||
assert name in self.field_arrays | assert name in self.field_arrays | ||||
self.field_arrays[name].append(field) | self.field_arrays[name].append(field) | ||||
def add_field(self, name, fields): | |||||
assert len(self) == len(fields) | |||||
self.field_arrays[name] = fields | |||||
def get_fields(self): | def get_fields(self): | ||||
return self.field_arrays | return self.field_arrays | ||||
@@ -39,35 +39,6 @@ class TextField(Field): | |||||
super(TextField, self).__init__(text, is_target) | super(TextField, self).__init__(text, is_target) | ||||
class IndexField(Field): | |||||
def __init__(self, name, content, vocab, is_target): | |||||
super(IndexField, self).__init__(name, is_target) | |||||
self.content = [] | |||||
self.padding_idx = vocab.padding_idx | |||||
for sent in content: | |||||
idx = vocab.index_sent(sent) | |||||
if isinstance(idx, list): | |||||
idx = torch.Tensor(idx) | |||||
elif isinstance(idx, np.array): | |||||
idx = torch.from_numpy(idx) | |||||
elif not isinstance(idx, torch.Tensor): | |||||
raise ValueError | |||||
self.content.append(idx) | |||||
def to_tensor(self, id_list, sort_within_batch=False): | |||||
max_len = max(id_list) | |||||
batch_size = len(id_list) | |||||
tensor = torch.full((batch_size, max_len), self.padding_idx, dtype=torch.long) | |||||
len_list = [(i, self.content[i].size(0)) for i in id_list] | |||||
if sort_within_batch: | |||||
len_list = sorted(len_list, key=lambda x: x[1], reverse=True) | |||||
for i, (idx, length) in enumerate(len_list): | |||||
if length == max_len: | |||||
tensor[i] = self.content[idx] | |||||
else: | |||||
tensor[i][:length] = self.content[idx] | |||||
return tensor | |||||
class LabelField(Field): | class LabelField(Field): | ||||
"""The Field representing a single label. Can be a string or integer. | """The Field representing a single label. Can be a string or integer. | ||||
@@ -4,36 +4,24 @@ import numpy as np | |||||
class FieldArray(object): | class FieldArray(object): | ||||
def __init__(self, name, content, padding_val=0, is_target=True, need_tensor=True): | def __init__(self, name, content, padding_val=0, is_target=True, need_tensor=True): | ||||
self.name = name | self.name = name | ||||
self.data = [self._convert_np(val) for val in content] | |||||
self.content = content | |||||
self.padding_val = padding_val | self.padding_val = padding_val | ||||
self.is_target = is_target | self.is_target = is_target | ||||
self.need_tensor = need_tensor | self.need_tensor = need_tensor | ||||
def _convert_np(self, val): | |||||
if not isinstance(val, np.array): | |||||
return np.array(val) | |||||
else: | |||||
return val | |||||
def append(self, val): | def append(self, val): | ||||
self.data.append(self._convert_np(val)) | |||||
self.content.append(val) | |||||
def get(self, idxes): | def get(self, idxes): | ||||
if isinstance(idxes, int): | if isinstance(idxes, int): | ||||
return self.data[idxes] | |||||
elif isinstance(idxes, list): | |||||
id_list = np.array(idxes) | |||||
batch_size = len(id_list) | |||||
len_list = [(i, self.data[i].shape[0]) for i in id_list] | |||||
_, max_len = max(len_list, key=lambda x: x[1]) | |||||
return self.content[idxes] | |||||
batch_size = len(idxes) | |||||
max_len = max([len(self.content[i]) for i in idxes]) | |||||
array = np.full((batch_size, max_len), self.padding_val, dtype=np.int32) | array = np.full((batch_size, max_len), self.padding_val, dtype=np.int32) | ||||
for i, (idx, length) in enumerate(len_list): | |||||
if length == max_len: | |||||
array[i] = self.data[idx] | |||||
else: | |||||
array[i][:length] = self.data[idx] | |||||
for i, idx in enumerate(idxes): | |||||
array[i][:len(self.content[idx])] = self.content[idx] | |||||
return array | return array | ||||
def __len__(self): | def __len__(self): | ||||
return len(self.data) | |||||
return len(self.content) |