Browse Source

解决ELMO不支持使用cuda

tags/v1.0.0alpha
yh_cc 3 years ago
parent
commit
e78fc82a6e
2 changed files with 3 additions and 3 deletions
  1. +1
    -1
      fastNLP/models/biaffine_parser.py
  2. +2
    -2
      fastNLP/modules/encoder/_elmo.py

+ 1
- 1
fastNLP/models/biaffine_parser.py View File

@@ -376,7 +376,7 @@ class BiaffineParser(GraphParser):
if self.encoder_name.endswith('lstm'): if self.encoder_name.endswith('lstm'):
sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True)
x = x[sort_idx] x = x[sort_idx]
x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True)
x = nn.utils.rnn.pack_padded_sequence(x, sort_lens.cpu(), batch_first=True)
feat, _ = self.encoder(x) # -> [N,L,C] feat, _ = self.encoder(x) # -> [N,L,C]
feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True)
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False)


+ 2
- 2
fastNLP/modules/encoder/_elmo.py View File

@@ -251,7 +251,7 @@ class LstmbiLm(nn.Module):
def forward(self, inputs, seq_len): def forward(self, inputs, seq_len):
sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True)
inputs = inputs[sort_idx] inputs = inputs[sort_idx]
inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens, batch_first=self.batch_first)
inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens.cpu(), batch_first=self.batch_first)
output, hx = self.encoder(inputs, None) # -> [N,L,C] output, hx = self.encoder(inputs, None) # -> [N,L,C]
output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=self.batch_first) output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=self.batch_first)
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False)
@@ -316,7 +316,7 @@ class ElmobiLm(torch.nn.Module):
max_len = inputs.size(1) max_len = inputs.size(1)
sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True)
inputs = inputs[sort_idx] inputs = inputs[sort_idx]
inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens, batch_first=True)
inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens.cpu(), batch_first=True)
output, _ = self._lstm_forward(inputs, None) output, _ = self._lstm_forward(inputs, None)
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False)
output = output[:, unsort_idx] output = output[:, unsort_idx]


Loading…
Cancel
Save