From be77533c3832afeca5557e3762e0470d467572f7 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Sun, 25 Aug 2019 18:58:03 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dword=20drop=20bug,=20?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E7=9B=B8=E5=BA=94=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/embeddings/bert_embedding.py | 4 ++-- fastNLP/embeddings/embedding.py | 2 +- test/embeddings/test_bert_embedding.py | 9 ++++++++- test/embeddings/test_static_embedding.py | 11 +++++++++++ 4 files changed, 22 insertions(+), 4 deletions(-) diff --git a/fastNLP/embeddings/bert_embedding.py b/fastNLP/embeddings/bert_embedding.py index 4bd06ec3..047048d8 100644 --- a/fastNLP/embeddings/bert_embedding.py +++ b/fastNLP/embeddings/bert_embedding.py @@ -126,7 +126,7 @@ class BertEmbedding(ContextualEmbedding): with torch.no_grad(): if self._word_sep_index: # 不能drop sep sep_mask = words.eq(self._word_sep_index) - mask = torch.full_like(words, fill_value=self.word_dropout) + mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device) mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 pad_mask = words.ne(0) mask = pad_mask.__and__(mask) # pad的位置不为unk @@ -267,7 +267,7 @@ class BertWordPieceEncoder(nn.Module): with torch.no_grad(): if self._word_sep_index: # 不能drop sep sep_mask = words.eq(self._wordpiece_unk_index) - mask = torch.full_like(words, fill_value=self.word_dropout) + mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device) mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 pad_mask = words.ne(self._wordpiece_pad_index) mask = pad_mask.__and__(mask) # pad的位置不为unk diff --git a/fastNLP/embeddings/embedding.py b/fastNLP/embeddings/embedding.py index a94985c1..5e7b9803 100644 --- a/fastNLP/embeddings/embedding.py +++ b/fastNLP/embeddings/embedding.py @@ -138,7 +138,7 @@ class TokenEmbedding(nn.Module): :return: """ if self.word_dropout > 0 and self.training: - mask = torch.full_like(words, fill_value=self.word_dropout) + mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device) mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 pad_mask = words.ne(self._word_pad_index) mask = mask.__and__(pad_mask) diff --git a/test/embeddings/test_bert_embedding.py b/test/embeddings/test_bert_embedding.py index 760029a3..da81c8c9 100644 --- a/test/embeddings/test_bert_embedding.py +++ b/test/embeddings/test_bert_embedding.py @@ -10,5 +10,12 @@ class TestDownload(unittest.TestCase): # import os vocab = Vocabulary().add_word_lst("This is a test .".split()) embed = BertEmbedding(vocab, model_dir_or_name='en') - words = torch.LongTensor([[0, 1, 2]]) + words = torch.LongTensor([[2, 3, 4, 0]]) print(embed(words).size()) + + def test_word_drop(self): + vocab = Vocabulary().add_word_lst("This is a test .".split()) + embed = BertEmbedding(vocab, model_dir_or_name='en', dropout=0.1, word_dropout=0.2) + for i in range(10): + words = torch.LongTensor([[2, 3, 4, 0]]) + print(embed(words).size()) \ No newline at end of file diff --git a/test/embeddings/test_static_embedding.py b/test/embeddings/test_static_embedding.py index 83137345..c17daa0a 100644 --- a/test/embeddings/test_static_embedding.py +++ b/test/embeddings/test_static_embedding.py @@ -5,6 +5,7 @@ from fastNLP import Vocabulary import torch import os + class TestLoad(unittest.TestCase): def test_norm1(self): # 测试只对可以找到的norm @@ -22,6 +23,16 @@ class TestLoad(unittest.TestCase): self.assertEqual(round(torch.norm(embed(torch.LongTensor([[2]]))).item(), 4), 1) self.assertEqual(round(torch.norm(embed(torch.LongTensor([[4]]))).item(), 4), 1) + def test_dropword(self): + # 测试是否可以通过drop word + vocab = Vocabulary().add_word_lst([chr(i) for i in range(1, 200)]) + embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=10, dropout=0.1, word_dropout=0.4) + for i in range(10): + length = torch.randint(1, 50, (1,)).item() + batch = torch.randint(1, 4, (1,)).item() + words = torch.randint(1, 200, (batch, length)).long() + embed(words) + class TestRandomSameEntry(unittest.TestCase): def test_same_vector(self): vocab = Vocabulary().add_word_lst(["The", "the", "THE", 'a', "A"])