|
|
@@ -89,7 +89,7 @@ class ConditionalRandomField(nn.Module): |
|
|
|
score = score.sum(0) + emit_score[-1] |
|
|
|
if self.include_start_end_trans: |
|
|
|
st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] |
|
|
|
last_idx = masks.long().sum(0) |
|
|
|
last_idx = mask.long().sum(0) |
|
|
|
ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]] |
|
|
|
score += st_scores + ed_scores |
|
|
|
# return [B,] |
|
|
|