@@ -115,7 +115,7 @@ class BertEmbedding(ContextualEmbedding): | |||||
if self._word_sep_index: # 不能drop sep | if self._word_sep_index: # 不能drop sep | ||||
sep_mask = words.eq(self._word_sep_index) | sep_mask = words.eq(self._word_sep_index) | ||||
mask = torch.ones_like(words).float() * self.word_dropout | mask = torch.ones_like(words).float() * self.word_dropout | ||||
mask = torch.bernoulli(mask).byte() # dropout_word越大,越多位置为1 | |||||
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | |||||
words = words.masked_fill(mask, self._word_unk_index) | words = words.masked_fill(mask, self._word_unk_index) | ||||
if self._word_sep_index: | if self._word_sep_index: | ||||
words.masked_fill_(sep_mask, self._word_sep_index) | words.masked_fill_(sep_mask, self._word_sep_index) | ||||
@@ -252,7 +252,7 @@ class BertWordPieceEncoder(nn.Module): | |||||
if self._word_sep_index: # 不能drop sep | if self._word_sep_index: # 不能drop sep | ||||
sep_mask = words.eq(self._wordpiece_unk_index) | sep_mask = words.eq(self._wordpiece_unk_index) | ||||
mask = torch.ones_like(words).float() * self.word_dropout | mask = torch.ones_like(words).float() * self.word_dropout | ||||
mask = torch.bernoulli(mask).byte() # dropout_word越大,越多位置为1 | |||||
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | |||||
words = words.masked_fill(mask, self._word_unk_index) | words = words.masked_fill(mask, self._word_unk_index) | ||||
if self._word_sep_index: | if self._word_sep_index: | ||||
words.masked_fill_(sep_mask, self._wordpiece_unk_index) | words.masked_fill_(sep_mask, self._wordpiece_unk_index) | ||||
@@ -63,7 +63,7 @@ class Embedding(nn.Module): | |||||
""" | """ | ||||
if self.word_dropout>0 and self.training: | if self.word_dropout>0 and self.training: | ||||
mask = torch.ones_like(words).float() * self.word_dropout | mask = torch.ones_like(words).float() * self.word_dropout | ||||
mask = torch.bernoulli(mask).byte() # dropout_word越大,越多位置为1 | |||||
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | |||||
words = words.masked_fill(mask, self.unk_index) | words = words.masked_fill(mask, self.unk_index) | ||||
words = self.embed(words) | words = self.embed(words) | ||||
return self.dropout(words) | return self.dropout(words) | ||||
@@ -135,7 +135,7 @@ class TokenEmbedding(nn.Module): | |||||
""" | """ | ||||
if self.word_dropout > 0 and self.training: | if self.word_dropout > 0 and self.training: | ||||
mask = torch.ones_like(words).float() * self.word_dropout | mask = torch.ones_like(words).float() * self.word_dropout | ||||
mask = torch.bernoulli(mask).byte() # dropout_word越大,越多位置为1 | |||||
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | |||||
words = words.masked_fill(mask, self._word_unk_index) | words = words.masked_fill(mask, self._word_unk_index) | ||||
return words | return words | ||||
@@ -150,7 +150,7 @@ class GraphParser(BaseModel): | |||||
""" | """ | ||||
_, seq_len, _ = arc_matrix.shape | _, seq_len, _ = arc_matrix.shape | ||||
matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf)) | matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf)) | ||||
flip_mask = (mask == 0).byte() | |||||
flip_mask = mask.eq(0) | |||||
matrix.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | matrix.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | ||||
_, heads = torch.max(matrix, dim=2) | _, heads = torch.max(matrix, dim=2) | ||||
if mask is not None: | if mask is not None: | ||||
@@ -210,7 +210,7 @@ class ConditionalRandomField(nn.Module): | |||||
trans_score = self.trans_m.view(1, n_tags, n_tags) | trans_score = self.trans_m.view(1, n_tags, n_tags) | ||||
tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score | tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score | ||||
alpha = torch.logsumexp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \ | alpha = torch.logsumexp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \ | ||||
alpha.masked_fill(mask[i].byte().view(batch_size, 1), 0) | |||||
alpha.masked_fill(mask[i].eq(1).view(batch_size, 1), 0) | |||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
alpha = alpha + self.end_scores.view(1, -1) | alpha = alpha + self.end_scores.view(1, -1) | ||||
@@ -230,7 +230,7 @@ class ConditionalRandomField(nn.Module): | |||||
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) | seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) | ||||
# trans_socre [L-1, B] | # trans_socre [L-1, B] | ||||
mask = mask.byte() | |||||
mask = mask.eq(1) | |||||
flip_mask = mask.eq(0) | flip_mask = mask.eq(0) | ||||
trans_score = self.trans_m[tags[:seq_len - 1], tags[1:]].masked_fill(flip_mask[1:, :], 0) | trans_score = self.trans_m[tags[:seq_len - 1], tags[1:]].masked_fill(flip_mask[1:, :], 0) | ||||
# emit_score [L, B] | # emit_score [L, B] | ||||
@@ -278,7 +278,7 @@ class ConditionalRandomField(nn.Module): | |||||
""" | """ | ||||
batch_size, seq_len, n_tags = logits.size() | batch_size, seq_len, n_tags = logits.size() | ||||
logits = logits.transpose(0, 1).data # L, B, H | logits = logits.transpose(0, 1).data # L, B, H | ||||
mask = mask.transpose(0, 1).data.byte() # L, B | |||||
mask = mask.transpose(0, 1).data.eq(1) # L, B | |||||
# dp | # dp | ||||
vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | ||||
@@ -27,7 +27,7 @@ def viterbi_decode(logits, transitions, mask=None, unpad=False): | |||||
"compatible." | "compatible." | ||||
logits = logits.transpose(0, 1).data # L, B, H | logits = logits.transpose(0, 1).data # L, B, H | ||||
if mask is not None: | if mask is not None: | ||||
mask = mask.transpose(0, 1).data.byte() # L, B | |||||
mask = mask.transpose(0, 1).data.eq(1) # L, B | |||||
else: | else: | ||||
mask = logits.new_ones((seq_len, batch_size), dtype=torch.uint8) | mask = logits.new_ones((seq_len, batch_size), dtype=torch.uint8) | ||||