|
@@ -5,6 +5,23 @@ from fastNLP import Vocabulary |
|
|
import torch |
|
|
import torch |
|
|
import os |
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
class TestLoad(unittest.TestCase): |
|
|
|
|
|
def test_norm1(self): |
|
|
|
|
|
# 测试只对可以找到的norm |
|
|
|
|
|
vocab = Vocabulary().add_word_lst(['the', 'a', 'notinfile']) |
|
|
|
|
|
embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/glove.6B.50d_test.txt', |
|
|
|
|
|
only_norm_found_vector=True) |
|
|
|
|
|
self.assertEqual(round(torch.norm(embed(torch.LongTensor([[2]]))).item(), 4), 1) |
|
|
|
|
|
self.assertNotEqual(torch.norm(embed(torch.LongTensor([[4]]))).item(), 1) |
|
|
|
|
|
|
|
|
|
|
|
def test_norm2(self): |
|
|
|
|
|
# 测试对所有都norm |
|
|
|
|
|
vocab = Vocabulary().add_word_lst(['the', 'a', 'notinfile']) |
|
|
|
|
|
embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/glove.6B.50d_test.txt', |
|
|
|
|
|
normalize=True) |
|
|
|
|
|
self.assertEqual(round(torch.norm(embed(torch.LongTensor([[2]]))).item(), 4), 1) |
|
|
|
|
|
self.assertEqual(round(torch.norm(embed(torch.LongTensor([[4]]))).item(), 4), 1) |
|
|
|
|
|
|
|
|
class TestRandomSameEntry(unittest.TestCase): |
|
|
class TestRandomSameEntry(unittest.TestCase): |
|
|
def test_same_vector(self): |
|
|
def test_same_vector(self): |
|
|
vocab = Vocabulary().add_word_lst(["The", "the", "THE", 'a', "A"]) |
|
|
vocab = Vocabulary().add_word_lst(["The", "the", "THE", 'a', "A"]) |
|
@@ -21,7 +38,7 @@ class TestRandomSameEntry(unittest.TestCase): |
|
|
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") |
|
|
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") |
|
|
def test_same_vector2(self): |
|
|
def test_same_vector2(self): |
|
|
vocab = Vocabulary().add_word_lst(["The", 'a', 'b', "the", "THE", "B", 'a', "A"]) |
|
|
vocab = Vocabulary().add_word_lst(["The", 'a', 'b', "the", "THE", "B", 'a', "A"]) |
|
|
embed = StaticEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.6B.100d.txt', |
|
|
|
|
|
|
|
|
embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6B-100d', |
|
|
lower=True) |
|
|
lower=True) |
|
|
words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE", 'b', "B", 'a', 'A']]]) |
|
|
words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE", 'b', "B", 'a', 'A']]]) |
|
|
words = embed(words) |
|
|
words = embed(words) |
|
@@ -39,7 +56,7 @@ class TestRandomSameEntry(unittest.TestCase): |
|
|
no_create_word_lst = ['of', 'Of', 'With', 'with'] |
|
|
no_create_word_lst = ['of', 'Of', 'With', 'with'] |
|
|
vocab = Vocabulary().add_word_lst(word_lst) |
|
|
vocab = Vocabulary().add_word_lst(word_lst) |
|
|
vocab.add_word_lst(no_create_word_lst, no_create_entry=True) |
|
|
vocab.add_word_lst(no_create_word_lst, no_create_entry=True) |
|
|
embed = StaticEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', |
|
|
|
|
|
|
|
|
embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6B-100d', |
|
|
lower=True) |
|
|
lower=True) |
|
|
words = torch.LongTensor([[vocab.to_index(word) for word in word_lst+no_create_word_lst]]) |
|
|
words = torch.LongTensor([[vocab.to_index(word) for word in word_lst+no_create_word_lst]]) |
|
|
words = embed(words) |
|
|
words = embed(words) |
|
@@ -48,7 +65,7 @@ class TestRandomSameEntry(unittest.TestCase): |
|
|
lowered_no_create_word_lst = [word.lower() for word in no_create_word_lst] |
|
|
lowered_no_create_word_lst = [word.lower() for word in no_create_word_lst] |
|
|
lowered_vocab = Vocabulary().add_word_lst(lowered_word_lst) |
|
|
lowered_vocab = Vocabulary().add_word_lst(lowered_word_lst) |
|
|
lowered_vocab.add_word_lst(lowered_no_create_word_lst, no_create_entry=True) |
|
|
lowered_vocab.add_word_lst(lowered_no_create_word_lst, no_create_entry=True) |
|
|
lowered_embed = StaticEmbedding(lowered_vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', |
|
|
|
|
|
|
|
|
lowered_embed = StaticEmbedding(lowered_vocab, model_dir_or_name='en-glove-6B-100d', |
|
|
lower=False) |
|
|
lower=False) |
|
|
lowered_words = torch.LongTensor([[lowered_vocab.to_index(word) for word in lowered_word_lst+lowered_no_create_word_lst]]) |
|
|
lowered_words = torch.LongTensor([[lowered_vocab.to_index(word) for word in lowered_word_lst+lowered_no_create_word_lst]]) |
|
|
lowered_words = lowered_embed(lowered_words) |
|
|
lowered_words = lowered_embed(lowered_words) |
|
@@ -67,7 +84,7 @@ class TestRandomSameEntry(unittest.TestCase): |
|
|
all_words = word_lst[:-2] + no_create_word_lst[:-2] |
|
|
all_words = word_lst[:-2] + no_create_word_lst[:-2] |
|
|
vocab = Vocabulary(min_freq=2).add_word_lst(word_lst) |
|
|
vocab = Vocabulary(min_freq=2).add_word_lst(word_lst) |
|
|
vocab.add_word_lst(no_create_word_lst, no_create_entry=True) |
|
|
vocab.add_word_lst(no_create_word_lst, no_create_entry=True) |
|
|
embed = StaticEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', |
|
|
|
|
|
|
|
|
embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6B-100d', |
|
|
lower=True) |
|
|
lower=True) |
|
|
words = torch.LongTensor([[vocab.to_index(word) for word in all_words]]) |
|
|
words = torch.LongTensor([[vocab.to_index(word) for word in all_words]]) |
|
|
words = embed(words) |
|
|
words = embed(words) |
|
@@ -76,7 +93,7 @@ class TestRandomSameEntry(unittest.TestCase): |
|
|
lowered_no_create_word_lst = [word.lower() for word in no_create_word_lst] |
|
|
lowered_no_create_word_lst = [word.lower() for word in no_create_word_lst] |
|
|
lowered_vocab = Vocabulary().add_word_lst(lowered_word_lst) |
|
|
lowered_vocab = Vocabulary().add_word_lst(lowered_word_lst) |
|
|
lowered_vocab.add_word_lst(lowered_no_create_word_lst, no_create_entry=True) |
|
|
lowered_vocab.add_word_lst(lowered_no_create_word_lst, no_create_entry=True) |
|
|
lowered_embed = StaticEmbedding(lowered_vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', |
|
|
|
|
|
|
|
|
lowered_embed = StaticEmbedding(lowered_vocab, model_dir_or_name='en-glove-6B-100d', |
|
|
lower=False) |
|
|
lower=False) |
|
|
lowered_words = torch.LongTensor([[lowered_vocab.to_index(word.lower()) for word in all_words]]) |
|
|
lowered_words = torch.LongTensor([[lowered_vocab.to_index(word.lower()) for word in all_words]]) |
|
|
lowered_words = lowered_embed(lowered_words) |
|
|
lowered_words = lowered_embed(lowered_words) |
|
@@ -94,14 +111,14 @@ class TestRandomSameEntry(unittest.TestCase): |
|
|
all_words = word_lst[:-2] + no_create_word_lst[:-2] |
|
|
all_words = word_lst[:-2] + no_create_word_lst[:-2] |
|
|
vocab = Vocabulary().add_word_lst(word_lst) |
|
|
vocab = Vocabulary().add_word_lst(word_lst) |
|
|
vocab.add_word_lst(no_create_word_lst, no_create_entry=True) |
|
|
vocab.add_word_lst(no_create_word_lst, no_create_entry=True) |
|
|
embed = StaticEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', |
|
|
|
|
|
|
|
|
embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6B-100d', |
|
|
lower=False, min_freq=2) |
|
|
lower=False, min_freq=2) |
|
|
words = torch.LongTensor([[vocab.to_index(word) for word in all_words]]) |
|
|
words = torch.LongTensor([[vocab.to_index(word) for word in all_words]]) |
|
|
words = embed(words) |
|
|
words = embed(words) |
|
|
|
|
|
|
|
|
min_freq_vocab = Vocabulary(min_freq=2).add_word_lst(word_lst) |
|
|
min_freq_vocab = Vocabulary(min_freq=2).add_word_lst(word_lst) |
|
|
min_freq_vocab.add_word_lst(no_create_word_lst, no_create_entry=True) |
|
|
min_freq_vocab.add_word_lst(no_create_word_lst, no_create_entry=True) |
|
|
min_freq_embed = StaticEmbedding(min_freq_vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', |
|
|
|
|
|
|
|
|
min_freq_embed = StaticEmbedding(min_freq_vocab, model_dir_or_name='en-glove-6B-100d', |
|
|
lower=False) |
|
|
lower=False) |
|
|
min_freq_words = torch.LongTensor([[min_freq_vocab.to_index(word.lower()) for word in all_words]]) |
|
|
min_freq_words = torch.LongTensor([[min_freq_vocab.to_index(word.lower()) for word in all_words]]) |
|
|
min_freq_words = min_freq_embed(min_freq_words) |
|
|
min_freq_words = min_freq_embed(min_freq_words) |
|
|