Browse Source

1.修复部分测试; 2.修复StaticEmbedding中未找到词初始化bug

tags/v0.4.10
yh_cc 5 years ago
parent
commit
85f01f01d1
4 changed files with 35 additions and 14 deletions
  1. +7
    -2
      fastNLP/embeddings/static_embedding.py
  2. +1
    -1
      test/embeddings/test_bert_embedding.py
  3. +24
    -7
      test/embeddings/test_static_embedding.py
  4. +3
    -4
      test/test_tutorials.py

+ 7
- 2
fastNLP/embeddings/static_embedding.py View File

@@ -106,6 +106,7 @@ class StaticEmbedding(TokenEmbedding):
print(f"{len(vocab) - len(truncated_vocab)} out of {len(vocab)} words have frequency less than {min_freq}.")
vocab = truncated_vocab

self.only_norm_found_vector = kwargs.get('only_norm_found_vector', False)
# 读取embedding
if lower:
lowered_vocab = Vocabulary(padding=vocab.padding, unknown=vocab.unknown)
@@ -142,7 +143,7 @@ class StaticEmbedding(TokenEmbedding):
else:
embedding = self._randomly_init_embed(len(vocab), embedding_dim, init_method)
self.words_to_words = nn.Parameter(torch.arange(len(vocab)).long(), requires_grad=False)
if normalize:
if not self.only_norm_found_vector and normalize:
embedding /= (torch.norm(embedding, dim=1, keepdim=True) + 1e-12)

if truncate_vocab:
@@ -233,6 +234,7 @@ class StaticEmbedding(TokenEmbedding):
if vocab.unknown:
matrix[vocab.unknown_idx] = torch.zeros(dim)
found_count = 0
found_unknown = False
for idx, line in enumerate(f, start_idx):
try:
parts = line.strip().split()
@@ -243,9 +245,12 @@ class StaticEmbedding(TokenEmbedding):
word = vocab.padding
elif word == unknown and vocab.unknown is not None:
word = vocab.unknown
found_unknown = True
if word in vocab:
index = vocab.to_index(word)
matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim))
if self.only_norm_found_vector:
matrix[index] = matrix[index]/np.linalg.norm(matrix[index])
found_count += 1
except Exception as e:
if error == 'ignore':
@@ -256,7 +261,7 @@ class StaticEmbedding(TokenEmbedding):
print("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab)))
for word, index in vocab:
if index not in matrix and not vocab._is_word_no_create_entry(word):
if vocab.unknown_idx in matrix: # 如果有unkonwn,用unknown初始化
if found_unknown: # 如果有unkonwn,用unknown初始化
matrix[index] = matrix[vocab.unknown_idx]
else:
matrix[index] = None


+ 1
- 1
test/embeddings/test_bert_embedding.py View File

@@ -9,6 +9,6 @@ class TestDownload(unittest.TestCase):
def test_download(self):
# import os
vocab = Vocabulary().add_word_lst("This is a test .".split())
embed = BertEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/embedding/bert-base-cased')
embed = BertEmbedding(vocab, model_dir_or_name='en')
words = torch.LongTensor([[0, 1, 2]])
print(embed(words).size())

+ 24
- 7
test/embeddings/test_static_embedding.py View File

@@ -5,6 +5,23 @@ from fastNLP import Vocabulary
import torch
import os

class TestLoad(unittest.TestCase):
def test_norm1(self):
# 测试只对可以找到的norm
vocab = Vocabulary().add_word_lst(['the', 'a', 'notinfile'])
embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/glove.6B.50d_test.txt',
only_norm_found_vector=True)
self.assertEqual(round(torch.norm(embed(torch.LongTensor([[2]]))).item(), 4), 1)
self.assertNotEqual(torch.norm(embed(torch.LongTensor([[4]]))).item(), 1)

def test_norm2(self):
# 测试对所有都norm
vocab = Vocabulary().add_word_lst(['the', 'a', 'notinfile'])
embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/glove.6B.50d_test.txt',
normalize=True)
self.assertEqual(round(torch.norm(embed(torch.LongTensor([[2]]))).item(), 4), 1)
self.assertEqual(round(torch.norm(embed(torch.LongTensor([[4]]))).item(), 4), 1)

