|
|
@@ -56,8 +56,8 @@ class LSTM(nn.Module): |
|
|
|
:param seq_len: [batch, ] 序列长度, 若为 ``None``, 所有输入看做一样长. Default: ``None`` |
|
|
|
:param h0: [batch, hidden_size] 初始隐状态, 若为 ``None`` , 设为全0向量. Default: ``None`` |
|
|
|
:param c0: [batch, hidden_size] 初始Cell状态, 若为 ``None`` , 设为全0向量. Default: ``None`` |
|
|
|
:return (output, ht) 或 output: 若 ``get_hidden=True`` [batch, seq_len, hidden_size*num_direction] 输出序列 |
|
|
|
和 [batch, hidden_size*num_direction] 最后时刻隐状态. |
|
|
|
:return (output, (ht, ct)): output: [batch, seq_len, hidden_size*num_direction] 输出序列 |
|
|
|
和 ht,ct: [num_layers*num_direction, batch, hidden_size] 最后时刻隐状态. |
|
|
|
""" |
|
|
|
batch_size, max_len, _ = x.size() |
|
|
|
if h0 is not None and c0 is not None: |
|
|
@@ -78,6 +78,7 @@ class LSTM(nn.Module): |
|
|
|
output = output[unsort_idx] |
|
|
|
else: |
|
|
|
output = output[:, unsort_idx] |
|
|
|
hx = hx[0][:, unsort_idx], hx[1][:, unsort_idx] |
|
|
|
else: |
|
|
|
output, hx = self.lstm(x, hx) |
|
|
|
return output, hx |