Browse Source

Merge pull request #246 from fastnlp/dev0.5.0

修复针对pytorch1.3.0版本bug的一个bug
tags/v0.5.5
yhcc GitHub 4 years ago
parent
commit
fb1ad7a8fa
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 34 additions and 34 deletions
  1. +2
    -2
      fastNLP/core/losses.py
  2. +3
    -3
      fastNLP/core/metrics.py
  3. +3
    -3
      fastNLP/embeddings/bert_embedding.py
  4. +3
    -3
      fastNLP/embeddings/char_embedding.py
  5. +2
    -2
      fastNLP/models/biaffine_parser.py
  6. +1
    -1
      fastNLP/models/snli.py
  7. +6
    -6
      fastNLP/modules/decoder/crf.py
  8. +1
    -1
      fastNLP/modules/decoder/utils.py
  9. +1
    -1
      fastNLP/modules/encoder/star_transformer.py
  10. +1
    -1
      fastNLP/modules/encoder/transformer.py
  11. +1
    -1
      reproduction/Summarization/Baseline/model/Loss.py
  12. +1
    -1
      reproduction/Summarization/Baseline/model/Metric.py
  13. +2
    -2
      reproduction/Summarization/BertSum/model.py
  14. +1
    -1
      reproduction/joint_cws_parse/models/CharParser.py
  15. +1
    -1
      reproduction/matching/model/esim.py
  16. +4
    -4
      reproduction/matching/model/mwan.py
  17. +1
    -1
      reproduction/multi-criteria-cws/transformer.py

+ 2
- 2
fastNLP/core/losses.py View File

@@ -229,7 +229,7 @@ class CrossEntropyLoss(LossBase):
def get_loss(self, pred, target, seq_len=None):
if seq_len is not None and target.dim()>1:
mask = seq_len_to_mask(seq_len, max_len=target.size(1)).eq(0)
mask = seq_len_to_mask(seq_len, max_len=target.size(1)).eq(False)
target = target.masked_fill(mask, self.padding_idx)

if pred.dim() > 2:
@@ -374,7 +374,7 @@ class CMRC2018Loss(LossBase):
:return:
"""
batch_size, max_len = pred_end.size()
mask = seq_len_to_mask(context_len, max_len).eq(0)
mask = seq_len_to_mask(context_len, max_len).eq(False)

pred_start = pred_start.masked_fill(mask, float('-inf'))
pred_end = pred_end.masked_fill(mask, float('-inf'))


+ 3
- 3
fastNLP/core/metrics.py View File

@@ -358,7 +358,7 @@ class AccuracyMetric(MetricBase):

target = target.to(pred)
if masks is not None:
self.acc_count += torch.sum(torch.eq(pred, target).masked_fill(masks.eq(0), 0)).item()
self.acc_count += torch.sum(torch.eq(pred, target).masked_fill(masks.eq(False), 0)).item()
self.total += torch.sum(masks).item()
else:
self.acc_count += torch.sum(torch.eq(pred, target)).item()
@@ -465,7 +465,7 @@ class ClassifyFPreRecMetric(MetricBase):
masks = seq_len_to_mask(seq_len=seq_len, max_len=max_len)
else:
masks = torch.ones_like(target).long().to(target.device)
masks = masks.eq(0)
masks = masks.eq(False)

if pred.dim() == target.dim():
pass
@@ -1017,7 +1017,7 @@ class CMRC2018Metric(MetricBase):
:return:
"""
batch_size, max_len = pred_start.size()
context_mask = seq_len_to_mask(context_len, max_len=max_len).eq(0)
context_mask = seq_len_to_mask(context_len, max_len=max_len).eq(False)
pred_start.masked_fill_(context_mask, float('-inf'))
pred_end.masked_fill_(context_mask, float('-inf'))
max_pred_start, pred_start_index = pred_start.max(dim=-1, keepdim=True) # batch_size,


+ 3
- 3
fastNLP/embeddings/bert_embedding.py View File

