Browse Source

pytorch1.2版本中新增boolTensor类型,所有的masked_fill必须为ByteTensor类型的索引,修改fastNLP以适配

tags/v0.4.10
yh_cc 6 years ago
parent
commit
f18ab642d7
5 changed files with 9 additions and 9 deletions
  1. +2
    -2
      fastNLP/embeddings/bert_embedding.py
  2. +2
    -2
      fastNLP/embeddings/embedding.py
  3. +1
    -1
      fastNLP/models/biaffine_parser.py
  4. +3
    -3
      fastNLP/modules/decoder/crf.py
  5. +1
    -1
      fastNLP/modules/decoder/utils.py

+ 2
- 2
fastNLP/embeddings/bert_embedding.py View File

@@ -115,7 +115,7 @@ class BertEmbedding(ContextualEmbedding):
if self._word_sep_index: # 不能drop sep
sep_mask = words.eq(self._word_sep_index)
mask = torch.ones_like(words).float() * self.word_dropout
mask = torch.bernoulli(mask).byte() # dropout_word越大,越多位置为1
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
words = words.masked_fill(mask, self._word_unk_index)
if self._word_sep_index:
words.masked_fill_(sep_mask, self._word_sep_index)
@@ -252,7 +252,7 @@ class BertWordPieceEncoder(nn.Module):
if self._word_sep_index: # 不能drop sep
sep_mask = words.eq(self._wordpiece_unk_index)
mask = torch.ones_like(words).float() * self.word_dropout
mask = torch.bernoulli(mask).byte() # dropout_word越大,越多位置为1
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
words = words.masked_fill(mask, self._word_unk_index)
if self._word_sep_index:
words.masked_fill_(sep_mask, self._wordpiece_unk_index)


+ 2
- 2
fastNLP/embeddings/embedding.py View File

@@ -63,7 +63,7 @@ class Embedding(nn.Module):
"""
if self.word_dropout>0 and self.training:
mask = torch.ones_like(words).float() * self.word_dropout
mask = torch.bernoulli(mask).byte() # dropout_word越大,越多位置为1
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
words = words.masked_fill(mask, self.unk_index)
words = self.embed(words)
return self.dropout(words)
@@ -135,7 +135,7 @@ class TokenEmbedding(nn.Module):
"""
if self.word_dropout > 0 and self.training:
mask = torch.ones_like(words).float() * self.word_dropout
mask = torch.bernoulli(mask).byte() # dropout_word越大,越多位置为1
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
words = words.masked_fill(mask, self._word_unk_index)
return words



+ 1
- 1
fastNLP/models/biaffine_parser.py View File

@@ -150,7 +150,7 @@ class GraphParser(BaseModel):
"""
_, seq_len, _ = arc_matrix.shape
matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf))
flip_mask = (mask == 0).byte()
flip_mask = mask.eq(0)
matrix.masked_fill_(flip_mask.unsqueeze(1), -np.inf)
_, heads = torch.max(matrix, dim=2)
if mask is not None:


+ 3
- 3
fastNLP/modules/decoder/crf.py View File

@@ -210,7 +210,7 @@ class ConditionalRandomField(nn.Module):
trans_score = self.trans_m.view(1, n_tags, n_tags)
tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score
alpha = torch.logsumexp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \
alpha.masked_fill(mask[i].byte().view(batch_size, 1), 0)
alpha.masked_fill(mask[i].eq(1).view(batch_size, 1), 0)

if self.include_start_end_trans:
alpha = alpha + self.end_scores.view(1, -1)
@@ -230,7 +230,7 @@ class ConditionalRandomField(nn.Module):
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device)

# trans_socre [L-1, B]
mask = mask.byte()
mask = mask.eq(1)
flip_mask = mask.eq(0)
trans_score = self.trans_m[tags[:seq_len - 1], tags[1:]].masked_fill(flip_mask[1:, :], 0)
# emit_score [L, B]
@@ -278,7 +278,7 @@ class ConditionalRandomField(nn.Module):
"""
batch_size, seq_len, n_tags = logits.size()
logits = logits.transpose(0, 1).data # L, B, H
mask = mask.transpose(0, 1).data.byte() # L, B
mask = mask.transpose(0, 1).data.eq(1) # L, B

# dp
vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long)


+ 1
- 1
fastNLP/modules/decoder/utils.py View File

@@ -27,7 +27,7 @@ def viterbi_decode(logits, transitions, mask=None, unpad=False):
"compatible."
logits = logits.transpose(0, 1).data # L, B, H
if mask is not None:
mask = mask.transpose(0, 1).data.byte() # L, B
mask = mask.transpose(0, 1).data.eq(1) # L, B
else:
mask = logits.new_ones((seq_len, batch_size), dtype=torch.uint8)



Loading…
Cancel
Save