From 5edd9de84178db51c7492da86d76f3468092bde3 Mon Sep 17 00:00:00 2001 From: yunfan Date: Tue, 4 Dec 2018 15:49:01 +0800 Subject: [PATCH] fix bugs --- fastNLP/core/dataset.py | 2 +- fastNLP/models/cnn_text_classification.py | 23 ----------------------- 2 files changed, 1 insertion(+), 24 deletions(-) diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index bc4dcf57..cdca4356 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -67,8 +67,8 @@ class DataSet(object): self.dataset = dataset self.idx = idx def __getitem__(self, item): - assert self.idx < len(self.dataset), "index:{} out of range".format(self.idx) assert item in self.dataset.field_arrays, "no such field:{} in instance {}".format(item, self.dataset[self.idx]) + assert self.idx < len(self.dataset.field_arrays[item]), "index:{} out of range".format(self.idx) return self.dataset.field_arrays[item][self.idx] def __repr__(self): return self.dataset[self.idx].__repr__() diff --git a/fastNLP/models/cnn_text_classification.py b/fastNLP/models/cnn_text_classification.py index 9aa07e66..c8fe5181 100644 --- a/fastNLP/models/cnn_text_classification.py +++ b/fastNLP/models/cnn_text_classification.py @@ -33,7 +33,6 @@ class CNNText(torch.nn.Module): padding=padding) self.dropout = nn.Dropout(dropout) self.fc = encoder.Linear(sum(kernel_nums), num_classes) - self._loss = nn.CrossEntropyLoss() def forward(self, word_seq): """ @@ -56,25 +55,3 @@ class CNNText(torch.nn.Module): output = self(word_seq) _, predict = output['output'].max(dim=1) return {'predict': predict} - - def get_loss(self, output, label_seq): - """ - - :param output: output of forward(), [batch_size, seq_len] - :param label_seq: true label in DataSet, [batch_size, seq_len] - :return loss: torch.Tensor - """ - return self._loss(output, label_seq) - - def evaluate(self, predict, label_seq): - """ - - :param predict: iterable predict tensors - :param label_seq: iterable true label tensors - :return accuracy: dict of float - """ - predict, label_seq = torch.stack(tuple(predict), dim=0), torch.stack(tuple(label_seq), dim=0) - predict, label_seq = predict.squeeze(), label_seq.squeeze() - correct = (predict == label_seq).long().sum().item() - total = label_seq.size(0) - return {'acc': 1.0 * correct / total}