@@ -325,7 +325,7 @@ class _WordBertModel(nn.Module):
batch_size, max_word_len = words.size()
word_mask = words.ne(self._word_pad_index) # 为1的地方有word
seq_len = word_mask.sum(dim=-1)
batch_word_pieces_length = self.word_pieces_lengths[words].masked_fill(word_mask.eq(0),
batch_word_pieces_length = self.word_pieces_lengths[words].masked_fill(word_mask.eq(False),
0) # batch_size x max_len
word_pieces_lengths = batch_word_pieces_length.sum(dim=-1) # batch_size
word_piece_length = batch_word_pieces_length.sum(dim=-1).max().item() # 表示word piece的长度(包括padding)
@@ -403,12 +403,12 @@ class _WordBertModel(nn.Module):
truncate_output_layer = output_layer[:, 1:-1] # 删除[CLS]与[SEP] batch_size x len x hidden_size
if self.pool_method == 'first':
tmp = truncate_output_layer[_batch_indexes, batch_word_pieces_cum_length]
tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(0), 0)
tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(False), 0)
outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1)+s_shift] = tmp

elif self.pool_method == 'last':
tmp = truncate_output_layer[_batch_indexes, batch_word_pieces_cum_length]
tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(0), 0)
tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(False), 0)
outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1)+s_shift] = tmp
elif self.pool_method == 'max':
for i in range(batch_size):


+ 3
- 3
fastNLP/embeddings/char_embedding.py View File

@@ -148,7 +148,7 @@ class CNNCharEmbedding(TokenEmbedding):
chars, _ = torch.max(conv_chars, dim=-2) # batch_size x max_len x sum(filters)
else:
conv_chars = conv_chars.masked_fill(chars_masks.unsqueeze(-1), 0)
chars = torch.sum(conv_chars, dim=-2) / chars_masks.eq(0).sum(dim=-1, keepdim=True).float()
chars = torch.sum(conv_chars, dim=-2) / chars_masks.eq(False).sum(dim=-1, keepdim=True).float()
chars = self.fc(chars)
return self.dropout(chars)

@@ -266,7 +266,7 @@ class LSTMCharEmbedding(TokenEmbedding):
chars = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size
chars = self.dropout(chars)
reshaped_chars = chars.reshape(batch_size * max_len, max_word_len, -1)
char_seq_len = chars_masks.eq(0).sum(dim=-1).reshape(batch_size * max_len)
char_seq_len = chars_masks.eq(False).sum(dim=-1).reshape(batch_size * max_len)
lstm_chars = self.lstm(reshaped_chars, char_seq_len)[0].reshape(batch_size, max_len, max_word_len, -1)
# B x M x M x H
@@ -276,7 +276,7 @@ class LSTMCharEmbedding(TokenEmbedding):
chars, _ = torch.max(lstm_chars, dim=-2) # batch_size x max_len x H
else:
lstm_chars = lstm_chars.masked_fill(chars_masks.unsqueeze(-1), 0)
chars = torch.sum(lstm_chars, dim=-2) / chars_masks.eq(0).sum(dim=-1, keepdim=True).float()
chars = torch.sum(lstm_chars, dim=-2) / chars_masks.eq(False).sum(dim=-1, keepdim=True).float()
chars = self.fc(chars)


+ 2
- 2
fastNLP/models/biaffine_parser.py View File

