|
|
@@ -37,8 +37,8 @@ class SemiCRFShiftRelay(nn.Module): |
|
|
|
# 当前时刻结束的分数是多少 |
|
|
|
scores = logits.new_zeros(batch_size, max_len+1) |
|
|
|
# golden的分数 |
|
|
|
gold_scores = relay_logits[:, 0].masked_fill(relay_mask[:, 0].eq(0), 0) + \ |
|
|
|
logits[:, 0, 0].masked_fill(end_seg_mask[:, 0].eq(0), 0) |
|
|
|
gold_scores = relay_logits[:, 0].masked_fill(relay_mask[:, 0].eq(False), 0) + \ |
|
|
|
logits[:, 0, 0].masked_fill(end_seg_mask[:, 0].eq(False), 0) |
|
|
|
# 初始化 |
|
|
|
scores[:, 1] = logits[:, 0, 0] |
|
|
|
batch_i = torch.arange(batch_size).to(logits.device).long() |
|
|
@@ -67,8 +67,8 @@ class SemiCRFShiftRelay(nn.Module): |
|
|
|
|
|
|
|
# 计算golden |
|
|
|
seg_i = relay_target[:, t] # batch_size |
|
|
|
gold_segment_scores = logits[:, t][(batch_i, seg_i)].masked_fill(end_seg_mask[:, t].eq(0), 0) # batch_size, 后向从0到L长度的segment的分数 |
|
|
|
relay_score = relay_logits[:, t].masked_fill(relay_mask[:, t].eq(0), 0) |
|
|
|
gold_segment_scores = logits[:, t][(batch_i, seg_i)].masked_fill(end_seg_mask[:, t].eq(False), 0) # batch_size, 后向从0到L长度的segment的分数 |
|
|
|
relay_score = relay_logits[:, t].masked_fill(relay_mask[:, t].eq(False), 0) |
|
|
|
gold_scores = gold_scores + relay_score + gold_segment_scores |
|
|
|
all_scores = scores.gather(dim=1, index=seq_len.unsqueeze(1)).squeeze(1) # batch_size |
|
|
|
return all_scores - gold_scores |
|
|
|