diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 19fb5724..97e3a9fb 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -229,7 +229,7 @@ class CrossEntropyLoss(LossBase): def get_loss(self, pred, target, seq_len=None): if seq_len is not None and target.dim()>1: - mask = seq_len_to_mask(seq_len, max_len=target.size(1)).eq(0) + mask = seq_len_to_mask(seq_len, max_len=target.size(1)).eq(False) target = target.masked_fill(mask, self.padding_idx) if pred.dim() > 2: @@ -374,7 +374,7 @@ class CMRC2018Loss(LossBase): :return: """ batch_size, max_len = pred_end.size() - mask = seq_len_to_mask(context_len, max_len).eq(0) + mask = seq_len_to_mask(context_len, max_len).eq(False) pred_start = pred_start.masked_fill(mask, float('-inf')) pred_end = pred_end.masked_fill(mask, float('-inf')) diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index f1f97b17..4247d1de 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -358,7 +358,7 @@ class AccuracyMetric(MetricBase): target = target.to(pred) if masks is not None: - self.acc_count += torch.sum(torch.eq(pred, target).masked_fill(masks.eq(0), 0)).item() + self.acc_count += torch.sum(torch.eq(pred, target).masked_fill(masks.eq(False), 0)).item() self.total += torch.sum(masks).item() else: self.acc_count += torch.sum(torch.eq(pred, target)).item() @@ -465,7 +465,7 @@ class ClassifyFPreRecMetric(MetricBase): masks = seq_len_to_mask(seq_len=seq_len, max_len=max_len) else: masks = torch.ones_like(target).long().to(target.device) - masks = masks.eq(0) + masks = masks.eq(False) if pred.dim() == target.dim(): pass @@ -1017,7 +1017,7 @@ class CMRC2018Metric(MetricBase): :return: """ batch_size, max_len = pred_start.size() - context_mask = seq_len_to_mask(context_len, max_len=max_len).eq(0) + context_mask = seq_len_to_mask(context_len, max_len=max_len).eq(False) pred_start.masked_fill_(context_mask, float('-inf')) pred_end.masked_fill_(context_mask, float('-inf')) max_pred_start, pred_start_index = pred_start.max(dim=-1, keepdim=True) # batch_size, diff --git a/fastNLP/embeddings/bert_embedding.py b/fastNLP/embeddings/bert_embedding.py index 36670a0b..62344441 100644 --- a/fastNLP/embeddings/bert_embedding.py +++ b/fastNLP/embeddings/bert_embedding.py @@ -325,7 +325,7 @@ class _WordBertModel(nn.Module): batch_size, max_word_len = words.size() word_mask = words.ne(self._word_pad_index) # 为1的地方有word seq_len = word_mask.sum(dim=-1) - batch_word_pieces_length = self.word_pieces_lengths[words].masked_fill(word_mask.eq(0), + batch_word_pieces_length = self.word_pieces_lengths[words].masked_fill(word_mask.eq(False), 0) # batch_size x max_len word_pieces_lengths = batch_word_pieces_length.sum(dim=-1) # batch_size word_piece_length = batch_word_pieces_length.sum(dim=-1).max().item() # 表示word piece的长度(包括padding) @@ -403,12 +403,12 @@ class _WordBertModel(nn.Module): truncate_output_layer = output_layer[:, 1:-1] # 删除[CLS]与[SEP] batch_size x len x hidden_size if self.pool_method == 'first': tmp = truncate_output_layer[_batch_indexes, batch_word_pieces_cum_length] - tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(0), 0) + tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(False), 0) outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1)+s_shift] = tmp elif self.pool_method == 'last': tmp = truncate_output_layer[_batch_indexes, batch_word_pieces_cum_length] - tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(0), 0) + tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(False), 0) outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1)+s_shift] = tmp elif self.pool_method == 'max': for i in range(batch_size): diff --git a/fastNLP/embeddings/char_embedding.py b/fastNLP/embeddings/char_embedding.py index 0624d07f..93d3ce00 100644 --- a/fastNLP/embeddings/char_embedding.py +++ b/fastNLP/embeddings/char_embedding.py @@ -148,7 +148,7 @@ class CNNCharEmbedding(TokenEmbedding): chars, _ = torch.max(conv_chars, dim=-2) # batch_size x max_len x sum(filters) else: conv_chars = conv_chars.masked_fill(chars_masks.unsqueeze(-1), 0) - chars = torch.sum(conv_chars, dim=-2) / chars_masks.eq(0).sum(dim=-1, keepdim=True).float() + chars = torch.sum(conv_chars, dim=-2) / chars_masks.eq(False).sum(dim=-1, keepdim=True).float() chars = self.fc(chars) return self.dropout(chars) @@ -266,7 +266,7 @@ class LSTMCharEmbedding(TokenEmbedding): chars = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size chars = self.dropout(chars) reshaped_chars = chars.reshape(batch_size * max_len, max_word_len, -1) - char_seq_len = chars_masks.eq(0).sum(dim=-1).reshape(batch_size * max_len) + char_seq_len = chars_masks.eq(False).sum(dim=-1).reshape(batch_size * max_len) lstm_chars = self.lstm(reshaped_chars, char_seq_len)[0].reshape(batch_size, max_len, max_word_len, -1) # B x M x M x H @@ -276,7 +276,7 @@ class LSTMCharEmbedding(TokenEmbedding): chars, _ = torch.max(lstm_chars, dim=-2) # batch_size x max_len x H else: lstm_chars = lstm_chars.masked_fill(chars_masks.unsqueeze(-1), 0) - chars = torch.sum(lstm_chars, dim=-2) / chars_masks.eq(0).sum(dim=-1, keepdim=True).float() + chars = torch.sum(lstm_chars, dim=-2) / chars_masks.eq(False).sum(dim=-1, keepdim=True).float() chars = self.fc(chars) diff --git a/fastNLP/models/biaffine_parser.py b/fastNLP/models/biaffine_parser.py index 45f8adb7..e10a1fc4 100644 --- a/fastNLP/models/biaffine_parser.py +++ b/fastNLP/models/biaffine_parser.py @@ -148,7 +148,7 @@ class GraphParser(BaseModel): """ _, seq_len, _ = arc_matrix.shape matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf)) - flip_mask = mask.eq(0) + flip_mask = mask.eq(False) matrix.masked_fill_(flip_mask.unsqueeze(1), -np.inf) _, heads = torch.max(matrix, dim=2) if mask is not None: @@ -441,7 +441,7 @@ class BiaffineParser(GraphParser): batch_size, length, _ = pred1.shape mask = seq_len_to_mask(seq_len, max_len=length) - flip_mask = (mask == 0) + flip_mask = (mask.eq(False)) _arc_pred = pred1.clone() _arc_pred = _arc_pred.masked_fill(flip_mask.unsqueeze(1), -float('inf')) arc_logits = F.log_softmax(_arc_pred, dim=2) diff --git a/fastNLP/models/snli.py b/fastNLP/models/snli.py index 9a9f0c80..d8b5ccb4 100644 --- a/fastNLP/models/snli.py +++ b/fastNLP/models/snli.py @@ -152,7 +152,7 @@ class BiRNN(nn.Module): def forward(self, x, x_mask): # Sort x - lengths = x_mask.data.eq(1).long().sum(1) + lengths = x_mask.data.eq(True).long().sum(1) _, idx_sort = torch.sort(lengths, dim=0, descending=True) _, idx_unsort = torch.sort(idx_sort, dim=0) lengths = list(lengths[idx_sort]) diff --git a/fastNLP/modules/decoder/crf.py b/fastNLP/modules/decoder/crf.py index 669501e9..0f121595 100644 --- a/fastNLP/modules/decoder/crf.py +++ b/fastNLP/modules/decoder/crf.py @@ -217,14 +217,14 @@ class ConditionalRandomField(nn.Module): if self.include_start_end_trans: alpha = alpha + self.start_scores.view(1, -1) - flip_mask = mask.eq(0) + flip_mask = mask.eq(False) for i in range(1, seq_len): emit_score = logits[i].view(batch_size, 1, n_tags) trans_score = self.trans_m.view(1, n_tags, n_tags) tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score alpha = torch.logsumexp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \ - alpha.masked_fill(mask[i].eq(1).view(batch_size, 1), 0) + alpha.masked_fill(mask[i].eq(True).view(batch_size, 1), 0) if self.include_start_end_trans: alpha = alpha + self.end_scores.view(1, -1) @@ -244,8 +244,8 @@ class ConditionalRandomField(nn.Module): seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) # trans_socre [L-1, B] - mask = mask.eq(1) - flip_mask = mask.eq(0) + mask = mask.eq(True) + flip_mask = mask.eq(False) trans_score = self.trans_m[tags[:seq_len - 1], tags[1:]].masked_fill(flip_mask[1:, :], 0) # emit_score [L, B] emit_score = logits[seq_idx.view(-1, 1), batch_idx.view(1, -1), tags].masked_fill(flip_mask, 0) @@ -292,7 +292,7 @@ class ConditionalRandomField(nn.Module): """ batch_size, seq_len, n_tags = logits.size() logits = logits.transpose(0, 1).data # L, B, H - mask = mask.transpose(0, 1).data.eq(1) # L, B + mask = mask.transpose(0, 1).data.eq(True) # L, B # dp vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) @@ -311,7 +311,7 @@ class ConditionalRandomField(nn.Module): score = prev_score + trans_score + cur_score best_score, best_dst = score.max(1) vpath[i] = best_dst - vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \ + vscore = best_score.masked_fill(mask[i].eq(False).view(batch_size, 1), 0) + \ vscore.masked_fill(mask[i].view(batch_size, 1), 0) if self.include_start_end_trans: diff --git a/fastNLP/modules/decoder/utils.py b/fastNLP/modules/decoder/utils.py index e0d2af68..a1c60b61 100644 --- a/fastNLP/modules/decoder/utils.py +++ b/fastNLP/modules/decoder/utils.py @@ -27,7 +27,7 @@ def viterbi_decode(logits, transitions, mask=None, unpad=False): "compatible." logits = logits.transpose(0, 1).data # L, B, H if mask is not None: - mask = mask.transpose(0, 1).data.eq(1) # L, B + mask = mask.transpose(0, 1).data.eq(True) # L, B else: mask = logits.new_ones((seq_len, batch_size), dtype=torch.uint8) diff --git a/fastNLP/modules/encoder/star_transformer.py b/fastNLP/modules/encoder/star_transformer.py index 85b1ac4d..90cb6a2b 100644 --- a/fastNLP/modules/encoder/star_transformer.py +++ b/fastNLP/modules/encoder/star_transformer.py @@ -65,7 +65,7 @@ class StarTransformer(nn.Module): return f(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) B, L, H = data.size() - mask = (mask == 0) # flip the mask for masked_fill_ + mask = (mask.eq(False)) # flip the mask for masked_fill_ smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1) embs = data.permute(0, 2, 1)[:, :, :, None] # B H L 1 diff --git a/fastNLP/modules/encoder/transformer.py b/fastNLP/modules/encoder/transformer.py index 323091b0..2a999039 100644 --- a/fastNLP/modules/encoder/transformer.py +++ b/fastNLP/modules/encoder/transformer.py @@ -71,7 +71,7 @@ class TransformerEncoder(nn.Module): if seq_mask is None: atte_mask_out = None else: - atte_mask_out = (seq_mask == 0)[:, None, :] + atte_mask_out = (seq_mask.eq(False))[:, None, :] seq_mask = seq_mask[:, :, None] for layer in self.layers: output = layer(output, seq_mask, atte_mask_out) diff --git a/reproduction/Summarization/Baseline/model/Loss.py b/reproduction/Summarization/Baseline/model/Loss.py index e5244261..6ff6f0b9 100644 --- a/reproduction/Summarization/Baseline/model/Loss.py +++ b/reproduction/Summarization/Baseline/model/Loss.py @@ -47,7 +47,7 @@ class MyCrossEntropyLoss(LossBase): loss = F.cross_entropy(input=pred, target=target, ignore_index=self.padding_idx, reduction=self.reduce) loss = loss.view(batch, -1) - loss = loss.masked_fill(mask.eq(0), 0) + loss = loss.masked_fill(mask.eq(False), 0) loss = loss.sum(1).mean() logger.debug("loss %f", loss) return loss diff --git a/reproduction/Summarization/Baseline/model/Metric.py b/reproduction/Summarization/Baseline/model/Metric.py index df2cd9eb..91c25184 100644 --- a/reproduction/Summarization/Baseline/model/Metric.py +++ b/reproduction/Summarization/Baseline/model/Metric.py @@ -57,7 +57,7 @@ class LossMetric(MetricBase): loss = F.cross_entropy(input=pred, target=target, ignore_index=self.padding_idx, reduction=self.reduce) loss = loss.view(batch, -1) - loss = loss.masked_fill(mask.eq(0), 0) + loss = loss.masked_fill(mask.eq(False), 0) loss = loss.sum(1).mean() self.loss += loss self.iteration += 1 diff --git a/reproduction/Summarization/BertSum/model.py b/reproduction/Summarization/BertSum/model.py index 1ee821fc..34a05495 100644 --- a/reproduction/Summarization/BertSum/model.py +++ b/reproduction/Summarization/BertSum/model.py @@ -33,8 +33,8 @@ class BertSum(nn.Module): # print(segment_id.device) # print(cls_id.device) - input_mask = 1 - (article == 0) - mask_cls = 1 - (cls_id == -1) + input_mask = 1 - (article == 0).long() + mask_cls = 1 - (cls_id == -1).long() assert input_mask.size() == article.size() assert mask_cls.size() == cls_id.size() diff --git a/reproduction/joint_cws_parse/models/CharParser.py b/reproduction/joint_cws_parse/models/CharParser.py index 7d89cacb..bfb5da4e 100644 --- a/reproduction/joint_cws_parse/models/CharParser.py +++ b/reproduction/joint_cws_parse/models/CharParser.py @@ -223,7 +223,7 @@ class CharBiaffineParser(BiaffineParser): """ batch_size, seq_len, _ = arc_pred.shape - flip_mask = (mask == 0) + flip_mask = (mask.eq(False)) # _arc_pred = arc_pred.clone() _arc_pred = arc_pred.masked_fill(flip_mask.unsqueeze(1), -float('inf')) diff --git a/reproduction/matching/model/esim.py b/reproduction/matching/model/esim.py index d704e2f8..f3f93bb6 100644 --- a/reproduction/matching/model/esim.py +++ b/reproduction/matching/model/esim.py @@ -119,7 +119,7 @@ class BiRNN(nn.Module): def forward(self, x, x_mask): # Sort x - lengths = x_mask.data.eq(1).long().sum(1) + lengths = x_mask.data.eq(True).long().sum(1) _, idx_sort = torch.sort(lengths, dim=0, descending=True) _, idx_unsort = torch.sort(idx_sort, dim=0) lengths = list(lengths[idx_sort]) diff --git a/reproduction/matching/model/mwan.py b/reproduction/matching/model/mwan.py index 7ca6df3b..9af1e134 100644 --- a/reproduction/matching/model/mwan.py +++ b/reproduction/matching/model/mwan.py @@ -91,7 +91,7 @@ class ConcatAttention_Param(nn.Module): s = self.v(tc.tanh(self.ln(tc.cat([h,vq],-1)))).squeeze(-1) # (batch_size, len) - s = s - ((mask == 0).float() * 10000) + s = s - ((mask.eq(False)).float() * 10000) a = tc.softmax(s, dim=1) r = a.unsqueeze(-1) * h # (batch_size, len, input_size) @@ -121,7 +121,7 @@ def Attention(hq, hp, mask_hq, mask_hp, my_method): s = my_method(hq_mat, hp_mat) # (batch_size, len_q, len_p) - s = s - ((mask_mat == 0).float() * 10000) + s = s - ((mask_mat.eq(False)).float() * 10000) a = tc.softmax(s, dim=1) q = a.unsqueeze(-1) * hq_mat #(batch_size, len_q, len_p, input_size) @@ -242,7 +242,7 @@ class BiLinearAttention(nn.Module): s = self.my_method(hq, hp, mask_hp) # (batch_size, len_q, len_p) - s = s - ((mask_mat == 0).float() * 10000) + s = s - ((mask_mat.eq(False)).float() * 10000) a = tc.softmax(s, dim=1) hq_mat = hq.unsqueeze(2).expand(standard_size) @@ -285,7 +285,7 @@ class AggAttention(nn.Module): s = self.v(tc.tanh(self.ln(tc.cat([hs,vq],-1)))).squeeze(-1)# (4, batch_size, len_q) - s = s - ((mask.unsqueeze(0) == 0).float() * 10000) + s = s - ((mask.unsqueeze(0).eq(False)).float() * 10000) a = tc.softmax(s, dim=0) x = a.unsqueeze(-1) * hs diff --git a/reproduction/multi-criteria-cws/transformer.py b/reproduction/multi-criteria-cws/transformer.py index fc352e44..f102e5aa 100644 --- a/reproduction/multi-criteria-cws/transformer.py +++ b/reproduction/multi-criteria-cws/transformer.py @@ -27,7 +27,7 @@ def attention(query, key, value, mask=None, dropout=None): scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) if mask is not None: # print(scores.size(),mask.size()) # [bsz,1,1,len] - scores = scores.masked_fill(mask == 0, -1e9) + scores = scores.masked_fill(mask.eq(False), -1e9) p_attn = F.softmax(scores, dim=-1) if dropout is not None: p_attn = dropout(p_attn)