|
|
@@ -3,13 +3,90 @@ import unittest |
|
|
|
from fastNLP.embeddings import StaticEmbedding |
|
|
|
from fastNLP import Vocabulary |
|
|
|
import torch |
|
|
|
import os |
|
|
|
|
|
|
|
class TestRandomSameEntry(unittest.TestCase): |
|
|
|
def test_same_vector(self): |
|
|
|
vocab = Vocabulary().add_word_lst(["The", "the", "THE"]) |
|
|
|
vocab = Vocabulary().add_word_lst(["The", "the", "THE", 'a', "A"]) |
|
|
|
embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5, lower=True) |
|
|
|
words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE"]]]) |
|
|
|
words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE", 'a', 'A']]]) |
|
|
|
words = embed(words) |
|
|
|
embed_0 = words[0, 0] |
|
|
|
for i in range(1, words.size(1)): |
|
|
|
for i in range(1, 3): |
|
|
|
assert torch.sum(embed_0==words[0, i]).eq(len(embed_0)) |
|
|
|
embed_0 = words[0, 3] |
|
|
|
for i in range(3, 5): |
|
|
|
assert torch.sum(embed_0 == words[0, i]).eq(len(embed_0)) |
|
|
|
|
|
|
|
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") |
|
|
|
def test_same_vector2(self): |
|
|
|
vocab = Vocabulary().add_word_lst(["The", 'a', 'b', "the", "THE", "B", 'a', "A"]) |
|
|
|
embed = StaticEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.6B.100d.txt', |
|
|
|
lower=True) |
|
|
|
words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE", 'b', "B", 'a', 'A']]]) |
|
|
|
words = embed(words) |
|
|
|
embed_0 = words[0, 0] |
|
|
|
for i in range(1, 3): |
|
|
|
assert torch.sum(embed_0==words[0, i]).eq(len(embed_0)) |
|
|
|
embed_0 = words[0, 3] |
|
|
|
for i in range(3, 5): |
|
|
|
assert torch.sum(embed_0 == words[0, i]).eq(len(embed_0)) |
|
|
|
|
|
|
|
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") |
|
|
|
def test_same_vector3(self): |
|
|
|
word_lst = ["The", "the"] |
|
|
|
no_create_word_lst = ['of', 'Of', 'With', 'with'] |
|
|
|
vocab = Vocabulary().add_word_lst(word_lst) |
|
|
|
vocab.add_word_lst(no_create_word_lst, no_create_entry=True) |
|
|
|
embed = StaticEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', |
|
|
|
lower=True) |
|
|
|
words = torch.LongTensor([[vocab.to_index(word) for word in word_lst+no_create_word_lst]]) |
|
|
|
words = embed(words) |
|
|
|
|
|
|
|
lowered_word_lst = [word.lower() for word in word_lst] |
|
|
|
lowered_no_create_word_lst = [word.lower() for word in no_create_word_lst] |
|
|
|
lowered_vocab = Vocabulary().add_word_lst(lowered_word_lst) |
|
|
|
lowered_vocab.add_word_lst(lowered_no_create_word_lst, no_create_entry=True) |
|
|
|
lowered_embed = StaticEmbedding(lowered_vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', |
|
|
|
lower=False) |
|
|
|
lowered_words = torch.LongTensor([[lowered_vocab.to_index(word) for word in lowered_word_lst+lowered_no_create_word_lst]]) |
|
|
|
lowered_words = lowered_embed(lowered_words) |
|
|
|
|
|
|
|
all_words = word_lst + no_create_word_lst |
|
|
|
|
|
|
|
for idx, (word_i, word_j) in enumerate(zip(words[0], lowered_words[0])): |
|
|
|
with self.subTest(idx=idx, word=all_words[idx]): |
|
|
|
assert torch.sum(word_i == word_j).eq(lowered_embed.embed_size) |
|
|
|
|
|
|
|
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") |
|
|
|
def test_same_vector4(self): |
|
|
|
# words = [] |
|
|
|
# create_word_lst = [] # 需要创建 |
|
|
|
# no_create_word_lst = [] |
|
|
|
# ignore_word_lst = [] |
|
|
|
# with open('/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', 'r', encoding='utf-8') as f: |
|
|
|
# for line in f: |
|
|
|
# words |
|
|
|
word_lst = ["The", "the", "the", "The", "a", "A"] |
|
|
|
no_create_word_lst = ['of', 'Of', "Of", "of", 'With', 'with'] |
|
|
|
all_words = word_lst[:-2] + no_create_word_lst[:-2] |
|
|
|
vocab = Vocabulary(min_freq=2).add_word_lst(word_lst) |
|
|
|
vocab.add_word_lst(no_create_word_lst, no_create_entry=True) |
|
|
|
embed = StaticEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', |
|
|
|
lower=True) |
|
|
|
words = torch.LongTensor([[vocab.to_index(word) for word in all_words]]) |
|
|
|
words = embed(words) |
|
|
|
|
|
|
|
lowered_word_lst = [word.lower() for word in word_lst] |
|
|
|
lowered_no_create_word_lst = [word.lower() for word in no_create_word_lst] |
|
|
|
lowered_vocab = Vocabulary().add_word_lst(lowered_word_lst) |
|
|
|
lowered_vocab.add_word_lst(lowered_no_create_word_lst, no_create_entry=True) |
|
|
|
lowered_embed = StaticEmbedding(lowered_vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', |
|
|
|
lower=False) |
|
|
|
lowered_words = torch.LongTensor([[lowered_vocab.to_index(word.lower()) for word in all_words]]) |
|
|
|
lowered_words = lowered_embed(lowered_words) |
|
|
|
|
|
|
|
for idx in range(len(all_words)): |
|
|
|
word_i, word_j = words[0, idx], lowered_words[0, idx] |
|
|
|
with self.subTest(idx=idx, word=all_words[idx]): |
|
|
|
assert torch.sum(word_i == word_j).eq(lowered_embed.embed_size) |