Browse Source

修复crf中typo; 以及可能导致数值不稳定的地方

tags/v0.4.10
yh 5 years ago
parent
commit
22661ea866
1 changed files with 4 additions and 3 deletions
  1. +4
    -3
      fastNLP/modules/decoder/CRF.py

+ 4
- 3
fastNLP/modules/decoder/CRF.py View File

@@ -205,7 +205,7 @@ class ConditionalRandomField(nn.Module):


return log_sum_exp(alpha, 1) return log_sum_exp(alpha, 1)


def _glod_score(self, logits, tags, mask):
def _gold_score(self, logits, tags, mask):
""" """
Compute the score for the gold path. Compute the score for the gold path.
:param logits: FloatTensor, max_len x batch_size x num_tags :param logits: FloatTensor, max_len x batch_size x num_tags
@@ -244,7 +244,7 @@ class ConditionalRandomField(nn.Module):
tags = tags.transpose(0, 1).long() tags = tags.transpose(0, 1).long()
mask = mask.transpose(0, 1).float() mask = mask.transpose(0, 1).float()
all_path_score = self._normalizer_likelihood(feats, mask) all_path_score = self._normalizer_likelihood(feats, mask)
gold_path_score = self._glod_score(feats, tags, mask)
gold_path_score = self._gold_score(feats, tags, mask)


return all_path_score - gold_path_score return all_path_score - gold_path_score


@@ -284,7 +284,8 @@ class ConditionalRandomField(nn.Module):
score = prev_score + trans_score + cur_score score = prev_score + trans_score + cur_score
best_score, best_dst = score.max(1) best_score, best_dst = score.max(1)
vpath[i] = best_dst vpath[i] = best_dst
vscore = best_score * mask[i].view(batch_size, 1) + vscore * (1 - mask[i]).view(batch_size, 1)
best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \
vscore.masked_fill(mask[i].view(batch_size, 1), 0)


vscore += transitions[:n_tags, n_tags+1].view(1, -1) vscore += transitions[:n_tags, n_tags+1].view(1, -1)




Loading…
Cancel
Save