|
@@ -37,8 +37,9 @@ class CNNText(torch.nn.Module): |
|
|
|
|
|
|
|
|
def forward(self, word_seq): |
|
|
def forward(self, word_seq): |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
:param word_seq: torch.LongTensor, [batch_size, seq_len] |
|
|
:param word_seq: torch.LongTensor, [batch_size, seq_len] |
|
|
:return x: torch.LongTensor, [batch_size, num_classes] |
|
|
|
|
|
|
|
|
:return output: dict of torch.LongTensor, [batch_size, num_classes] |
|
|
""" |
|
|
""" |
|
|
x = self.embed(word_seq) # [N,L] -> [N,L,C] |
|
|
x = self.embed(word_seq) # [N,L] -> [N,L,C] |
|
|
x = self.conv_pool(x) # [N,L,C] -> [N,C] |
|
|
x = self.conv_pool(x) # [N,L,C] -> [N,C] |
|
@@ -47,14 +48,31 @@ class CNNText(torch.nn.Module): |
|
|
return {'output':x} |
|
|
return {'output':x} |
|
|
|
|
|
|
|
|
def predict(self, word_seq): |
|
|
def predict(self, word_seq): |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
:param word_seq: torch.LongTensor, [batch_size, seq_len] |
|
|
|
|
|
:return predict: dict of torch.LongTensor, [batch_size, seq_len] |
|
|
|
|
|
""" |
|
|
output = self(word_seq) |
|
|
output = self(word_seq) |
|
|
_, predict = output['output'].max(dim=1) |
|
|
_, predict = output['output'].max(dim=1) |
|
|
return {'predict': predict} |
|
|
return {'predict': predict} |
|
|
|
|
|
|
|
|
def get_loss(self, output, label_seq): |
|
|
def get_loss(self, output, label_seq): |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
:param output: output of forward(), [batch_size, seq_len] |
|
|
|
|
|
:param label_seq: true label in DataSet, [batch_size, seq_len] |
|
|
|
|
|
:return loss: torch.Tensor |
|
|
|
|
|
""" |
|
|
return self._loss(output, label_seq) |
|
|
return self._loss(output, label_seq) |
|
|
|
|
|
|
|
|
def evaluate(self, predict, label_seq): |
|
|
def evaluate(self, predict, label_seq): |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
:param predict: iterable predict tensors |
|
|
|
|
|
:param label_seq: iterable true label tensors |
|
|
|
|
|
:return accuracy: dict of float |
|
|
|
|
|
""" |
|
|
predict, label_seq = torch.stack(tuple(predict), dim=0), torch.stack(tuple(label_seq), dim=0) |
|
|
predict, label_seq = torch.stack(tuple(predict), dim=0), torch.stack(tuple(label_seq), dim=0) |
|
|
predict, label_seq = predict.squeeze(), label_seq.squeeze() |
|
|
predict, label_seq = predict.squeeze(), label_seq.squeeze() |
|
|
correct = (predict == label_seq).long().sum().item() |
|
|
correct = (predict == label_seq).long().sum().item() |
|
|