| @@ -126,7 +126,7 @@ class BertEmbedding(ContextualEmbedding): | |||||
| with torch.no_grad(): | with torch.no_grad(): | ||||
| if self._word_sep_index: # 不能drop sep | if self._word_sep_index: # 不能drop sep | ||||
| sep_mask = words.eq(self._word_sep_index) | sep_mask = words.eq(self._word_sep_index) | ||||
| mask = torch.full_like(words, fill_value=self.word_dropout) | |||||
| mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device) | |||||
| mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | ||||
| pad_mask = words.ne(0) | pad_mask = words.ne(0) | ||||
| mask = pad_mask.__and__(mask) # pad的位置不为unk | mask = pad_mask.__and__(mask) # pad的位置不为unk | ||||
| @@ -267,7 +267,7 @@ class BertWordPieceEncoder(nn.Module): | |||||
| with torch.no_grad(): | with torch.no_grad(): | ||||
| if self._word_sep_index: # 不能drop sep | if self._word_sep_index: # 不能drop sep | ||||
| sep_mask = words.eq(self._wordpiece_unk_index) | sep_mask = words.eq(self._wordpiece_unk_index) | ||||
| mask = torch.full_like(words, fill_value=self.word_dropout) | |||||
| mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device) | |||||
| mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | ||||
| pad_mask = words.ne(self._wordpiece_pad_index) | pad_mask = words.ne(self._wordpiece_pad_index) | ||||
| mask = pad_mask.__and__(mask) # pad的位置不为unk | mask = pad_mask.__and__(mask) # pad的位置不为unk | ||||
| @@ -138,7 +138,7 @@ class TokenEmbedding(nn.Module): | |||||
| :return: | :return: | ||||
| """ | """ | ||||
| if self.word_dropout > 0 and self.training: | if self.word_dropout > 0 and self.training: | ||||
| mask = torch.full_like(words, fill_value=self.word_dropout) | |||||
| mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device) | |||||
| mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | ||||
| pad_mask = words.ne(self._word_pad_index) | pad_mask = words.ne(self._word_pad_index) | ||||
| mask = mask.__and__(pad_mask) | mask = mask.__and__(pad_mask) | ||||
| @@ -10,5 +10,12 @@ class TestDownload(unittest.TestCase): | |||||
| # import os | # import os | ||||
| vocab = Vocabulary().add_word_lst("This is a test .".split()) | vocab = Vocabulary().add_word_lst("This is a test .".split()) | ||||
| embed = BertEmbedding(vocab, model_dir_or_name='en') | embed = BertEmbedding(vocab, model_dir_or_name='en') | ||||
| words = torch.LongTensor([[0, 1, 2]]) | |||||
| words = torch.LongTensor([[2, 3, 4, 0]]) | |||||
| print(embed(words).size()) | print(embed(words).size()) | ||||
| def test_word_drop(self): | |||||
| vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||||
| embed = BertEmbedding(vocab, model_dir_or_name='en', dropout=0.1, word_dropout=0.2) | |||||
| for i in range(10): | |||||
| words = torch.LongTensor([[2, 3, 4, 0]]) | |||||
| print(embed(words).size()) | |||||
| @@ -5,6 +5,7 @@ from fastNLP import Vocabulary | |||||
| import torch | import torch | ||||
| import os | import os | ||||
| class TestLoad(unittest.TestCase): | class TestLoad(unittest.TestCase): | ||||
| def test_norm1(self): | def test_norm1(self): | ||||
| # 测试只对可以找到的norm | # 测试只对可以找到的norm | ||||
| @@ -22,6 +23,16 @@ class TestLoad(unittest.TestCase): | |||||
| self.assertEqual(round(torch.norm(embed(torch.LongTensor([[2]]))).item(), 4), 1) | self.assertEqual(round(torch.norm(embed(torch.LongTensor([[2]]))).item(), 4), 1) | ||||
| self.assertEqual(round(torch.norm(embed(torch.LongTensor([[4]]))).item(), 4), 1) | self.assertEqual(round(torch.norm(embed(torch.LongTensor([[4]]))).item(), 4), 1) | ||||
| def test_dropword(self): | |||||
| # 测试是否可以通过drop word | |||||
| vocab = Vocabulary().add_word_lst([chr(i) for i in range(1, 200)]) | |||||
| embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=10, dropout=0.1, word_dropout=0.4) | |||||
| for i in range(10): | |||||
| length = torch.randint(1, 50, (1,)).item() | |||||
| batch = torch.randint(1, 4, (1,)).item() | |||||
| words = torch.randint(1, 200, (batch, length)).long() | |||||
| embed(words) | |||||
| class TestRandomSameEntry(unittest.TestCase): | class TestRandomSameEntry(unittest.TestCase): | ||||
| def test_same_vector(self): | def test_same_vector(self): | ||||
| vocab = Vocabulary().add_word_lst(["The", "the", "THE", 'a', "A"]) | vocab = Vocabulary().add_word_lst(["The", "the", "THE", 'a', "A"]) | ||||