diff --git a/test/embeddings/test_static_embedding.py b/test/embeddings/test_static_embedding.py index 0c8fc739..6fd33072 100644 --- a/test/embeddings/test_static_embedding.py +++ b/test/embeddings/test_static_embedding.py @@ -3,13 +3,90 @@ import unittest from fastNLP.embeddings import StaticEmbedding from fastNLP import Vocabulary import torch +import os class TestRandomSameEntry(unittest.TestCase): def test_same_vector(self): - vocab = Vocabulary().add_word_lst(["The", "the", "THE"]) + vocab = Vocabulary().add_word_lst(["The", "the", "THE", 'a', "A"]) embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5, lower=True) - words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE"]]]) + words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE", 'a', 'A']]]) words = embed(words) embed_0 = words[0, 0] - for i in range(1, words.size(1)): + for i in range(1, 3): assert torch.sum(embed_0==words[0, i]).eq(len(embed_0)) + embed_0 = words[0, 3] + for i in range(3, 5): + assert torch.sum(embed_0 == words[0, i]).eq(len(embed_0)) + + @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") + def test_same_vector2(self): + vocab = Vocabulary().add_word_lst(["The", 'a', 'b', "the", "THE", "B", 'a', "A"]) + embed = StaticEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.6B.100d.txt', + lower=True) + words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE", 'b', "B", 'a', 'A']]]) + words = embed(words) + embed_0 = words[0, 0] + for i in range(1, 3): + assert torch.sum(embed_0==words[0, i]).eq(len(embed_0)) + embed_0 = words[0, 3] + for i in range(3, 5): + assert torch.sum(embed_0 == words[0, i]).eq(len(embed_0)) + + @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") + def test_same_vector3(self): + word_lst = ["The", "the"] + no_create_word_lst = ['of', 'Of', 'With', 'with'] + vocab = Vocabulary().add_word_lst(word_lst) + vocab.add_word_lst(no_create_word_lst, no_create_entry=True) + embed = StaticEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', + lower=True) + words = torch.LongTensor([[vocab.to_index(word) for word in word_lst+no_create_word_lst]]) + words = embed(words) + + lowered_word_lst = [word.lower() for word in word_lst] + lowered_no_create_word_lst = [word.lower() for word in no_create_word_lst] + lowered_vocab = Vocabulary().add_word_lst(lowered_word_lst) + lowered_vocab.add_word_lst(lowered_no_create_word_lst, no_create_entry=True) + lowered_embed = StaticEmbedding(lowered_vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', + lower=False) + lowered_words = torch.LongTensor([[lowered_vocab.to_index(word) for word in lowered_word_lst+lowered_no_create_word_lst]]) + lowered_words = lowered_embed(lowered_words) + + all_words = word_lst + no_create_word_lst + + for idx, (word_i, word_j) in enumerate(zip(words[0], lowered_words[0])): + with self.subTest(idx=idx, word=all_words[idx]): + assert torch.sum(word_i == word_j).eq(lowered_embed.embed_size) + + @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") + def test_same_vector4(self): + # words = [] + # create_word_lst = [] # 需要创建 + # no_create_word_lst = [] + # ignore_word_lst = [] + # with open('/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', 'r', encoding='utf-8') as f: + # for line in f: + # words + word_lst = ["The", "the", "the", "The", "a", "A"] + no_create_word_lst = ['of', 'Of', "Of", "of", 'With', 'with'] + all_words = word_lst[:-2] + no_create_word_lst[:-2] + vocab = Vocabulary(min_freq=2).add_word_lst(word_lst) + vocab.add_word_lst(no_create_word_lst, no_create_entry=True) + embed = StaticEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', + lower=True) + words = torch.LongTensor([[vocab.to_index(word) for word in all_words]]) + words = embed(words) + + lowered_word_lst = [word.lower() for word in word_lst] + lowered_no_create_word_lst = [word.lower() for word in no_create_word_lst] + lowered_vocab = Vocabulary().add_word_lst(lowered_word_lst) + lowered_vocab.add_word_lst(lowered_no_create_word_lst, no_create_entry=True) + lowered_embed = StaticEmbedding(lowered_vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', + lower=False) + lowered_words = torch.LongTensor([[lowered_vocab.to_index(word.lower()) for word in all_words]]) + lowered_words = lowered_embed(lowered_words) + + for idx in range(len(all_words)): + word_i, word_j = words[0, idx], lowered_words[0, idx] + with self.subTest(idx=idx, word=all_words[idx]): + assert torch.sum(word_i == word_j).eq(lowered_embed.embed_size) \ No newline at end of file