@@ -229,7 +229,7 @@ class CrossEntropyLoss(LossBase): | |||||
def get_loss(self, pred, target, seq_len=None): | def get_loss(self, pred, target, seq_len=None): | ||||
if seq_len is not None and target.dim()>1: | if seq_len is not None and target.dim()>1: | ||||
mask = seq_len_to_mask(seq_len, max_len=target.size(1)).eq(0) | |||||
mask = seq_len_to_mask(seq_len, max_len=target.size(1)).eq(False) | |||||
target = target.masked_fill(mask, self.padding_idx) | target = target.masked_fill(mask, self.padding_idx) | ||||
if pred.dim() > 2: | if pred.dim() > 2: | ||||
@@ -374,7 +374,7 @@ class CMRC2018Loss(LossBase): | |||||
:return: | :return: | ||||
""" | """ | ||||
batch_size, max_len = pred_end.size() | batch_size, max_len = pred_end.size() | ||||
mask = seq_len_to_mask(context_len, max_len).eq(0) | |||||
mask = seq_len_to_mask(context_len, max_len).eq(False) | |||||
pred_start = pred_start.masked_fill(mask, float('-inf')) | pred_start = pred_start.masked_fill(mask, float('-inf')) | ||||
pred_end = pred_end.masked_fill(mask, float('-inf')) | pred_end = pred_end.masked_fill(mask, float('-inf')) | ||||
@@ -358,7 +358,7 @@ class AccuracyMetric(MetricBase): | |||||
target = target.to(pred) | target = target.to(pred) | ||||
if masks is not None: | if masks is not None: | ||||
self.acc_count += torch.sum(torch.eq(pred, target).masked_fill(masks.eq(0), 0)).item() | |||||
self.acc_count += torch.sum(torch.eq(pred, target).masked_fill(masks.eq(False), 0)).item() | |||||
self.total += torch.sum(masks).item() | self.total += torch.sum(masks).item() | ||||
else: | else: | ||||
self.acc_count += torch.sum(torch.eq(pred, target)).item() | self.acc_count += torch.sum(torch.eq(pred, target)).item() | ||||
@@ -465,7 +465,7 @@ class ClassifyFPreRecMetric(MetricBase): | |||||
masks = seq_len_to_mask(seq_len=seq_len, max_len=max_len) | masks = seq_len_to_mask(seq_len=seq_len, max_len=max_len) | ||||
else: | else: | ||||
masks = torch.ones_like(target).long().to(target.device) | masks = torch.ones_like(target).long().to(target.device) | ||||
masks = masks.eq(0) | |||||
masks = masks.eq(False) | |||||
if pred.dim() == target.dim(): | if pred.dim() == target.dim(): | ||||
pass | pass | ||||
@@ -1017,7 +1017,7 @@ class CMRC2018Metric(MetricBase): | |||||
:return: | :return: | ||||
""" | """ | ||||
batch_size, max_len = pred_start.size() | batch_size, max_len = pred_start.size() | ||||
context_mask = seq_len_to_mask(context_len, max_len=max_len).eq(0) | |||||
context_mask = seq_len_to_mask(context_len, max_len=max_len).eq(False) | |||||
pred_start.masked_fill_(context_mask, float('-inf')) | pred_start.masked_fill_(context_mask, float('-inf')) | ||||
pred_end.masked_fill_(context_mask, float('-inf')) | pred_end.masked_fill_(context_mask, float('-inf')) | ||||
max_pred_start, pred_start_index = pred_start.max(dim=-1, keepdim=True) # batch_size, | max_pred_start, pred_start_index = pred_start.max(dim=-1, keepdim=True) # batch_size, | ||||
@@ -325,7 +325,7 @@ class _WordBertModel(nn.Module): | |||||
batch_size, max_word_len = words.size() | batch_size, max_word_len = words.size() | ||||
word_mask = words.ne(self._word_pad_index) # 为1的地方有word | word_mask = words.ne(self._word_pad_index) # 为1的地方有word | ||||
seq_len = word_mask.sum(dim=-1) | seq_len = word_mask.sum(dim=-1) | ||||
batch_word_pieces_length = self.word_pieces_lengths[words].masked_fill(word_mask.eq(0), | |||||
batch_word_pieces_length = self.word_pieces_lengths[words].masked_fill(word_mask.eq(False), | |||||
0) # batch_size x max_len | 0) # batch_size x max_len | ||||
word_pieces_lengths = batch_word_pieces_length.sum(dim=-1) # batch_size | word_pieces_lengths = batch_word_pieces_length.sum(dim=-1) # batch_size | ||||
word_piece_length = batch_word_pieces_length.sum(dim=-1).max().item() # 表示word piece的长度(包括padding) | word_piece_length = batch_word_pieces_length.sum(dim=-1).max().item() # 表示word piece的长度(包括padding) | ||||
@@ -403,12 +403,12 @@ class _WordBertModel(nn.Module): | |||||
truncate_output_layer = output_layer[:, 1:-1] # 删除[CLS]与[SEP] batch_size x len x hidden_size | truncate_output_layer = output_layer[:, 1:-1] # 删除[CLS]与[SEP] batch_size x len x hidden_size | ||||
if self.pool_method == 'first': | if self.pool_method == 'first': | ||||
tmp = truncate_output_layer[_batch_indexes, batch_word_pieces_cum_length] | tmp = truncate_output_layer[_batch_indexes, batch_word_pieces_cum_length] | ||||
tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(0), 0) | |||||
tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(False), 0) | |||||
outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1)+s_shift] = tmp | outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1)+s_shift] = tmp | ||||
elif self.pool_method == 'last': | elif self.pool_method == 'last': | ||||
tmp = truncate_output_layer[_batch_indexes, batch_word_pieces_cum_length] | tmp = truncate_output_layer[_batch_indexes, batch_word_pieces_cum_length] | ||||
tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(0), 0) | |||||
tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(False), 0) | |||||
outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1)+s_shift] = tmp | outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1)+s_shift] = tmp | ||||
elif self.pool_method == 'max': | elif self.pool_method == 'max': | ||||
for i in range(batch_size): | for i in range(batch_size): | ||||
@@ -148,7 +148,7 @@ class CNNCharEmbedding(TokenEmbedding): | |||||
chars, _ = torch.max(conv_chars, dim=-2) # batch_size x max_len x sum(filters) | chars, _ = torch.max(conv_chars, dim=-2) # batch_size x max_len x sum(filters) | ||||
else: | else: | ||||
conv_chars = conv_chars.masked_fill(chars_masks.unsqueeze(-1), 0) | conv_chars = conv_chars.masked_fill(chars_masks.unsqueeze(-1), 0) | ||||
chars = torch.sum(conv_chars, dim=-2) / chars_masks.eq(0).sum(dim=-1, keepdim=True).float() | |||||
chars = torch.sum(conv_chars, dim=-2) / chars_masks.eq(False).sum(dim=-1, keepdim=True).float() | |||||
chars = self.fc(chars) | chars = self.fc(chars) | ||||
return self.dropout(chars) | return self.dropout(chars) | ||||
@@ -266,7 +266,7 @@ class LSTMCharEmbedding(TokenEmbedding): | |||||
chars = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size | chars = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size | ||||
chars = self.dropout(chars) | chars = self.dropout(chars) | ||||
reshaped_chars = chars.reshape(batch_size * max_len, max_word_len, -1) | reshaped_chars = chars.reshape(batch_size * max_len, max_word_len, -1) | ||||
char_seq_len = chars_masks.eq(0).sum(dim=-1).reshape(batch_size * max_len) | |||||
char_seq_len = chars_masks.eq(False).sum(dim=-1).reshape(batch_size * max_len) | |||||
lstm_chars = self.lstm(reshaped_chars, char_seq_len)[0].reshape(batch_size, max_len, max_word_len, -1) | lstm_chars = self.lstm(reshaped_chars, char_seq_len)[0].reshape(batch_size, max_len, max_word_len, -1) | ||||
# B x M x M x H | # B x M x M x H | ||||
@@ -276,7 +276,7 @@ class LSTMCharEmbedding(TokenEmbedding): | |||||
chars, _ = torch.max(lstm_chars, dim=-2) # batch_size x max_len x H | chars, _ = torch.max(lstm_chars, dim=-2) # batch_size x max_len x H | ||||
else: | else: | ||||
lstm_chars = lstm_chars.masked_fill(chars_masks.unsqueeze(-1), 0) | lstm_chars = lstm_chars.masked_fill(chars_masks.unsqueeze(-1), 0) | ||||
chars = torch.sum(lstm_chars, dim=-2) / chars_masks.eq(0).sum(dim=-1, keepdim=True).float() | |||||
chars = torch.sum(lstm_chars, dim=-2) / chars_masks.eq(False).sum(dim=-1, keepdim=True).float() | |||||
chars = self.fc(chars) | chars = self.fc(chars) | ||||
@@ -148,7 +148,7 @@ class GraphParser(BaseModel): | |||||
""" | """ | ||||
_, seq_len, _ = arc_matrix.shape | _, seq_len, _ = arc_matrix.shape | ||||
matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf)) | matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf)) | ||||
flip_mask = mask.eq(0) | |||||
flip_mask = mask.eq(False) | |||||
matrix.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | matrix.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | ||||
_, heads = torch.max(matrix, dim=2) | _, heads = torch.max(matrix, dim=2) | ||||
if mask is not None: | if mask is not None: | ||||
@@ -441,7 +441,7 @@ class BiaffineParser(GraphParser): | |||||
batch_size, length, _ = pred1.shape | batch_size, length, _ = pred1.shape | ||||
mask = seq_len_to_mask(seq_len, max_len=length) | mask = seq_len_to_mask(seq_len, max_len=length) | ||||
flip_mask = (mask == 0) | |||||
flip_mask = (mask.eq(False)) | |||||
_arc_pred = pred1.clone() | _arc_pred = pred1.clone() | ||||
_arc_pred = _arc_pred.masked_fill(flip_mask.unsqueeze(1), -float('inf')) | _arc_pred = _arc_pred.masked_fill(flip_mask.unsqueeze(1), -float('inf')) | ||||
arc_logits = F.log_softmax(_arc_pred, dim=2) | arc_logits = F.log_softmax(_arc_pred, dim=2) | ||||
@@ -152,7 +152,7 @@ class BiRNN(nn.Module): | |||||
def forward(self, x, x_mask): | def forward(self, x, x_mask): | ||||
# Sort x | # Sort x | ||||
lengths = x_mask.data.eq(1).long().sum(1) | |||||
lengths = x_mask.data.eq(True).long().sum(1) | |||||
_, idx_sort = torch.sort(lengths, dim=0, descending=True) | _, idx_sort = torch.sort(lengths, dim=0, descending=True) | ||||
_, idx_unsort = torch.sort(idx_sort, dim=0) | _, idx_unsort = torch.sort(idx_sort, dim=0) | ||||
lengths = list(lengths[idx_sort]) | lengths = list(lengths[idx_sort]) | ||||
@@ -217,14 +217,14 @@ class ConditionalRandomField(nn.Module): | |||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
alpha = alpha + self.start_scores.view(1, -1) | alpha = alpha + self.start_scores.view(1, -1) | ||||
flip_mask = mask.eq(0) | |||||
flip_mask = mask.eq(False) | |||||
for i in range(1, seq_len): | for i in range(1, seq_len): | ||||
emit_score = logits[i].view(batch_size, 1, n_tags) | emit_score = logits[i].view(batch_size, 1, n_tags) | ||||
trans_score = self.trans_m.view(1, n_tags, n_tags) | trans_score = self.trans_m.view(1, n_tags, n_tags) | ||||
tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score | tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score | ||||
alpha = torch.logsumexp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \ | alpha = torch.logsumexp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \ | ||||
alpha.masked_fill(mask[i].eq(1).view(batch_size, 1), 0) | |||||
alpha.masked_fill(mask[i].eq(True).view(batch_size, 1), 0) | |||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
alpha = alpha + self.end_scores.view(1, -1) | alpha = alpha + self.end_scores.view(1, -1) | ||||
@@ -244,8 +244,8 @@ class ConditionalRandomField(nn.Module): | |||||
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) | seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) | ||||
# trans_socre [L-1, B] | # trans_socre [L-1, B] | ||||
mask = mask.eq(1) | |||||
flip_mask = mask.eq(0) | |||||
mask = mask.eq(True) | |||||
flip_mask = mask.eq(False) | |||||
trans_score = self.trans_m[tags[:seq_len - 1], tags[1:]].masked_fill(flip_mask[1:, :], 0) | trans_score = self.trans_m[tags[:seq_len - 1], tags[1:]].masked_fill(flip_mask[1:, :], 0) | ||||
# emit_score [L, B] | # emit_score [L, B] | ||||
emit_score = logits[seq_idx.view(-1, 1), batch_idx.view(1, -1), tags].masked_fill(flip_mask, 0) | emit_score = logits[seq_idx.view(-1, 1), batch_idx.view(1, -1), tags].masked_fill(flip_mask, 0) | ||||
@@ -292,7 +292,7 @@ class ConditionalRandomField(nn.Module): | |||||
""" | """ | ||||
batch_size, seq_len, n_tags = logits.size() | batch_size, seq_len, n_tags = logits.size() | ||||
logits = logits.transpose(0, 1).data # L, B, H | logits = logits.transpose(0, 1).data # L, B, H | ||||
mask = mask.transpose(0, 1).data.eq(1) # L, B | |||||
mask = mask.transpose(0, 1).data.eq(True) # L, B | |||||
# dp | # dp | ||||
vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | ||||
@@ -311,7 +311,7 @@ class ConditionalRandomField(nn.Module): | |||||
score = prev_score + trans_score + cur_score | score = prev_score + trans_score + cur_score | ||||
best_score, best_dst = score.max(1) | best_score, best_dst = score.max(1) | ||||
vpath[i] = best_dst | vpath[i] = best_dst | ||||
vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \ | |||||
vscore = best_score.masked_fill(mask[i].eq(False).view(batch_size, 1), 0) + \ | |||||
vscore.masked_fill(mask[i].view(batch_size, 1), 0) | vscore.masked_fill(mask[i].view(batch_size, 1), 0) | ||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
@@ -27,7 +27,7 @@ def viterbi_decode(logits, transitions, mask=None, unpad=False): | |||||
"compatible." | "compatible." | ||||
logits = logits.transpose(0, 1).data # L, B, H | logits = logits.transpose(0, 1).data # L, B, H | ||||
if mask is not None: | if mask is not None: | ||||
mask = mask.transpose(0, 1).data.eq(1) # L, B | |||||
mask = mask.transpose(0, 1).data.eq(True) # L, B | |||||
else: | else: | ||||
mask = logits.new_ones((seq_len, batch_size), dtype=torch.uint8) | mask = logits.new_ones((seq_len, batch_size), dtype=torch.uint8) | ||||
@@ -65,7 +65,7 @@ class StarTransformer(nn.Module): | |||||
return f(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) | return f(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) | ||||
B, L, H = data.size() | B, L, H = data.size() | ||||
mask = (mask == 0) # flip the mask for masked_fill_ | |||||
mask = (mask.eq(False)) # flip the mask for masked_fill_ | |||||
smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1) | smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1) | ||||
embs = data.permute(0, 2, 1)[:, :, :, None] # B H L 1 | embs = data.permute(0, 2, 1)[:, :, :, None] # B H L 1 | ||||
@@ -71,7 +71,7 @@ class TransformerEncoder(nn.Module): | |||||
if seq_mask is None: | if seq_mask is None: | ||||
atte_mask_out = None | atte_mask_out = None | ||||
else: | else: | ||||
atte_mask_out = (seq_mask == 0)[:, None, :] | |||||
atte_mask_out = (seq_mask.eq(False))[:, None, :] | |||||
seq_mask = seq_mask[:, :, None] | seq_mask = seq_mask[:, :, None] | ||||
for layer in self.layers: | for layer in self.layers: | ||||
output = layer(output, seq_mask, atte_mask_out) | output = layer(output, seq_mask, atte_mask_out) | ||||
@@ -47,7 +47,7 @@ class MyCrossEntropyLoss(LossBase): | |||||
loss = F.cross_entropy(input=pred, target=target, | loss = F.cross_entropy(input=pred, target=target, | ||||
ignore_index=self.padding_idx, reduction=self.reduce) | ignore_index=self.padding_idx, reduction=self.reduce) | ||||
loss = loss.view(batch, -1) | loss = loss.view(batch, -1) | ||||
loss = loss.masked_fill(mask.eq(0), 0) | |||||
loss = loss.masked_fill(mask.eq(False), 0) | |||||
loss = loss.sum(1).mean() | loss = loss.sum(1).mean() | ||||
logger.debug("loss %f", loss) | logger.debug("loss %f", loss) | ||||
return loss | return loss | ||||
@@ -57,7 +57,7 @@ class LossMetric(MetricBase): | |||||
loss = F.cross_entropy(input=pred, target=target, | loss = F.cross_entropy(input=pred, target=target, | ||||
ignore_index=self.padding_idx, reduction=self.reduce) | ignore_index=self.padding_idx, reduction=self.reduce) | ||||
loss = loss.view(batch, -1) | loss = loss.view(batch, -1) | ||||
loss = loss.masked_fill(mask.eq(0), 0) | |||||
loss = loss.masked_fill(mask.eq(False), 0) | |||||
loss = loss.sum(1).mean() | loss = loss.sum(1).mean() | ||||
self.loss += loss | self.loss += loss | ||||
self.iteration += 1 | self.iteration += 1 | ||||
@@ -33,8 +33,8 @@ class BertSum(nn.Module): | |||||
# print(segment_id.device) | # print(segment_id.device) | ||||
# print(cls_id.device) | # print(cls_id.device) | ||||
input_mask = 1 - (article == 0) | |||||
mask_cls = 1 - (cls_id == -1) | |||||
input_mask = 1 - (article == 0).long() | |||||
mask_cls = 1 - (cls_id == -1).long() | |||||
assert input_mask.size() == article.size() | assert input_mask.size() == article.size() | ||||
assert mask_cls.size() == cls_id.size() | assert mask_cls.size() == cls_id.size() | ||||
@@ -223,7 +223,7 @@ class CharBiaffineParser(BiaffineParser): | |||||
""" | """ | ||||
batch_size, seq_len, _ = arc_pred.shape | batch_size, seq_len, _ = arc_pred.shape | ||||
flip_mask = (mask == 0) | |||||
flip_mask = (mask.eq(False)) | |||||
# _arc_pred = arc_pred.clone() | # _arc_pred = arc_pred.clone() | ||||
_arc_pred = arc_pred.masked_fill(flip_mask.unsqueeze(1), -float('inf')) | _arc_pred = arc_pred.masked_fill(flip_mask.unsqueeze(1), -float('inf')) | ||||
@@ -119,7 +119,7 @@ class BiRNN(nn.Module): | |||||
def forward(self, x, x_mask): | def forward(self, x, x_mask): | ||||
# Sort x | # Sort x | ||||
lengths = x_mask.data.eq(1).long().sum(1) | |||||
lengths = x_mask.data.eq(True).long().sum(1) | |||||
_, idx_sort = torch.sort(lengths, dim=0, descending=True) | _, idx_sort = torch.sort(lengths, dim=0, descending=True) | ||||
_, idx_unsort = torch.sort(idx_sort, dim=0) | _, idx_unsort = torch.sort(idx_sort, dim=0) | ||||
lengths = list(lengths[idx_sort]) | lengths = list(lengths[idx_sort]) | ||||
@@ -91,7 +91,7 @@ class ConcatAttention_Param(nn.Module): | |||||
s = self.v(tc.tanh(self.ln(tc.cat([h,vq],-1)))).squeeze(-1) # (batch_size, len) | s = self.v(tc.tanh(self.ln(tc.cat([h,vq],-1)))).squeeze(-1) # (batch_size, len) | ||||
s = s - ((mask == 0).float() * 10000) | |||||
s = s - ((mask.eq(False)).float() * 10000) | |||||
a = tc.softmax(s, dim=1) | a = tc.softmax(s, dim=1) | ||||
r = a.unsqueeze(-1) * h # (batch_size, len, input_size) | r = a.unsqueeze(-1) * h # (batch_size, len, input_size) | ||||
@@ -121,7 +121,7 @@ def Attention(hq, hp, mask_hq, mask_hp, my_method): | |||||
s = my_method(hq_mat, hp_mat) # (batch_size, len_q, len_p) | s = my_method(hq_mat, hp_mat) # (batch_size, len_q, len_p) | ||||
s = s - ((mask_mat == 0).float() * 10000) | |||||
s = s - ((mask_mat.eq(False)).float() * 10000) | |||||
a = tc.softmax(s, dim=1) | a = tc.softmax(s, dim=1) | ||||
q = a.unsqueeze(-1) * hq_mat #(batch_size, len_q, len_p, input_size) | q = a.unsqueeze(-1) * hq_mat #(batch_size, len_q, len_p, input_size) | ||||
@@ -242,7 +242,7 @@ class BiLinearAttention(nn.Module): | |||||
s = self.my_method(hq, hp, mask_hp) # (batch_size, len_q, len_p) | s = self.my_method(hq, hp, mask_hp) # (batch_size, len_q, len_p) | ||||
s = s - ((mask_mat == 0).float() * 10000) | |||||
s = s - ((mask_mat.eq(False)).float() * 10000) | |||||
a = tc.softmax(s, dim=1) | a = tc.softmax(s, dim=1) | ||||
hq_mat = hq.unsqueeze(2).expand(standard_size) | hq_mat = hq.unsqueeze(2).expand(standard_size) | ||||
@@ -285,7 +285,7 @@ class AggAttention(nn.Module): | |||||
s = self.v(tc.tanh(self.ln(tc.cat([hs,vq],-1)))).squeeze(-1)# (4, batch_size, len_q) | s = self.v(tc.tanh(self.ln(tc.cat([hs,vq],-1)))).squeeze(-1)# (4, batch_size, len_q) | ||||
s = s - ((mask.unsqueeze(0) == 0).float() * 10000) | |||||
s = s - ((mask.unsqueeze(0).eq(False)).float() * 10000) | |||||
a = tc.softmax(s, dim=0) | a = tc.softmax(s, dim=0) | ||||
x = a.unsqueeze(-1) * hs | x = a.unsqueeze(-1) * hs | ||||
@@ -27,7 +27,7 @@ def attention(query, key, value, mask=None, dropout=None): | |||||
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) | ||||
if mask is not None: | if mask is not None: | ||||
# print(scores.size(),mask.size()) # [bsz,1,1,len] | # print(scores.size(),mask.size()) # [bsz,1,1,len] | ||||
scores = scores.masked_fill(mask == 0, -1e9) | |||||
scores = scores.masked_fill(mask.eq(False), -1e9) | |||||
p_attn = F.softmax(scores, dim=-1) | p_attn = F.softmax(scores, dim=-1) | ||||
if dropout is not None: | if dropout is not None: | ||||
p_attn = dropout(p_attn) | p_attn = dropout(p_attn) | ||||