class TestRandomSameEntry(unittest.TestCase):
def test_same_vector(self):
vocab = Vocabulary().add_word_lst(["The", "the", "THE", 'a', "A"])
@@ -21,7 +38,7 @@ class TestRandomSameEntry(unittest.TestCase):
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
def test_same_vector2(self):
vocab = Vocabulary().add_word_lst(["The", 'a', 'b', "the", "THE", "B", 'a', "A"])
embed = StaticEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.6B.100d.txt',
embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6B-100d',
lower=True)
words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE", 'b', "B", 'a', 'A']]])
words = embed(words)
@@ -39,7 +56,7 @@ class TestRandomSameEntry(unittest.TestCase):
no_create_word_lst = ['of', 'Of', 'With', 'with']
vocab = Vocabulary().add_word_lst(word_lst)
vocab.add_word_lst(no_create_word_lst, no_create_entry=True)
embed = StaticEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt',
embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6B-100d',
lower=True)
words = torch.LongTensor([[vocab.to_index(word) for word in word_lst+no_create_word_lst]])
words = embed(words)
@@ -48,7 +65,7 @@ class TestRandomSameEntry(unittest.TestCase):
lowered_no_create_word_lst = [word.lower() for word in no_create_word_lst]
lowered_vocab = Vocabulary().add_word_lst(lowered_word_lst)
lowered_vocab.add_word_lst(lowered_no_create_word_lst, no_create_entry=True)
lowered_embed = StaticEmbedding(lowered_vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt',
lowered_embed = StaticEmbedding(lowered_vocab, model_dir_or_name='en-glove-6B-100d',
lower=False)
lowered_words = torch.LongTensor([[lowered_vocab.to_index(word) for word in lowered_word_lst+lowered_no_create_word_lst]])
lowered_words = lowered_embed(lowered_words)
@@ -67,7 +84,7 @@ class TestRandomSameEntry(unittest.TestCase):
all_words = word_lst[:-2] + no_create_word_lst[:-2]
vocab = Vocabulary(min_freq=2).add_word_lst(word_lst)
vocab.add_word_lst(no_create_word_lst, no_create_entry=True)
embed = StaticEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt',
embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6B-100d',
lower=True)
words = torch.LongTensor([[vocab.to_index(word) for word in all_words]])
words = embed(words)
@@ -76,7 +93,7 @@ class TestRandomSameEntry(unittest.TestCase):
lowered_no_create_word_lst = [word.lower() for word in no_create_word_lst]
lowered_vocab = Vocabulary().add_word_lst(lowered_word_lst)
lowered_vocab.add_word_lst(lowered_no_create_word_lst, no_create_entry=True)
lowered_embed = StaticEmbedding(lowered_vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt',
lowered_embed = StaticEmbedding(lowered_vocab, model_dir_or_name='en-glove-6B-100d',
lower=False)
lowered_words = torch.LongTensor([[lowered_vocab.to_index(word.lower()) for word in all_words]])
lowered_words = lowered_embed(lowered_words)
@@ -94,14 +111,14 @@ class TestRandomSameEntry(unittest.TestCase):
all_words = word_lst[:-2] + no_create_word_lst[:-2]
vocab = Vocabulary().add_word_lst(word_lst)
vocab.add_word_lst(no_create_word_lst, no_create_entry=True)
embed = StaticEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt',
embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6B-100d',
lower=False, min_freq=2)
words = torch.LongTensor([[vocab.to_index(word) for word in all_words]])
words = embed(words)

min_freq_vocab = Vocabulary(min_freq=2).add_word_lst(word_lst)
min_freq_vocab.add_word_lst(no_create_word_lst, no_create_entry=True)
min_freq_embed = StaticEmbedding(min_freq_vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt',
min_freq_embed = StaticEmbedding(min_freq_vocab, model_dir_or_name='en-glove-6B-100d',
lower=False)
min_freq_words = torch.LongTensor([[min_freq_vocab.to_index(word.lower()) for word in all_words]])
min_freq_words = min_freq_embed(min_freq_words)


+ 3
- 4
test/test_tutorials.py View File

@@ -5,14 +5,13 @@ from fastNLP import Instance
from fastNLP import Vocabulary
from fastNLP.core.losses import CrossEntropyLoss
from fastNLP.core.metrics import AccuracyMetric
from fastNLP.io.loader import CSVLoader

class TestTutorial(unittest.TestCase):
def test_fastnlp_10min_tutorial(self):
# 从csv读取数据到DataSet
sample_path = "test/data_for_tests/tutorial_sample_dataset.csv"
dataset = DataSet.read_csv(sample_path, headers=('raw_sentence', 'label'),
sep='\t')
dataset = CSVLoader(headers=['raw_sentence', 'label'], sep=' ')._load(sample_path)
print(len(dataset))
print(dataset[0])
print(dataset[-3])
@@ -110,7 +109,7 @@ class TestTutorial(unittest.TestCase):
def test_fastnlp_1min_tutorial(self):
# tutorials/fastnlp_1min_tutorial.ipynb
data_path = "test/data_for_tests/tutorial_sample_dataset.csv"
ds = DataSet.read_csv(data_path, headers=('raw_sentence', 'label'), sep='\t')
ds = CSVLoader(headers=['raw_sentence', 'label'], sep=' ')._load(data_path)
print(ds[1])

# 将所有数字转为小写


Loading…
Cancel
Save