|
@@ -251,7 +251,7 @@ class LstmbiLm(nn.Module): |
|
|
def forward(self, inputs, seq_len): |
|
|
def forward(self, inputs, seq_len): |
|
|
sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) |
|
|
sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) |
|
|
inputs = inputs[sort_idx] |
|
|
inputs = inputs[sort_idx] |
|
|
inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens, batch_first=self.batch_first) |
|
|
|
|
|
|
|
|
inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens.cpu(), batch_first=self.batch_first) |
|
|
output, hx = self.encoder(inputs, None) # -> [N,L,C] |
|
|
output, hx = self.encoder(inputs, None) # -> [N,L,C] |
|
|
output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=self.batch_first) |
|
|
output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=self.batch_first) |
|
|
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) |
|
|
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) |
|
@@ -316,7 +316,7 @@ class ElmobiLm(torch.nn.Module): |
|
|
max_len = inputs.size(1) |
|
|
max_len = inputs.size(1) |
|
|
sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) |
|
|
sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) |
|
|
inputs = inputs[sort_idx] |
|
|
inputs = inputs[sort_idx] |
|
|
inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens, batch_first=True) |
|
|
|
|
|
|
|
|
inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens.cpu(), batch_first=True) |
|
|
output, _ = self._lstm_forward(inputs, None) |
|
|
output, _ = self._lstm_forward(inputs, None) |
|
|
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) |
|
|
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) |
|
|
output = output[:, unsort_idx] |
|
|
output = output[:, unsort_idx] |
|
|