@@ -42,7 +42,7 @@ def viterbi_decode(logits, transitions, mask=None, unpad=False): | |||||
score = prev_score + trans_score + cur_score | score = prev_score + trans_score + cur_score | ||||
best_score, best_dst = score.max(1) | best_score, best_dst = score.max(1) | ||||
vpath[i] = best_dst | vpath[i] = best_dst | ||||
vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \ | |||||
vscore = best_score.masked_fill(mask[i].eq(False).view(batch_size, 1), 0) + \ | |||||
vscore.masked_fill(mask[i].view(batch_size, 1), 0) | vscore.masked_fill(mask[i].view(batch_size, 1), 0) | ||||
# backtrace | # backtrace | ||||
@@ -80,7 +80,7 @@ class ConvMaxpool(nn.Module): | |||||
xs = [self.activation(conv(x)) for conv in self.convs] # [[N,C,L], ...] | xs = [self.activation(conv(x)) for conv in self.convs] # [[N,C,L], ...] | ||||
if mask is not None: | if mask is not None: | ||||
mask = mask.unsqueeze(1) # B x 1 x L | mask = mask.unsqueeze(1) # B x 1 x L | ||||
xs = [x.masked_fill_(mask.eq(0), float('-inf')) for x in xs] | |||||
xs = [x.masked_fill_(mask.eq(False), float('-inf')) for x in xs] | |||||
# max-pooling | # max-pooling | ||||
xs = [F.max_pool1d(input=i, kernel_size=i.size(2)).squeeze(2) | xs = [F.max_pool1d(input=i, kernel_size=i.size(2)).squeeze(2) | ||||
for i in xs] # [[N, C], ...] | for i in xs] # [[N, C], ...] | ||||
@@ -158,7 +158,7 @@ class CWSModel(nn.Module): | |||||
else: | else: | ||||
if tags is not None: | if tags is not None: | ||||
out = out.contiguous().view(-1, self.tag_size) | out = out.contiguous().view(-1, self.tag_size) | ||||
tags = tags.data.masked_fill_(mask == 0, -100).view(-1) | |||||
tags = tags.data.masked_fill_(mask.eq(False), -100).view(-1) | |||||
loss = self.loss_f(out, tags) | loss = self.loss_f(out, tags) | ||||
return {"loss": loss} | return {"loss": loss} | ||||
else: | else: | ||||
@@ -18,7 +18,7 @@ def subsequent_mask(size): | |||||
"Mask out subsequent positions." | "Mask out subsequent positions." | ||||
attn_shape = (1, size, size) | attn_shape = (1, size, size) | ||||
subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype("uint8") | subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype("uint8") | ||||
return torch.from_numpy(subsequent_mask) == 0 | |||||
return torch.from_numpy(subsequent_mask).eq(False) | |||||
def attention(query, key, value, mask=None, dropout=None): | def attention(query, key, value, mask=None, dropout=None): | ||||
@@ -37,8 +37,8 @@ class SemiCRFShiftRelay(nn.Module): | |||||
# 当前时刻结束的分数是多少 | # 当前时刻结束的分数是多少 | ||||
scores = logits.new_zeros(batch_size, max_len+1) | scores = logits.new_zeros(batch_size, max_len+1) | ||||
# golden的分数 | # golden的分数 | ||||
gold_scores = relay_logits[:, 0].masked_fill(relay_mask[:, 0].eq(0), 0) + \ | |||||
logits[:, 0, 0].masked_fill(end_seg_mask[:, 0].eq(0), 0) | |||||
gold_scores = relay_logits[:, 0].masked_fill(relay_mask[:, 0].eq(False), 0) + \ | |||||
logits[:, 0, 0].masked_fill(end_seg_mask[:, 0].eq(False), 0) | |||||
# 初始化 | # 初始化 | ||||
scores[:, 1] = logits[:, 0, 0] | scores[:, 1] = logits[:, 0, 0] | ||||
batch_i = torch.arange(batch_size).to(logits.device).long() | batch_i = torch.arange(batch_size).to(logits.device).long() | ||||
@@ -67,8 +67,8 @@ class SemiCRFShiftRelay(nn.Module): | |||||
# 计算golden | # 计算golden | ||||
seg_i = relay_target[:, t] # batch_size | seg_i = relay_target[:, t] # batch_size | ||||
gold_segment_scores = logits[:, t][(batch_i, seg_i)].masked_fill(end_seg_mask[:, t].eq(0), 0) # batch_size, 后向从0到L长度的segment的分数 | |||||
relay_score = relay_logits[:, t].masked_fill(relay_mask[:, t].eq(0), 0) | |||||
gold_segment_scores = logits[:, t][(batch_i, seg_i)].masked_fill(end_seg_mask[:, t].eq(False), 0) # batch_size, 后向从0到L长度的segment的分数 | |||||
relay_score = relay_logits[:, t].masked_fill(relay_mask[:, t].eq(False), 0) | |||||
gold_scores = gold_scores + relay_score + gold_segment_scores | gold_scores = gold_scores + relay_score + gold_segment_scores | ||||
all_scores = scores.gather(dim=1, index=seq_len.unsqueeze(1)).squeeze(1) # batch_size | all_scores = scores.gather(dim=1, index=seq_len.unsqueeze(1)).squeeze(1) # batch_size | ||||
return all_scores - gold_scores | return all_scores - gold_scores | ||||
@@ -106,7 +106,7 @@ class IDCNN(nn.Module): | |||||
if self.crf is not None and target is not None: | if self.crf is not None and target is not None: | ||||
loss = self.crf(y.transpose(1, 2), t, mask) | loss = self.crf(y.transpose(1, 2), t, mask) | ||||
else: | else: | ||||
y.masked_fill_((mask == 0)[:,None,:], -100) | |||||
y.masked_fill_((mask.eq(False))[:,None,:], -100) | |||||
# f_mask = mask.float() | # f_mask = mask.float() | ||||
# t = f_mask * t + (1-f_mask) * -100 | # t = f_mask * t + (1-f_mask) * -100 | ||||
loss = F.cross_entropy(y, t, ignore_index=-100) | loss = F.cross_entropy(y, t, ignore_index=-100) | ||||