|
|
@@ -33,7 +33,6 @@ class CNNText(torch.nn.Module): |
|
|
|
padding=padding) |
|
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
self.fc = encoder.Linear(sum(kernel_nums), num_classes) |
|
|
|
self._loss = nn.CrossEntropyLoss() |
|
|
|
|
|
|
|
def forward(self, word_seq): |
|
|
|
""" |
|
|
@@ -56,25 +55,3 @@ class CNNText(torch.nn.Module): |
|
|
|
output = self(word_seq) |
|
|
|
_, predict = output['output'].max(dim=1) |
|
|
|
return {'predict': predict} |
|
|
|
|
|
|
|
def get_loss(self, output, label_seq): |
|
|
|
""" |
|
|
|
|
|
|
|
:param output: output of forward(), [batch_size, seq_len] |
|
|
|
:param label_seq: true label in DataSet, [batch_size, seq_len] |
|
|
|
:return loss: torch.Tensor |
|
|
|
""" |
|
|
|
return self._loss(output, label_seq) |
|
|
|
|
|
|
|
def evaluate(self, predict, label_seq): |
|
|
|
""" |
|
|
|
|
|
|
|
:param predict: iterable predict tensors |
|
|
|
:param label_seq: iterable true label tensors |
|
|
|
:return accuracy: dict of float |
|
|
|
""" |
|
|
|
predict, label_seq = torch.stack(tuple(predict), dim=0), torch.stack(tuple(label_seq), dim=0) |
|
|
|
predict, label_seq = predict.squeeze(), label_seq.squeeze() |
|
|
|
correct = (predict == label_seq).long().sum().item() |
|
|
|
total = label_seq.size(0) |
|
|
|
return {'acc': 1.0 * correct / total} |