From f18ab642d70cb304212e71fd9b22e16fe3aa5699 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Thu, 22 Aug 2019 15:51:44 +0800 Subject: [PATCH] =?UTF-8?q?pytorch1.2=E7=89=88=E6=9C=AC=E4=B8=AD=E6=96=B0?= =?UTF-8?q?=E5=A2=9EboolTensor=E7=B1=BB=E5=9E=8B=EF=BC=8C=E6=89=80?= =?UTF-8?q?=E6=9C=89=E7=9A=84masked=5Ffill=E5=BF=85=E9=A1=BB=E4=B8=BAByteT?= =?UTF-8?q?ensor=E7=B1=BB=E5=9E=8B=E7=9A=84=E7=B4=A2=E5=BC=95,=E4=BF=AE?= =?UTF-8?q?=E6=94=B9fastNLP=E4=BB=A5=E9=80=82=E9=85=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/embeddings/bert_embedding.py | 4 ++-- fastNLP/embeddings/embedding.py | 4 ++-- fastNLP/models/biaffine_parser.py | 2 +- fastNLP/modules/decoder/crf.py | 6 +++--- fastNLP/modules/decoder/utils.py | 2 +- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/fastNLP/embeddings/bert_embedding.py b/fastNLP/embeddings/bert_embedding.py index bc0d46e2..6a10c489 100644 --- a/fastNLP/embeddings/bert_embedding.py +++ b/fastNLP/embeddings/bert_embedding.py @@ -115,7 +115,7 @@ class BertEmbedding(ContextualEmbedding): if self._word_sep_index: # 不能drop sep sep_mask = words.eq(self._word_sep_index) mask = torch.ones_like(words).float() * self.word_dropout - mask = torch.bernoulli(mask).byte() # dropout_word越大,越多位置为1 + mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 words = words.masked_fill(mask, self._word_unk_index) if self._word_sep_index: words.masked_fill_(sep_mask, self._word_sep_index) @@ -252,7 +252,7 @@ class BertWordPieceEncoder(nn.Module): if self._word_sep_index: # 不能drop sep sep_mask = words.eq(self._wordpiece_unk_index) mask = torch.ones_like(words).float() * self.word_dropout - mask = torch.bernoulli(mask).byte() # dropout_word越大,越多位置为1 + mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 words = words.masked_fill(mask, self._word_unk_index) if self._word_sep_index: words.masked_fill_(sep_mask, self._wordpiece_unk_index) diff --git a/fastNLP/embeddings/embedding.py b/fastNLP/embeddings/embedding.py index 8c5396b7..8b746c0d 100644 --- a/fastNLP/embeddings/embedding.py +++ b/fastNLP/embeddings/embedding.py @@ -63,7 +63,7 @@ class Embedding(nn.Module): """ if self.word_dropout>0 and self.training: mask = torch.ones_like(words).float() * self.word_dropout - mask = torch.bernoulli(mask).byte() # dropout_word越大,越多位置为1 + mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 words = words.masked_fill(mask, self.unk_index) words = self.embed(words) return self.dropout(words) @@ -135,7 +135,7 @@ class TokenEmbedding(nn.Module): """ if self.word_dropout > 0 and self.training: mask = torch.ones_like(words).float() * self.word_dropout - mask = torch.bernoulli(mask).byte() # dropout_word越大,越多位置为1 + mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 words = words.masked_fill(mask, self._word_unk_index) return words diff --git a/fastNLP/models/biaffine_parser.py b/fastNLP/models/biaffine_parser.py index 29487864..bead09fc 100644 --- a/fastNLP/models/biaffine_parser.py +++ b/fastNLP/models/biaffine_parser.py @@ -150,7 +150,7 @@ class GraphParser(BaseModel): """ _, seq_len, _ = arc_matrix.shape matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf)) - flip_mask = (mask == 0).byte() + flip_mask = mask.eq(0) matrix.masked_fill_(flip_mask.unsqueeze(1), -np.inf) _, heads = torch.max(matrix, dim=2) if mask is not None: diff --git a/fastNLP/modules/decoder/crf.py b/fastNLP/modules/decoder/crf.py index b7a7547f..9f19afef 100644 --- a/fastNLP/modules/decoder/crf.py +++ b/fastNLP/modules/decoder/crf.py @@ -210,7 +210,7 @@ class ConditionalRandomField(nn.Module): trans_score = self.trans_m.view(1, n_tags, n_tags) tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score alpha = torch.logsumexp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \ - alpha.masked_fill(mask[i].byte().view(batch_size, 1), 0) + alpha.masked_fill(mask[i].eq(1).view(batch_size, 1), 0) if self.include_start_end_trans: alpha = alpha + self.end_scores.view(1, -1) @@ -230,7 +230,7 @@ class ConditionalRandomField(nn.Module): seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) # trans_socre [L-1, B] - mask = mask.byte() + mask = mask.eq(1) flip_mask = mask.eq(0) trans_score = self.trans_m[tags[:seq_len - 1], tags[1:]].masked_fill(flip_mask[1:, :], 0) # emit_score [L, B] @@ -278,7 +278,7 @@ class ConditionalRandomField(nn.Module): """ batch_size, seq_len, n_tags = logits.size() logits = logits.transpose(0, 1).data # L, B, H - mask = mask.transpose(0, 1).data.byte() # L, B + mask = mask.transpose(0, 1).data.eq(1) # L, B # dp vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) diff --git a/fastNLP/modules/decoder/utils.py b/fastNLP/modules/decoder/utils.py index 9e773336..3d5ac3f8 100644 --- a/fastNLP/modules/decoder/utils.py +++ b/fastNLP/modules/decoder/utils.py @@ -27,7 +27,7 @@ def viterbi_decode(logits, transitions, mask=None, unpad=False): "compatible." logits = logits.transpose(0, 1).data # L, B, H if mask is not None: - mask = mask.transpose(0, 1).data.byte() # L, B + mask = mask.transpose(0, 1).data.eq(1) # L, B else: mask = logits.new_ones((seq_len, batch_size), dtype=torch.uint8)