From 19bbaf11b6989a1a29384d5b1516bf934ccac296 Mon Sep 17 00:00:00 2001 From: yh Date: Tue, 27 Aug 2019 01:54:15 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BD=BF=E7=94=A8=E6=9B=B4pytorch=E7=9A=84?= =?UTF-8?q?=E6=96=B9=E5=BC=8F=E5=A4=84=E7=90=86embedding=E4=B8=AD=E7=9A=84?= =?UTF-8?q?parameter?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/embeddings/bert_embedding.py | 2 +- fastNLP/embeddings/char_embedding.py | 14 ++++++-------- fastNLP/embeddings/elmo_embedding.py | 5 ++--- fastNLP/embeddings/static_embedding.py | 16 +++++++--------- 4 files changed, 16 insertions(+), 21 deletions(-) diff --git a/fastNLP/embeddings/bert_embedding.py b/fastNLP/embeddings/bert_embedding.py index 6a10c489..f3ef69dd 100644 --- a/fastNLP/embeddings/bert_embedding.py +++ b/fastNLP/embeddings/bert_embedding.py @@ -345,7 +345,7 @@ class _WordBertModel(nn.Module): self._wordpiece_pad_index = self.tokenzier.vocab['[PAD]'] # 需要用于生成word_piece print("Found(Or segment into word pieces) {} words out of {}.".format(found_count, len(vocab))) self.word_to_wordpieces = np.array(word_to_wordpieces) - self.word_pieces_lengths = nn.Parameter(torch.LongTensor(word_pieces_lengths), requires_grad=False) + self.register_buffer('word_pieces_lengths', torch.LongTensor(word_pieces_lengths)) print("Successfully generate word pieces.") def forward(self, words): diff --git a/fastNLP/embeddings/char_embedding.py b/fastNLP/embeddings/char_embedding.py index 520e85e6..ea0d4e93 100644 --- a/fastNLP/embeddings/char_embedding.py +++ b/fastNLP/embeddings/char_embedding.py @@ -82,10 +82,9 @@ class CNNCharEmbedding(TokenEmbedding): print(f"In total, there are {len(self.char_vocab)} distinct characters.") # 对vocab进行index max_word_len = max(map(lambda x: len(x[0]), vocab)) - self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab), max_word_len), - fill_value=self.char_pad_index, dtype=torch.long), - requires_grad=False) - self.word_lengths = nn.Parameter(torch.zeros(len(vocab)).long(), requires_grad=False) + self.register_buffer('words_to_chars_embedding', torch.full((len(vocab), max_word_len), + fill_value=self.char_pad_index, dtype=torch.long)) + self.register_buffer('word_lengths', torch.zeros(len(vocab)).long()) for word, index in vocab: # if index!=vocab.padding_idx: # 如果是pad的话,直接就为pad_value了。修改为不区分pad, 这样所有的也是同一个embed self.words_to_chars_embedding[index, :len(word)] = \ @@ -235,10 +234,9 @@ class LSTMCharEmbedding(TokenEmbedding): print(f"In total, there are {len(self.char_vocab)} distinct characters.") # 对vocab进行index self.max_word_len = max(map(lambda x: len(x[0]), vocab)) - self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab), self.max_word_len), - fill_value=self.char_pad_index, dtype=torch.long), - requires_grad=False) - self.word_lengths = nn.Parameter(torch.zeros(len(vocab)).long(), requires_grad=False) + self.register_buffer('words_to_chars_embedding', torch.full((len(vocab), self.max_word_len), + fill_value=self.char_pad_index, dtype=torch.long)) + self.register_buffer('word_lengths', torch.zeros(len(vocab)).long()) for word, index in vocab: # if index!=vocab.padding_idx: # 如果是pad的话,直接就为pad_value了. 修改为不区分pad与否 self.words_to_chars_embedding[index, :len(word)] = \ diff --git a/fastNLP/embeddings/elmo_embedding.py b/fastNLP/embeddings/elmo_embedding.py index 24cd052e..80178d21 100644 --- a/fastNLP/embeddings/elmo_embedding.py +++ b/fastNLP/embeddings/elmo_embedding.py @@ -240,10 +240,9 @@ class _ElmoModel(nn.Module): # 生成words到chars的映射 max_chars = config['char_cnn']['max_characters_per_token'] - self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab) + 2, max_chars), + self.register_buffer('words_to_chars_embedding', torch.full((len(vocab) + 2, max_chars), fill_value=len(char_vocab), - dtype=torch.long), - requires_grad=False) + dtype=torch.long)) for word, index in list(iter(vocab)) + [(BOS_TAG, len(vocab)), (EOS_TAG, len(vocab) + 1)]: if len(word) + 2 > max_chars: word = word[:max_chars - 2] diff --git a/fastNLP/embeddings/static_embedding.py b/fastNLP/embeddings/static_embedding.py index a75ad18f..b0141682 100644 --- a/fastNLP/embeddings/static_embedding.py +++ b/fastNLP/embeddings/static_embedding.py @@ -121,28 +121,27 @@ class StaticEmbedding(TokenEmbedding): embedding = self._load_with_vocab(model_path, vocab=lowered_vocab, init_method=init_method) else: embedding = self._randomly_init_embed(len(vocab), embedding_dim, init_method) - self.words_to_words = nn.Parameter(torch.arange(len(vocab)).long(), requires_grad=False) + self.register_buffer('words_to_words', torch.arange(len(vocab)).long()) if lowered_vocab.unknown: unknown_idx = lowered_vocab.unknown_idx else: unknown_idx = embedding.size(0) - 1 # 否则是最后一个为unknow - self.words_to_words = nn.Parameter(torch.arange(len(vocab)).long(), requires_grad=False) - words_to_words = nn.Parameter(torch.full((len(vocab),), fill_value=unknown_idx).long(), - requires_grad=False) + self.register_buffer('words_to_words', torch.arange(len(vocab)).long()) + words_to_words = torch.full((len(vocab),), fill_value=unknown_idx).long() for word, index in vocab: if word not in lowered_vocab: word = word.lower() if word not in lowered_vocab and lowered_vocab._is_word_no_create_entry(word): continue # 如果不需要创建entry,已经默认unknown了 words_to_words[index] = self.words_to_words[lowered_vocab.to_index(word)] - self.words_to_words = words_to_words + self.register_buffer('words_to_words', words_to_words) self._word_unk_index = lowered_vocab.unknown_idx # 替换一下unknown的index else: if model_path: embedding = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method) else: embedding = self._randomly_init_embed(len(vocab), embedding_dim, init_method) - self.words_to_words = nn.Parameter(torch.arange(len(vocab)).long(), requires_grad=False) + self.register_buffer('words_to_words', torch.arange(len(vocab)).long()) if not self.only_norm_found_vector and normalize: embedding /= (torch.norm(embedding, dim=1, keepdim=True) + 1e-12) @@ -151,7 +150,7 @@ class StaticEmbedding(TokenEmbedding): index_in_truncated_vocab = truncated_words_to_words[i] truncated_words_to_words[i] = self.words_to_words[index_in_truncated_vocab] del self.words_to_words - self.words_to_words = nn.Parameter(truncated_words_to_words, requires_grad=False) + self.register_buffer('words_to_words', truncated_words_to_words) self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1], padding_idx=vocab.padding_idx, @@ -273,8 +272,7 @@ class StaticEmbedding(TokenEmbedding): vectors = torch.cat((vectors, torch.zeros(1, dim)), dim=0).contiguous() else: unknown_idx = vocab.unknown_idx - self.words_to_words = nn.Parameter(torch.full((len(vocab), ), fill_value=unknown_idx).long(), - requires_grad=False) + self.register_buffer('words_to_words', torch.full((len(vocab), ), fill_value=unknown_idx).long()) for index, (index_in_vocab, vec) in enumerate(matrix.items()): if vec is not None: