Browse Source

Merge branch 'master' of github.com:fastnlp/fastNLP

tags/v0.5.5
yh_cc 5 years ago
parent
commit
ae8079cd56
2 changed files with 4 additions and 3 deletions
  1. +1
    -1
      fastNLP/embeddings/bert_embedding.py
  2. +3
    -2
      fastNLP/modules/encoder/lstm.py

+ 1
- 1
fastNLP/embeddings/bert_embedding.py View File

@@ -259,7 +259,7 @@ class _WordBertModel(nn.Module):
if '[sep]' in vocab:
warnings.warn("Lower cased [sep] detected, it cannot be correctly recognized as [SEP] by BertEmbedding.")
if "[CLS]" in vocab:
warnings.warn("[CLS] detected in your vocabulary. BertEmbedding will add [CSL] and [SEP] to the begin "
warnings.warn("[CLS] detected in your vocabulary. BertEmbedding will add [CLS] and [SEP] to the begin "
"and end of the input automatically, make sure you don't add [CLS] and [SEP] at the begin"
" and end.")
for word, index in vocab:


+ 3
- 2
fastNLP/modules/encoder/lstm.py View File

@@ -56,8 +56,8 @@ class LSTM(nn.Module):
:param seq_len: [batch, ] 序列长度, 若为 ``None``, 所有输入看做一样长. Default: ``None``
:param h0: [batch, hidden_size] 初始隐状态, 若为 ``None`` , 设为全0向量. Default: ``None``
:param c0: [batch, hidden_size] 初始Cell状态, 若为 ``None`` , 设为全0向量. Default: ``None``
:return (output, ht) 或 output: 若 ``get_hidden=True`` [batch, seq_len, hidden_size*num_direction] 输出序列
[batch, hidden_size*num_direction] 最后时刻隐状态.
:return (output, (ht, ct)): output: [batch, seq_len, hidden_size*num_direction] 输出序列
ht,ct: [num_layers*num_direction, batch, hidden_size] 最后时刻隐状态.
"""
batch_size, max_len, _ = x.size()
if h0 is not None and c0 is not None:
@@ -78,6 +78,7 @@ class LSTM(nn.Module):
output = output[unsort_idx]
else:
output = output[:, unsort_idx]
hx = hx[0][:, unsort_idx], hx[1][:, unsort_idx]
else:
output, hx = self.lstm(x, hx)
return output, hx

Loading…
Cancel
Save