@@ -148,7 +148,7 @@ class GraphParser(BaseModel):
"""
_, seq_len, _ = arc_matrix.shape
matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf))
flip_mask = mask.eq(0)
flip_mask = mask.eq(False)
matrix.masked_fill_(flip_mask.unsqueeze(1), -np.inf)
_, heads = torch.max(matrix, dim=2)
if mask is not None:
@@ -441,7 +441,7 @@ class BiaffineParser(GraphParser):
batch_size, length, _ = pred1.shape
mask = seq_len_to_mask(seq_len, max_len=length)
flip_mask = (mask == 0)
flip_mask = (mask.eq(False))
_arc_pred = pred1.clone()
_arc_pred = _arc_pred.masked_fill(flip_mask.unsqueeze(1), -float('inf'))
arc_logits = F.log_softmax(_arc_pred, dim=2)


+ 1
- 1
fastNLP/models/snli.py View File

@@ -152,7 +152,7 @@ class BiRNN(nn.Module):

def forward(self, x, x_mask):
# Sort x
lengths = x_mask.data.eq(1).long().sum(1)
lengths = x_mask.data.eq(True).long().sum(1)
_, idx_sort = torch.sort(lengths, dim=0, descending=True)
_, idx_unsort = torch.sort(idx_sort, dim=0)
lengths = list(lengths[idx_sort])


+ 6
- 6
fastNLP/modules/decoder/crf.py View File

@@ -217,14 +217,14 @@ class ConditionalRandomField(nn.Module):
if self.include_start_end_trans:
alpha = alpha + self.start_scores.view(1, -1)

flip_mask = mask.eq(0)
flip_mask = mask.eq(False)

for i in range(1, seq_len):
emit_score = logits[i].view(batch_size, 1, n_tags)
trans_score = self.trans_m.view(1, n_tags, n_tags)
tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score
alpha = torch.logsumexp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \
alpha.masked_fill(mask[i].eq(1).view(batch_size, 1), 0)
alpha.masked_fill(mask[i].eq(True).view(batch_size, 1), 0)

if self.include_start_end_trans:
alpha = alpha + self.end_scores.view(1, -1)
@@ -244,8 +244,8 @@ class ConditionalRandomField(nn.Module):
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device)

# trans_socre [L-1, B]
mask = mask.eq(1)
flip_mask = mask.eq(0)
mask = mask.eq(True)
flip_mask = mask.eq(False)
trans_score = self.trans_m[tags[:seq_len - 1], tags[1:]].masked_fill(flip_mask[1:, :], 0)
# emit_score [L, B]
emit_score = logits[seq_idx.view(-1, 1), batch_idx.view(1, -1), tags].masked_fill(flip_mask, 0)
@@ -292,7 +292,7 @@ class ConditionalRandomField(nn.Module):
"""
batch_size, seq_len, n_tags = logits.size()
logits = logits.transpose(0, 1).data # L, B, H
mask = mask.transpose(0, 1).data.eq(1) # L, B
mask = mask.transpose(0, 1).data.eq(True) # L, B

# dp
vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long)
@@ -311,7 +311,7 @@ class ConditionalRandomField(nn.Module):
score = prev_score + trans_score + cur_score
best_score, best_dst = score.max(1)
vpath[i] = best_dst
vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \
vscore = best_score.masked_fill(mask[i].eq(False).view(batch_size, 1), 0) + \
vscore.masked_fill(mask[i].view(batch_size, 1), 0)

if self.include_start_end_trans:


+ 1
- 1
fastNLP/modules/decoder/utils.py View File

@@ -27,7 +27,7 @@ def viterbi_decode(logits, transitions, mask=None, unpad=False):
"compatible."
logits = logits.transpose(0, 1).data # L, B, H
if mask is not None:
mask = mask.transpose(0, 1).data.eq(1) # L, B
mask = mask.transpose(0, 1).data.eq(True) # L, B
else:
mask = logits.new_ones((seq_len, batch_size), dtype=torch.uint8)



+ 1
- 1
fastNLP/modules/encoder/star_transformer.py View File

@@ -65,7 +65,7 @@ class StarTransformer(nn.Module):
return f(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)

B, L, H = data.size()
mask = (mask == 0) # flip the mask for masked_fill_
mask = (mask.eq(False)) # flip the mask for masked_fill_
smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1)

embs = data.permute(0, 2, 1)[:, :, :, None] # B H L 1


+ 1
- 1
fastNLP/modules/encoder/transformer.py View File

@@ -71,7 +71,7 @@ class TransformerEncoder(nn.Module):
if seq_mask is None:
atte_mask_out = None
else:
atte_mask_out = (seq_mask == 0)[:, None, :]
atte_mask_out = (seq_mask.eq(False))[:, None, :]
seq_mask = seq_mask[:, :, None]
for layer in self.layers:
output = layer(output, seq_mask, atte_mask_out)


+ 1
- 1
reproduction/Summarization/Baseline/model/Loss.py View File

@@ -47,7 +47,7 @@ class MyCrossEntropyLoss(LossBase):
loss = F.cross_entropy(input=pred, target=target,
ignore_index=self.padding_idx, reduction=self.reduce)
loss = loss.view(batch, -1)
loss = loss.masked_fill(mask.eq(0), 0)
loss = loss.masked_fill(mask.eq(False), 0)
loss = loss.sum(1).mean()
logger.debug("loss %f", loss)
return loss


+ 1
- 1
reproduction/Summarization/Baseline/model/Metric.py View File

