Browse Source

修复crf中typo; 以及可能导致数值不稳定的地方

tags/v0.4.10
yh 5 years ago
parent
commit
22661ea866
1 changed files with 4 additions and 3 deletions
  1. +4
    -3
      fastNLP/modules/decoder/CRF.py

+ 4
- 3
fastNLP/modules/decoder/CRF.py View File

@@ -205,7 +205,7 @@ class ConditionalRandomField(nn.Module):

return log_sum_exp(alpha, 1)

def _glod_score(self, logits, tags, mask):
def _gold_score(self, logits, tags, mask):
"""
Compute the score for the gold path.
:param logits: FloatTensor, max_len x batch_size x num_tags
@@ -244,7 +244,7 @@ class ConditionalRandomField(nn.Module):
tags = tags.transpose(0, 1).long()
mask = mask.transpose(0, 1).float()
all_path_score = self._normalizer_likelihood(feats, mask)
gold_path_score = self._glod_score(feats, tags, mask)
gold_path_score = self._gold_score(feats, tags, mask)

return all_path_score - gold_path_score

@@ -284,7 +284,8 @@ class ConditionalRandomField(nn.Module):
score = prev_score + trans_score + cur_score
best_score, best_dst = score.max(1)
vpath[i] = best_dst
vscore = best_score * mask[i].view(batch_size, 1) + vscore * (1 - mask[i]).view(batch_size, 1)
best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \
vscore.masked_fill(mask[i].view(batch_size, 1), 0)

vscore += transitions[:n_tags, n_tags+1].view(1, -1)



Loading…
Cancel
Save