From 39f1185e1d5cd93734c7d0e03301188c212984fb Mon Sep 17 00:00:00 2001 From: yh Date: Thu, 28 Nov 2019 17:37:18 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8DBiLSTM=E4=B8=ADht,=20ct?= =?UTF-8?q?=E7=8A=B6=E6=80=81=E5=9C=A8=E6=9C=89seq=5Flen=E6=97=B6=E9=A1=BA?= =?UTF-8?q?=E5=BA=8F=E9=94=99=E4=B9=B1=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/modules/encoder/lstm.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/fastNLP/modules/encoder/lstm.py b/fastNLP/modules/encoder/lstm.py index 06b437ef..1ae61ec0 100644 --- a/fastNLP/modules/encoder/lstm.py +++ b/fastNLP/modules/encoder/lstm.py @@ -56,8 +56,8 @@ class LSTM(nn.Module): :param seq_len: [batch, ] 序列长度, 若为 ``None``, 所有输入看做一样长. Default: ``None`` :param h0: [batch, hidden_size] 初始隐状态, 若为 ``None`` , 设为全0向量. Default: ``None`` :param c0: [batch, hidden_size] 初始Cell状态, 若为 ``None`` , 设为全0向量. Default: ``None`` - :return (output, ht) 或 output: 若 ``get_hidden=True`` [batch, seq_len, hidden_size*num_direction] 输出序列 - 和 [batch, hidden_size*num_direction] 最后时刻隐状态. + :return (output, (ht, ct)): output: [batch, seq_len, hidden_size*num_direction] 输出序列 + 和 ht,ct: [num_layers*num_direction, batch, hidden_size] 最后时刻隐状态. """ batch_size, max_len, _ = x.size() if h0 is not None and c0 is not None: @@ -78,6 +78,7 @@ class LSTM(nn.Module): output = output[unsort_idx] else: output = output[:, unsort_idx] + hx = hx[0][:, unsort_idx], hx[1][:, unsort_idx] else: output, hx = self.lstm(x, hx) return output, hx