Browse Source

增加对static_embed的测试

tags/v0.4.10
yh 6 years ago
parent
commit
c9fba2ae96
1 changed files with 80 additions and 3 deletions
  1. +80
    -3
      test/embeddings/test_static_embedding.py

+ 80
- 3
test/embeddings/test_static_embedding.py View File

@@ -3,13 +3,90 @@ import unittest
from fastNLP.embeddings import StaticEmbedding
from fastNLP import Vocabulary
import torch
import os

class TestRandomSameEntry(unittest.TestCase):
def test_same_vector(self):
vocab = Vocabulary().add_word_lst(["The", "the", "THE"])
vocab = Vocabulary().add_word_lst(["The", "the", "THE", 'a', "A"])
embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5, lower=True)
words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE"]]])
words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE", 'a', 'A']]])
words = embed(words)
embed_0 = words[0, 0]
for i in range(1, words.size(1)):
for i in range(1, 3):
assert torch.sum(embed_0==words[0, i]).eq(len(embed_0))
embed_0 = words[0, 3]
for i in range(3, 5):
assert torch.sum(embed_0 == words[0, i]).eq(len(embed_0))

@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
def test_same_vector2(self):
vocab = Vocabulary().add_word_lst(["The", 'a', 'b', "the", "THE", "B", 'a', "A"])
embed = StaticEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.6B.100d.txt',
lower=True)
words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE", 'b', "B", 'a', 'A']]])
words = embed(words)
embed_0 = words[0, 0]
for i in range(1, 3):
assert torch.sum(embed_0==words[0, i]).eq(len(embed_0))
embed_0 = words[0, 3]
for i in range(3, 5):
assert torch.sum(embed_0 == words[0, i]).eq(len(embed_0))

@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
def test_same_vector3(self):
word_lst = ["The", "the"]
no_create_word_lst = ['of', 'Of', 'With', 'with']
vocab = Vocabulary().add_word_lst(word_lst)
vocab.add_word_lst(no_create_word_lst, no_create_entry=True)
embed = StaticEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt',
lower=True)
words = torch.LongTensor([[vocab.to_index(word) for word in word_lst+no_create_word_lst]])
words = embed(words)

lowered_word_lst = [word.lower() for word in word_lst]
lowered_no_create_word_lst = [word.lower() for word in no_create_word_lst]
lowered_vocab = Vocabulary().add_word_lst(lowered_word_lst)
lowered_vocab.add_word_lst(lowered_no_create_word_lst, no_create_entry=True)
lowered_embed = StaticEmbedding(lowered_vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt',
lower=False)
lowered_words = torch.LongTensor([[lowered_vocab.to_index(word) for word in lowered_word_lst+lowered_no_create_word_lst]])
lowered_words = lowered_embed(lowered_words)

all_words = word_lst + no_create_word_lst

for idx, (word_i, word_j) in enumerate(zip(words[0], lowered_words[0])):
with self.subTest(idx=idx, word=all_words[idx]):
assert torch.sum(word_i == word_j).eq(lowered_embed.embed_size)

@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
def test_same_vector4(self):
# words = []
# create_word_lst = [] # 需要创建
# no_create_word_lst = []
# ignore_word_lst = []
# with open('/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', 'r', encoding='utf-8') as f:
# for line in f:
# words
word_lst = ["The", "the", "the", "The", "a", "A"]
no_create_word_lst = ['of', 'Of', "Of", "of", 'With', 'with']
all_words = word_lst[:-2] + no_create_word_lst[:-2]
vocab = Vocabulary(min_freq=2).add_word_lst(word_lst)
vocab.add_word_lst(no_create_word_lst, no_create_entry=True)
embed = StaticEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt',
lower=True)
words = torch.LongTensor([[vocab.to_index(word) for word in all_words]])
words = embed(words)

lowered_word_lst = [word.lower() for word in word_lst]
lowered_no_create_word_lst = [word.lower() for word in no_create_word_lst]
lowered_vocab = Vocabulary().add_word_lst(lowered_word_lst)
lowered_vocab.add_word_lst(lowered_no_create_word_lst, no_create_entry=True)
lowered_embed = StaticEmbedding(lowered_vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt',
lower=False)
lowered_words = torch.LongTensor([[lowered_vocab.to_index(word.lower()) for word in all_words]])
lowered_words = lowered_embed(lowered_words)

for idx in range(len(all_words)):
word_i, word_j = words[0, idx], lowered_words[0, idx]
with self.subTest(idx=idx, word=all_words[idx]):
assert torch.sum(word_i == word_j).eq(lowered_embed.embed_size)

Loading…
Cancel
Save