Browse Source

fix crf

tags/v0.2.0
yunfan 6 years ago
parent
commit
64a9bacbc2
1 changed files with 5 additions and 4 deletions
  1. +5
    -4
      fastNLP/modules/decoder/CRF.py

+ 5
- 4
fastNLP/modules/decoder/CRF.py View File

@@ -89,8 +89,9 @@ class ConditionalRandomField(nn.Module):
score = score.sum(0) + emit_score[-1] score = score.sum(0) + emit_score[-1]
if self.include_start_end_trans: if self.include_start_end_trans:
st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]]
last_idx = mask.long().sum(0)
last_idx = mask.long().sum(0) - 1
ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]] ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]]
print(score.size(), st_scores.size(), ed_scores.size())
score += st_scores + ed_scores score += st_scores + ed_scores
# return [B,] # return [B,]
return score return score
@@ -104,8 +105,8 @@ class ConditionalRandomField(nn.Module):
:return:FloatTensor, batch_size :return:FloatTensor, batch_size
""" """
feats = feats.transpose(0, 1) feats = feats.transpose(0, 1)
tags = tags.transpose(0, 1)
mask = mask.transpose(0, 1)
tags = tags.transpose(0, 1).long()
mask = mask.transpose(0, 1).float()
all_path_score = self._normalizer_likelihood(feats, mask) all_path_score = self._normalizer_likelihood(feats, mask)
gold_path_score = self._glod_score(feats, tags, mask) gold_path_score = self._glod_score(feats, tags, mask)


@@ -156,4 +157,4 @@ class ConditionalRandomField(nn.Module):


if get_score: if get_score:
return ans_score, ans.transpose(0, 1) return ans_score, ans.transpose(0, 1)
return ans.transpose(0, 1)
return ans.transpose(0, 1)

Loading…
Cancel
Save