diff --git a/fastNLP/embeddings/static_embedding.py b/fastNLP/embeddings/static_embedding.py index 4079b2a2..a75ad18f 100644 --- a/fastNLP/embeddings/static_embedding.py +++ b/fastNLP/embeddings/static_embedding.py @@ -106,6 +106,7 @@ class StaticEmbedding(TokenEmbedding): print(f"{len(vocab) - len(truncated_vocab)} out of {len(vocab)} words have frequency less than {min_freq}.") vocab = truncated_vocab + self.only_norm_found_vector = kwargs.get('only_norm_found_vector', False) # 读取embedding if lower: lowered_vocab = Vocabulary(padding=vocab.padding, unknown=vocab.unknown) @@ -142,7 +143,7 @@ class StaticEmbedding(TokenEmbedding): else: embedding = self._randomly_init_embed(len(vocab), embedding_dim, init_method) self.words_to_words = nn.Parameter(torch.arange(len(vocab)).long(), requires_grad=False) - if normalize: + if not self.only_norm_found_vector and normalize: embedding /= (torch.norm(embedding, dim=1, keepdim=True) + 1e-12) if truncate_vocab: @@ -233,6 +234,7 @@ class StaticEmbedding(TokenEmbedding): if vocab.unknown: matrix[vocab.unknown_idx] = torch.zeros(dim) found_count = 0 + found_unknown = False for idx, line in enumerate(f, start_idx): try: parts = line.strip().split() @@ -243,9 +245,12 @@ class StaticEmbedding(TokenEmbedding): word = vocab.padding elif word == unknown and vocab.unknown is not None: word = vocab.unknown + found_unknown = True if word in vocab: index = vocab.to_index(word) matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim)) + if self.only_norm_found_vector: + matrix[index] = matrix[index]/np.linalg.norm(matrix[index]) found_count += 1 except Exception as e: if error == 'ignore': @@ -256,7 +261,7 @@ class StaticEmbedding(TokenEmbedding): print("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab))) for word, index in vocab: if index not in matrix and not vocab._is_word_no_create_entry(word): - if vocab.unknown_idx in matrix: # 如果有unkonwn,用unknown初始化 + if found_unknown: # 如果有unkonwn,用unknown初始化 matrix[index] = matrix[vocab.unknown_idx] else: matrix[index] = None diff --git a/test/embeddings/test_bert_embedding.py b/test/embeddings/test_bert_embedding.py index c27ebd40..760029a3 100644 --- a/test/embeddings/test_bert_embedding.py +++ b/test/embeddings/test_bert_embedding.py @@ -9,6 +9,6 @@ class TestDownload(unittest.TestCase): def test_download(self): # import os vocab = Vocabulary().add_word_lst("This is a test .".split()) - embed = BertEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/embedding/bert-base-cased') + embed = BertEmbedding(vocab, model_dir_or_name='en') words = torch.LongTensor([[0, 1, 2]]) print(embed(words).size()) diff --git a/test/embeddings/test_static_embedding.py b/test/embeddings/test_static_embedding.py index ca97dd75..83137345 100644 --- a/test/embeddings/test_static_embedding.py +++ b/test/embeddings/test_static_embedding.py @@ -5,6 +5,23 @@ from fastNLP import Vocabulary import torch import os +class TestLoad(unittest.TestCase): + def test_norm1(self): + # 测试只对可以找到的norm + vocab = Vocabulary().add_word_lst(['the', 'a', 'notinfile']) + embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/glove.6B.50d_test.txt', + only_norm_found_vector=True) + self.assertEqual(round(torch.norm(embed(torch.LongTensor([[2]]))).item(), 4), 1) + self.assertNotEqual(torch.norm(embed(torch.LongTensor([[4]]))).item(), 1) + + def test_norm2(self): + # 测试对所有都norm + vocab = Vocabulary().add_word_lst(['the', 'a', 'notinfile']) + embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/glove.6B.50d_test.txt', + normalize=True) + self.assertEqual(round(torch.norm(embed(torch.LongTensor([[2]]))).item(), 4), 1) + self.assertEqual(round(torch.norm(embed(torch.LongTensor([[4]]))).item(), 4), 1) + class TestRandomSameEntry(unittest.TestCase): def test_same_vector(self): vocab = Vocabulary().add_word_lst(["The", "the", "THE", 'a', "A"]) @@ -21,7 +38,7 @@ class TestRandomSameEntry(unittest.TestCase): @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") def test_same_vector2(self): vocab = Vocabulary().add_word_lst(["The", 'a', 'b', "the", "THE", "B", 'a', "A"]) - embed = StaticEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.6B.100d.txt', + embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6B-100d', lower=True) words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE", 'b', "B", 'a', 'A']]]) words = embed(words) @@ -39,7 +56,7 @@ class TestRandomSameEntry(unittest.TestCase): no_create_word_lst = ['of', 'Of', 'With', 'with'] vocab = Vocabulary().add_word_lst(word_lst) vocab.add_word_lst(no_create_word_lst, no_create_entry=True) - embed = StaticEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', + embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6B-100d', lower=True) words = torch.LongTensor([[vocab.to_index(word) for word in word_lst+no_create_word_lst]]) words = embed(words) @@ -48,7 +65,7 @@ class TestRandomSameEntry(unittest.TestCase): lowered_no_create_word_lst = [word.lower() for word in no_create_word_lst] lowered_vocab = Vocabulary().add_word_lst(lowered_word_lst) lowered_vocab.add_word_lst(lowered_no_create_word_lst, no_create_entry=True) - lowered_embed = StaticEmbedding(lowered_vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', + lowered_embed = StaticEmbedding(lowered_vocab, model_dir_or_name='en-glove-6B-100d', lower=False) lowered_words = torch.LongTensor([[lowered_vocab.to_index(word) for word in lowered_word_lst+lowered_no_create_word_lst]]) lowered_words = lowered_embed(lowered_words) @@ -67,7 +84,7 @@ class TestRandomSameEntry(unittest.TestCase): all_words = word_lst[:-2] + no_create_word_lst[:-2] vocab = Vocabulary(min_freq=2).add_word_lst(word_lst) vocab.add_word_lst(no_create_word_lst, no_create_entry=True) - embed = StaticEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', + embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6B-100d', lower=True) words = torch.LongTensor([[vocab.to_index(word) for word in all_words]]) words = embed(words) @@ -76,7 +93,7 @@ class TestRandomSameEntry(unittest.TestCase): lowered_no_create_word_lst = [word.lower() for word in no_create_word_lst] lowered_vocab = Vocabulary().add_word_lst(lowered_word_lst) lowered_vocab.add_word_lst(lowered_no_create_word_lst, no_create_entry=True) - lowered_embed = StaticEmbedding(lowered_vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', + lowered_embed = StaticEmbedding(lowered_vocab, model_dir_or_name='en-glove-6B-100d', lower=False) lowered_words = torch.LongTensor([[lowered_vocab.to_index(word.lower()) for word in all_words]]) lowered_words = lowered_embed(lowered_words) @@ -94,14 +111,14 @@ class TestRandomSameEntry(unittest.TestCase): all_words = word_lst[:-2] + no_create_word_lst[:-2] vocab = Vocabulary().add_word_lst(word_lst) vocab.add_word_lst(no_create_word_lst, no_create_entry=True) - embed = StaticEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', + embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6B-100d', lower=False, min_freq=2) words = torch.LongTensor([[vocab.to_index(word) for word in all_words]]) words = embed(words) min_freq_vocab = Vocabulary(min_freq=2).add_word_lst(word_lst) min_freq_vocab.add_word_lst(no_create_word_lst, no_create_entry=True) - min_freq_embed = StaticEmbedding(min_freq_vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', + min_freq_embed = StaticEmbedding(min_freq_vocab, model_dir_or_name='en-glove-6B-100d', lower=False) min_freq_words = torch.LongTensor([[min_freq_vocab.to_index(word.lower()) for word in all_words]]) min_freq_words = min_freq_embed(min_freq_words) diff --git a/test/test_tutorials.py b/test/test_tutorials.py index 6f4a8347..3ec0e381 100644 --- a/test/test_tutorials.py +++ b/test/test_tutorials.py @@ -5,14 +5,13 @@ from fastNLP import Instance from fastNLP import Vocabulary from fastNLP.core.losses import CrossEntropyLoss from fastNLP.core.metrics import AccuracyMetric - +from fastNLP.io.loader import CSVLoader class TestTutorial(unittest.TestCase): def test_fastnlp_10min_tutorial(self): # 从csv读取数据到DataSet sample_path = "test/data_for_tests/tutorial_sample_dataset.csv" - dataset = DataSet.read_csv(sample_path, headers=('raw_sentence', 'label'), - sep='\t') + dataset = CSVLoader(headers=['raw_sentence', 'label'], sep=' ')._load(sample_path) print(len(dataset)) print(dataset[0]) print(dataset[-3]) @@ -110,7 +109,7 @@ class TestTutorial(unittest.TestCase): def test_fastnlp_1min_tutorial(self): # tutorials/fastnlp_1min_tutorial.ipynb data_path = "test/data_for_tests/tutorial_sample_dataset.csv" - ds = DataSet.read_csv(data_path, headers=('raw_sentence', 'label'), sep='\t') + ds = CSVLoader(headers=['raw_sentence', 'label'], sep=' ')._load(data_path) print(ds[1]) # 将所有数字转为小写