Browse Source

fix crf

tags/v0.2.0
yunfan 6 years ago
parent
commit
64a9bacbc2
1 changed files with 5 additions and 4 deletions
  1. +5
    -4
      fastNLP/modules/decoder/CRF.py

+ 5
- 4
fastNLP/modules/decoder/CRF.py View File

@@ -89,8 +89,9 @@ class ConditionalRandomField(nn.Module):
score = score.sum(0) + emit_score[-1]
if self.include_start_end_trans:
st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]]
last_idx = mask.long().sum(0)
last_idx = mask.long().sum(0) - 1
ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]]
print(score.size(), st_scores.size(), ed_scores.size())
score += st_scores + ed_scores
# return [B,]
return score
@@ -104,8 +105,8 @@ class ConditionalRandomField(nn.Module):
:return:FloatTensor, batch_size
"""
feats = feats.transpose(0, 1)
tags = tags.transpose(0, 1)
mask = mask.transpose(0, 1)
tags = tags.transpose(0, 1).long()
mask = mask.transpose(0, 1).float()
all_path_score = self._normalizer_likelihood(feats, mask)
gold_path_score = self._glod_score(feats, tags, mask)

@@ -156,4 +157,4 @@ class ConditionalRandomField(nn.Module):

if get_score:
return ans_score, ans.transpose(0, 1)
return ans.transpose(0, 1)
return ans.transpose(0, 1)

Loading…
Cancel
Save