Browse Source

修复word drop bug, 增加相应测试

tags/v0.4.10
yh_cc 5 years ago
parent
commit
be77533c38
4 changed files with 22 additions and 4 deletions
  1. +2
    -2
      fastNLP/embeddings/bert_embedding.py
  2. +1
    -1
      fastNLP/embeddings/embedding.py
  3. +8
    -1
      test/embeddings/test_bert_embedding.py
  4. +11
    -0
      test/embeddings/test_static_embedding.py

+ 2
- 2
fastNLP/embeddings/bert_embedding.py View File

@@ -126,7 +126,7 @@ class BertEmbedding(ContextualEmbedding):
with torch.no_grad(): with torch.no_grad():
if self._word_sep_index: # 不能drop sep if self._word_sep_index: # 不能drop sep
sep_mask = words.eq(self._word_sep_index) sep_mask = words.eq(self._word_sep_index)
mask = torch.full_like(words, fill_value=self.word_dropout)
mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device)
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
pad_mask = words.ne(0) pad_mask = words.ne(0)
mask = pad_mask.__and__(mask) # pad的位置不为unk mask = pad_mask.__and__(mask) # pad的位置不为unk
@@ -267,7 +267,7 @@ class BertWordPieceEncoder(nn.Module):
with torch.no_grad(): with torch.no_grad():
if self._word_sep_index: # 不能drop sep if self._word_sep_index: # 不能drop sep
sep_mask = words.eq(self._wordpiece_unk_index) sep_mask = words.eq(self._wordpiece_unk_index)
mask = torch.full_like(words, fill_value=self.word_dropout)
mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device)
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
pad_mask = words.ne(self._wordpiece_pad_index) pad_mask = words.ne(self._wordpiece_pad_index)
mask = pad_mask.__and__(mask) # pad的位置不为unk mask = pad_mask.__and__(mask) # pad的位置不为unk


+ 1
- 1
fastNLP/embeddings/embedding.py View File

@@ -138,7 +138,7 @@ class TokenEmbedding(nn.Module):
:return: :return:
""" """
if self.word_dropout > 0 and self.training: if self.word_dropout > 0 and self.training:
mask = torch.full_like(words, fill_value=self.word_dropout)
mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device)
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
pad_mask = words.ne(self._word_pad_index) pad_mask = words.ne(self._word_pad_index)
mask = mask.__and__(pad_mask) mask = mask.__and__(pad_mask)


+ 8
- 1
test/embeddings/test_bert_embedding.py View File

@@ -10,5 +10,12 @@ class TestDownload(unittest.TestCase):
# import os # import os
vocab = Vocabulary().add_word_lst("This is a test .".split()) vocab = Vocabulary().add_word_lst("This is a test .".split())
embed = BertEmbedding(vocab, model_dir_or_name='en') embed = BertEmbedding(vocab, model_dir_or_name='en')
words = torch.LongTensor([[0, 1, 2]])
words = torch.LongTensor([[2, 3, 4, 0]])
print(embed(words).size()) print(embed(words).size())

def test_word_drop(self):
vocab = Vocabulary().add_word_lst("This is a test .".split())
embed = BertEmbedding(vocab, model_dir_or_name='en', dropout=0.1, word_dropout=0.2)
for i in range(10):
words = torch.LongTensor([[2, 3, 4, 0]])
print(embed(words).size())

+ 11
- 0
test/embeddings/test_static_embedding.py View File

@@ -5,6 +5,7 @@ from fastNLP import Vocabulary
import torch import torch
import os import os



class TestLoad(unittest.TestCase): class TestLoad(unittest.TestCase):
def test_norm1(self): def test_norm1(self):
# 测试只对可以找到的norm # 测试只对可以找到的norm
@@ -22,6 +23,16 @@ class TestLoad(unittest.TestCase):
self.assertEqual(round(torch.norm(embed(torch.LongTensor([[2]]))).item(), 4), 1) self.assertEqual(round(torch.norm(embed(torch.LongTensor([[2]]))).item(), 4), 1)
self.assertEqual(round(torch.norm(embed(torch.LongTensor([[4]]))).item(), 4), 1) self.assertEqual(round(torch.norm(embed(torch.LongTensor([[4]]))).item(), 4), 1)


def test_dropword(self):
# 测试是否可以通过drop word
vocab = Vocabulary().add_word_lst([chr(i) for i in range(1, 200)])
embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=10, dropout=0.1, word_dropout=0.4)
for i in range(10):
length = torch.randint(1, 50, (1,)).item()
batch = torch.randint(1, 4, (1,)).item()
words = torch.randint(1, 200, (batch, length)).long()
embed(words)

class TestRandomSameEntry(unittest.TestCase): class TestRandomSameEntry(unittest.TestCase):
def test_same_vector(self): def test_same_vector(self):
vocab = Vocabulary().add_word_lst(["The", "the", "THE", 'a', "A"]) vocab = Vocabulary().add_word_lst(["The", "the", "THE", 'a', "A"])


Loading…
Cancel
Save