diff --git a/fastNLP/embeddings/bert_embedding.py b/fastNLP/embeddings/bert_embedding.py index 4bd06ec3..047048d8 100644 --- a/fastNLP/embeddings/bert_embedding.py +++ b/fastNLP/embeddings/bert_embedding.py @@ -126,7 +126,7 @@ class BertEmbedding(ContextualEmbedding): with torch.no_grad(): if self._word_sep_index: # 不能drop sep sep_mask = words.eq(self._word_sep_index) - mask = torch.full_like(words, fill_value=self.word_dropout) + mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device) mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 pad_mask = words.ne(0) mask = pad_mask.__and__(mask) # pad的位置不为unk @@ -267,7 +267,7 @@ class BertWordPieceEncoder(nn.Module): with torch.no_grad(): if self._word_sep_index: # 不能drop sep sep_mask = words.eq(self._wordpiece_unk_index) - mask = torch.full_like(words, fill_value=self.word_dropout) + mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device) mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 pad_mask = words.ne(self._wordpiece_pad_index) mask = pad_mask.__and__(mask) # pad的位置不为unk diff --git a/fastNLP/embeddings/embedding.py b/fastNLP/embeddings/embedding.py index a94985c1..5e7b9803 100644 --- a/fastNLP/embeddings/embedding.py +++ b/fastNLP/embeddings/embedding.py @@ -138,7 +138,7 @@ class TokenEmbedding(nn.Module): :return: """ if self.word_dropout > 0 and self.training: - mask = torch.full_like(words, fill_value=self.word_dropout) + mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device) mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 pad_mask = words.ne(self._word_pad_index) mask = mask.__and__(pad_mask) diff --git a/test/embeddings/test_bert_embedding.py b/test/embeddings/test_bert_embedding.py index 760029a3..da81c8c9 100644 --- a/test/embeddings/test_bert_embedding.py +++ b/test/embeddings/test_bert_embedding.py @@ -10,5 +10,12 @@ class TestDownload(unittest.TestCase): # import os vocab = Vocabulary().add_word_lst("This is a test .".split()) embed = BertEmbedding(vocab, model_dir_or_name='en') - words = torch.LongTensor([[0, 1, 2]]) + words = torch.LongTensor([[2, 3, 4, 0]]) print(embed(words).size()) + + def test_word_drop(self): + vocab = Vocabulary().add_word_lst("This is a test .".split()) + embed = BertEmbedding(vocab, model_dir_or_name='en', dropout=0.1, word_dropout=0.2) + for i in range(10): + words = torch.LongTensor([[2, 3, 4, 0]]) + print(embed(words).size()) \ No newline at end of file diff --git a/test/embeddings/test_static_embedding.py b/test/embeddings/test_static_embedding.py index 83137345..c17daa0a 100644 --- a/test/embeddings/test_static_embedding.py +++ b/test/embeddings/test_static_embedding.py @@ -5,6 +5,7 @@ from fastNLP import Vocabulary import torch import os + class TestLoad(unittest.TestCase): def test_norm1(self): # 测试只对可以找到的norm @@ -22,6 +23,16 @@ class TestLoad(unittest.TestCase): self.assertEqual(round(torch.norm(embed(torch.LongTensor([[2]]))).item(), 4), 1) self.assertEqual(round(torch.norm(embed(torch.LongTensor([[4]]))).item(), 4), 1) + def test_dropword(self): + # 测试是否可以通过drop word + vocab = Vocabulary().add_word_lst([chr(i) for i in range(1, 200)]) + embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=10, dropout=0.1, word_dropout=0.4) + for i in range(10): + length = torch.randint(1, 50, (1,)).item() + batch = torch.randint(1, 4, (1,)).item() + words = torch.randint(1, 200, (batch, length)).long() + embed(words) + class TestRandomSameEntry(unittest.TestCase): def test_same_vector(self): vocab = Vocabulary().add_word_lst(["The", "the", "THE", 'a', "A"])