|
|
@@ -89,8 +89,9 @@ class ConditionalRandomField(nn.Module): |
|
|
|
score = score.sum(0) + emit_score[-1] |
|
|
|
if self.include_start_end_trans: |
|
|
|
st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] |
|
|
|
last_idx = mask.long().sum(0) |
|
|
|
last_idx = mask.long().sum(0) - 1 |
|
|
|
ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]] |
|
|
|
print(score.size(), st_scores.size(), ed_scores.size()) |
|
|
|
score += st_scores + ed_scores |
|
|
|
# return [B,] |
|
|
|
return score |
|
|
@@ -104,8 +105,8 @@ class ConditionalRandomField(nn.Module): |
|
|
|
:return:FloatTensor, batch_size |
|
|
|
""" |
|
|
|
feats = feats.transpose(0, 1) |
|
|
|
tags = tags.transpose(0, 1) |
|
|
|
mask = mask.transpose(0, 1) |
|
|
|
tags = tags.transpose(0, 1).long() |
|
|
|
mask = mask.transpose(0, 1).float() |
|
|
|
all_path_score = self._normalizer_likelihood(feats, mask) |
|
|
|
gold_path_score = self._glod_score(feats, tags, mask) |
|
|
|
|
|
|
@@ -156,4 +157,4 @@ class ConditionalRandomField(nn.Module): |
|
|
|
|
|
|
|
if get_score: |
|
|
|
return ans_score, ans.transpose(0, 1) |
|
|
|
return ans.transpose(0, 1) |
|
|
|
return ans.transpose(0, 1) |