@@ -57,7 +57,7 @@ class LossMetric(MetricBase):
loss = F.cross_entropy(input=pred, target=target,
ignore_index=self.padding_idx, reduction=self.reduce)
loss = loss.view(batch, -1)
loss = loss.masked_fill(mask.eq(0), 0)
loss = loss.masked_fill(mask.eq(False), 0)
loss = loss.sum(1).mean()
self.loss += loss
self.iteration += 1


+ 2
- 2
reproduction/Summarization/BertSum/model.py View File

@@ -33,8 +33,8 @@ class BertSum(nn.Module):
# print(segment_id.device)
# print(cls_id.device)

input_mask = 1 - (article == 0)
mask_cls = 1 - (cls_id == -1)
input_mask = 1 - (article == 0).long()
mask_cls = 1 - (cls_id == -1).long()
assert input_mask.size() == article.size()
assert mask_cls.size() == cls_id.size()



+ 1
- 1
reproduction/joint_cws_parse/models/CharParser.py View File

@@ -223,7 +223,7 @@ class CharBiaffineParser(BiaffineParser):
"""

batch_size, seq_len, _ = arc_pred.shape
flip_mask = (mask == 0)
flip_mask = (mask.eq(False))
# _arc_pred = arc_pred.clone()
_arc_pred = arc_pred.masked_fill(flip_mask.unsqueeze(1), -float('inf'))



+ 1
- 1
reproduction/matching/model/esim.py View File

@@ -119,7 +119,7 @@ class BiRNN(nn.Module):

def forward(self, x, x_mask):
# Sort x
lengths = x_mask.data.eq(1).long().sum(1)
lengths = x_mask.data.eq(True).long().sum(1)
_, idx_sort = torch.sort(lengths, dim=0, descending=True)
_, idx_unsort = torch.sort(idx_sort, dim=0)
lengths = list(lengths[idx_sort])


+ 4
- 4
reproduction/matching/model/mwan.py View File

@@ -91,7 +91,7 @@ class ConcatAttention_Param(nn.Module):

s = self.v(tc.tanh(self.ln(tc.cat([h,vq],-1)))).squeeze(-1) # (batch_size, len)
s = s - ((mask == 0).float() * 10000)
s = s - ((mask.eq(False)).float() * 10000)
a = tc.softmax(s, dim=1)

r = a.unsqueeze(-1) * h # (batch_size, len, input_size)
@@ -121,7 +121,7 @@ def Attention(hq, hp, mask_hq, mask_hp, my_method):

s = my_method(hq_mat, hp_mat) # (batch_size, len_q, len_p)

s = s - ((mask_mat == 0).float() * 10000)
s = s - ((mask_mat.eq(False)).float() * 10000)
a = tc.softmax(s, dim=1)

q = a.unsqueeze(-1) * hq_mat #(batch_size, len_q, len_p, input_size)
@@ -242,7 +242,7 @@ class BiLinearAttention(nn.Module):

s = self.my_method(hq, hp, mask_hp) # (batch_size, len_q, len_p)

s = s - ((mask_mat == 0).float() * 10000)
s = s - ((mask_mat.eq(False)).float() * 10000)
a = tc.softmax(s, dim=1)

hq_mat = hq.unsqueeze(2).expand(standard_size)
@@ -285,7 +285,7 @@ class AggAttention(nn.Module):

s = self.v(tc.tanh(self.ln(tc.cat([hs,vq],-1)))).squeeze(-1)# (4, batch_size, len_q)

s = s - ((mask.unsqueeze(0) == 0).float() * 10000)
s = s - ((mask.unsqueeze(0).eq(False)).float() * 10000)
a = tc.softmax(s, dim=0)

x = a.unsqueeze(-1) * hs


+ 1
- 1
reproduction/multi-criteria-cws/transformer.py View File

@@ -27,7 +27,7 @@ def attention(query, key, value, mask=None, dropout=None):
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
# print(scores.size(),mask.size()) # [bsz,1,1,len]
scores = scores.masked_fill(mask == 0, -1e9)
scores = scores.masked_fill(mask.eq(False), -1e9)
p_attn = F.softmax(scores, dim=-1)
if dropout is not None:
p_attn = dropout(p_attn)


Loading…
Cancel
Save