@@ -126,7 +126,7 @@ class BertEmbedding(ContextualEmbedding): | |||||
with torch.no_grad(): | with torch.no_grad(): | ||||
if self._word_sep_index: # 不能drop sep | if self._word_sep_index: # 不能drop sep | ||||
sep_mask = words.eq(self._word_sep_index) | sep_mask = words.eq(self._word_sep_index) | ||||
mask = torch.full_like(words, fill_value=self.word_dropout) | |||||
mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device) | |||||
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | ||||
pad_mask = words.ne(0) | pad_mask = words.ne(0) | ||||
mask = pad_mask.__and__(mask) # pad的位置不为unk | mask = pad_mask.__and__(mask) # pad的位置不为unk | ||||
@@ -267,7 +267,7 @@ class BertWordPieceEncoder(nn.Module): | |||||
with torch.no_grad(): | with torch.no_grad(): | ||||
if self._word_sep_index: # 不能drop sep | if self._word_sep_index: # 不能drop sep | ||||
sep_mask = words.eq(self._wordpiece_unk_index) | sep_mask = words.eq(self._wordpiece_unk_index) | ||||
mask = torch.full_like(words, fill_value=self.word_dropout) | |||||
mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device) | |||||
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | ||||
pad_mask = words.ne(self._wordpiece_pad_index) | pad_mask = words.ne(self._wordpiece_pad_index) | ||||
mask = pad_mask.__and__(mask) # pad的位置不为unk | mask = pad_mask.__and__(mask) # pad的位置不为unk | ||||
@@ -138,7 +138,7 @@ class TokenEmbedding(nn.Module): | |||||
:return: | :return: | ||||
""" | """ | ||||
if self.word_dropout > 0 and self.training: | if self.word_dropout > 0 and self.training: | ||||
mask = torch.full_like(words, fill_value=self.word_dropout) | |||||
mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device) | |||||
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | ||||
pad_mask = words.ne(self._word_pad_index) | pad_mask = words.ne(self._word_pad_index) | ||||
mask = mask.__and__(pad_mask) | mask = mask.__and__(pad_mask) | ||||
@@ -10,5 +10,12 @@ class TestDownload(unittest.TestCase): | |||||
# import os | # import os | ||||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | vocab = Vocabulary().add_word_lst("This is a test .".split()) | ||||
embed = BertEmbedding(vocab, model_dir_or_name='en') | embed = BertEmbedding(vocab, model_dir_or_name='en') | ||||
words = torch.LongTensor([[0, 1, 2]]) | |||||
words = torch.LongTensor([[2, 3, 4, 0]]) | |||||
print(embed(words).size()) | print(embed(words).size()) | ||||
def test_word_drop(self): | |||||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||||
embed = BertEmbedding(vocab, model_dir_or_name='en', dropout=0.1, word_dropout=0.2) | |||||
for i in range(10): | |||||
words = torch.LongTensor([[2, 3, 4, 0]]) | |||||
print(embed(words).size()) |
@@ -5,6 +5,7 @@ from fastNLP import Vocabulary | |||||
import torch | import torch | ||||
import os | import os | ||||
class TestLoad(unittest.TestCase): | class TestLoad(unittest.TestCase): | ||||
def test_norm1(self): | def test_norm1(self): | ||||
# 测试只对可以找到的norm | # 测试只对可以找到的norm | ||||
@@ -22,6 +23,16 @@ class TestLoad(unittest.TestCase): | |||||
self.assertEqual(round(torch.norm(embed(torch.LongTensor([[2]]))).item(), 4), 1) | self.assertEqual(round(torch.norm(embed(torch.LongTensor([[2]]))).item(), 4), 1) | ||||
self.assertEqual(round(torch.norm(embed(torch.LongTensor([[4]]))).item(), 4), 1) | self.assertEqual(round(torch.norm(embed(torch.LongTensor([[4]]))).item(), 4), 1) | ||||
def test_dropword(self): | |||||
# 测试是否可以通过drop word | |||||
vocab = Vocabulary().add_word_lst([chr(i) for i in range(1, 200)]) | |||||
embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=10, dropout=0.1, word_dropout=0.4) | |||||
for i in range(10): | |||||
length = torch.randint(1, 50, (1,)).item() | |||||
batch = torch.randint(1, 4, (1,)).item() | |||||
words = torch.randint(1, 200, (batch, length)).long() | |||||
embed(words) | |||||
class TestRandomSameEntry(unittest.TestCase): | class TestRandomSameEntry(unittest.TestCase): | ||||
def test_same_vector(self): | def test_same_vector(self): | ||||
vocab = Vocabulary().add_word_lst(["The", "the", "THE", 'a', "A"]) | vocab = Vocabulary().add_word_lst(["The", "the", "THE", 'a', "A"]) | ||||