From 22661ea866fe97db2721187a51b637aca7dc89a2 Mon Sep 17 00:00:00 2001 From: yh Date: Tue, 12 Mar 2019 19:51:28 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dcrf=E4=B8=ADtypo;=20=E4=BB=A5?= =?UTF-8?q?=E5=8F=8A=E5=8F=AF=E8=83=BD=E5=AF=BC=E8=87=B4=E6=95=B0=E5=80=BC?= =?UTF-8?q?=E4=B8=8D=E7=A8=B3=E5=AE=9A=E7=9A=84=E5=9C=B0=E6=96=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/modules/decoder/CRF.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/fastNLP/modules/decoder/CRF.py b/fastNLP/modules/decoder/CRF.py index e1b68e7a..e17b04f3 100644 --- a/fastNLP/modules/decoder/CRF.py +++ b/fastNLP/modules/decoder/CRF.py @@ -205,7 +205,7 @@ class ConditionalRandomField(nn.Module): return log_sum_exp(alpha, 1) - def _glod_score(self, logits, tags, mask): + def _gold_score(self, logits, tags, mask): """ Compute the score for the gold path. :param logits: FloatTensor, max_len x batch_size x num_tags @@ -244,7 +244,7 @@ class ConditionalRandomField(nn.Module): tags = tags.transpose(0, 1).long() mask = mask.transpose(0, 1).float() all_path_score = self._normalizer_likelihood(feats, mask) - gold_path_score = self._glod_score(feats, tags, mask) + gold_path_score = self._gold_score(feats, tags, mask) return all_path_score - gold_path_score @@ -284,7 +284,8 @@ class ConditionalRandomField(nn.Module): score = prev_score + trans_score + cur_score best_score, best_dst = score.max(1) vpath[i] = best_dst - vscore = best_score * mask[i].view(batch_size, 1) + vscore * (1 - mask[i]).view(batch_size, 1) + best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \ + vscore.masked_fill(mask[i].view(batch_size, 1), 0) vscore += transitions[:n_tags, n_tags+1].view(1, -1)