|
|
@@ -11,13 +11,15 @@ import torch.nn as nn |
|
|
|
import torch.nn.utils.rnn as rnn |
|
|
|
|
|
|
|
from ..utils import initial_parameter |
|
|
|
from torch import autograd |
|
|
|
|
|
|
|
|
|
|
|
class LSTM(nn.Module): |
|
|
|
""" |
|
|
|
别名::class:`fastNLP.modules.LSTM` :class:`fastNLP.modules.encoder.lstm.LSTM` |
|
|
|
|
|
|
|
LSTM 模块, 轻量封装的Pytorch LSTM |
|
|
|
LSTM 模块, 轻量封装的Pytorch LSTM. 在提供seq_len的情况下,将自动使用pack_padded_sequence; 同时默认将forget gate的bias初始化 |
|
|
|
为1; 且可以应对DataParallel中LSTM的使用问题 |
|
|
|
|
|
|
|
:param input_size: 输入 `x` 的特征维度 |
|
|
|
:param hidden_size: 隐状态 `h` 的特征维度. |
|
|
@@ -59,6 +61,7 @@ class LSTM(nn.Module): |
|
|
|
:return (output, ht) 或 output: 若 ``get_hidden=True`` [batch, seq_len, hidden_size*num_direction] 输出序列 |
|
|
|
和 [batch, hidden_size*num_direction] 最后时刻隐状态. |
|
|
|
""" |
|
|
|
batch_size, max_len, _ = x.size() |
|
|
|
if h0 is not None and c0 is not None: |
|
|
|
hx = (h0, c0) |
|
|
|
else: |
|
|
@@ -77,6 +80,10 @@ class LSTM(nn.Module): |
|
|
|
output = output[unsort_idx] |
|
|
|
else: |
|
|
|
output = output[:, unsort_idx] |
|
|
|
# 解决LSTM无法在DataParallel下使用的问题问题https://github.com/pytorch/pytorch/issues/1591 |
|
|
|
if output.size(1) < max_len: |
|
|
|
dummy_tensor = autograd.Variable(torch.zeros(batch_size, max_len - output.size(1), output.size(-1))) |
|
|
|
output = torch.cat([output, dummy_tensor], 1) |
|
|
|
else: |
|
|
|
output, hx = self.lstm(x, hx) |
|
|
|
return output, hx |