From 64a9bacbc25d3890b6112c512e5823f4a4e3e338 Mon Sep 17 00:00:00 2001 From: yunfan Date: Sat, 10 Nov 2018 16:50:56 +0800 Subject: [PATCH] fix crf --- fastNLP/modules/decoder/CRF.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/fastNLP/modules/decoder/CRF.py b/fastNLP/modules/decoder/CRF.py index 11cde48a..e24f4d27 100644 --- a/fastNLP/modules/decoder/CRF.py +++ b/fastNLP/modules/decoder/CRF.py @@ -89,8 +89,9 @@ class ConditionalRandomField(nn.Module): score = score.sum(0) + emit_score[-1] if self.include_start_end_trans: st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] - last_idx = mask.long().sum(0) + last_idx = mask.long().sum(0) - 1 ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]] + print(score.size(), st_scores.size(), ed_scores.size()) score += st_scores + ed_scores # return [B,] return score @@ -104,8 +105,8 @@ class ConditionalRandomField(nn.Module): :return:FloatTensor, batch_size """ feats = feats.transpose(0, 1) - tags = tags.transpose(0, 1) - mask = mask.transpose(0, 1) + tags = tags.transpose(0, 1).long() + mask = mask.transpose(0, 1).float() all_path_score = self._normalizer_likelihood(feats, mask) gold_path_score = self._glod_score(feats, tags, mask) @@ -156,4 +157,4 @@ class ConditionalRandomField(nn.Module): if get_score: return ans_score, ans.transpose(0, 1) - return ans.transpose(0, 1) \ No newline at end of file + return ans.transpose(0, 1)