diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index e626ff26..c6f0de35 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -27,7 +27,7 @@ class DataSet(object): for ins in ins_list: self.append(ins) else: - self.append(ins) + self.append(ins_list) def append(self, ins): # no field @@ -40,6 +40,10 @@ class DataSet(object): assert name in self.field_arrays self.field_arrays[name].append(field) + def add_field(self, name, fields): + assert len(self) == len(fields) + self.field_arrays[name] = fields + def get_fields(self): return self.field_arrays diff --git a/fastNLP/core/field.py b/fastNLP/core/field.py index 5b9c1b63..cf34abf8 100644 --- a/fastNLP/core/field.py +++ b/fastNLP/core/field.py @@ -39,35 +39,6 @@ class TextField(Field): super(TextField, self).__init__(text, is_target) -class IndexField(Field): - def __init__(self, name, content, vocab, is_target): - super(IndexField, self).__init__(name, is_target) - self.content = [] - self.padding_idx = vocab.padding_idx - for sent in content: - idx = vocab.index_sent(sent) - if isinstance(idx, list): - idx = torch.Tensor(idx) - elif isinstance(idx, np.array): - idx = torch.from_numpy(idx) - elif not isinstance(idx, torch.Tensor): - raise ValueError - self.content.append(idx) - - def to_tensor(self, id_list, sort_within_batch=False): - max_len = max(id_list) - batch_size = len(id_list) - tensor = torch.full((batch_size, max_len), self.padding_idx, dtype=torch.long) - len_list = [(i, self.content[i].size(0)) for i in id_list] - if sort_within_batch: - len_list = sorted(len_list, key=lambda x: x[1], reverse=True) - for i, (idx, length) in enumerate(len_list): - if length == max_len: - tensor[i] = self.content[idx] - else: - tensor[i][:length] = self.content[idx] - return tensor - class LabelField(Field): """The Field representing a single label. Can be a string or integer. diff --git a/fastNLP/core/fieldarray.py b/fastNLP/core/fieldarray.py index 9710f991..9d0f8e9e 100644 --- a/fastNLP/core/fieldarray.py +++ b/fastNLP/core/fieldarray.py @@ -4,36 +4,24 @@ import numpy as np class FieldArray(object): def __init__(self, name, content, padding_val=0, is_target=True, need_tensor=True): self.name = name - self.data = [self._convert_np(val) for val in content] + self.content = content self.padding_val = padding_val self.is_target = is_target self.need_tensor = need_tensor - def _convert_np(self, val): - if not isinstance(val, np.array): - return np.array(val) - else: - return val - def append(self, val): - self.data.append(self._convert_np(val)) + self.content.append(val) def get(self, idxes): if isinstance(idxes, int): - return self.data[idxes] - elif isinstance(idxes, list): - id_list = np.array(idxes) - batch_size = len(id_list) - len_list = [(i, self.data[i].shape[0]) for i in id_list] - _, max_len = max(len_list, key=lambda x: x[1]) + return self.content[idxes] + batch_size = len(idxes) + max_len = max([len(self.content[i]) for i in idxes]) array = np.full((batch_size, max_len), self.padding_val, dtype=np.int32) - for i, (idx, length) in enumerate(len_list): - if length == max_len: - array[i] = self.data[idx] - else: - array[i][:length] = self.data[idx] + for i, idx in enumerate(idxes): + array[i][:len(self.content[idx])] = self.content[idx] return array def __len__(self): - return len(self.data) + return len(self.content)