From 52b1b18a76d3620f413d59967f1b9cb2f4ec650e Mon Sep 17 00:00:00 2001 From: yunfan Date: Tue, 4 Dec 2018 17:04:31 +0800 Subject: [PATCH] fix bugs in vocab --- fastNLP/core/vocabulary.py | 49 +++++++++++---------------------- test/core/test_trainer.py | 52 +++++++++++++++++++----------------- test/core/test_vocabulary.py | 20 +++++++------- 3 files changed, 53 insertions(+), 68 deletions(-) diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index ca6b4ebf..14577635 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -1,11 +1,4 @@ from collections import Counter -from copy import deepcopy - -DEFAULT_PADDING_LABEL = '' # dict index = 0 -DEFAULT_UNKNOWN_LABEL = '' # dict index = 1 - -DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1} - def isiterable(p_object): try: @@ -57,22 +50,16 @@ class Vocabulary(object): vocab.to_word(5) """ - def __init__(self, need_default=True, max_size=None, min_freq=None): + def __init__(self, max_size=None, min_freq=None, unknown='', padding=''): """ - :param bool need_default: set if the Vocabulary has default labels reserved for sequences. Default: True. :param int max_size: set the max number of words in Vocabulary. Default: None :param int min_freq: set the min occur frequency of words in Vocabulary. Default: None """ self.max_size = max_size self.min_freq = min_freq self.word_count = Counter() - self.has_default = need_default - if self.has_default: - self.padding_label = DEFAULT_PADDING_LABEL - self.unknown_label = DEFAULT_UNKNOWN_LABEL - else: - self.padding_label = None - self.unknown_label = None + self.unknown = unknown + self.padding = padding self.word2idx = None self.idx2word = None self.rebuild = True @@ -113,17 +100,18 @@ class Vocabulary(object): """Build 'word to index' dict, and filter the word using `max_size` and `min_freq`. """ - if self.has_default: - self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX) - self.word2idx[self.unknown_label] = self.word2idx.pop(DEFAULT_UNKNOWN_LABEL) - self.word2idx[self.padding_label] = self.word2idx.pop(DEFAULT_PADDING_LABEL) - else: - self.word2idx = {} + self.word2idx = {} + if self.padding is not None: + self.word2idx[self.padding] = 0 + if self.unknown is not None: + self.word2idx[self.unknown] = 1 max_size = min(self.max_size, len(self.word_count)) if self.max_size else None words = self.word_count.most_common(max_size) if self.min_freq is not None: words = filter(lambda kv: kv[1] >= self.min_freq, words) + if self.word2idx is not None: + words = filter(lambda kv: kv[0] not in self.word2idx, words) start_idx = len(self.word2idx) self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)}) self.build_reverse_vocab() @@ -159,8 +147,8 @@ class Vocabulary(object): """ if w in self.word2idx: return self.word2idx[w] - elif self.has_default: - return self.word2idx[self.unknown_label] + if self.unknown is not None: + return self.word2idx[self.unknown] else: raise ValueError("word {} not in vocabulary".format(w)) @@ -175,21 +163,16 @@ class Vocabulary(object): @property @check_build_vocab def unknown_idx(self): - if self.unknown_label is None: + if self.unknown is None: return None - return self.word2idx[self.unknown_label] - - def __setattr__(self, name, val): - self.__dict__[name] = val - if name in ["unknown_label", "padding_label"]: - self.word2idx = None + return self.word2idx[self.unknown] @property @check_build_vocab def padding_idx(self): - if self.padding_label is None: + if self.padding is None: return None - return self.word2idx[self.padding_label] + return self.word2idx[self.padding] @check_build_vocab def to_word(self, idx): diff --git a/test/core/test_trainer.py b/test/core/test_trainer.py index 7903b403..1b578eae 100644 --- a/test/core/test_trainer.py +++ b/test/core/test_trainer.py @@ -4,6 +4,7 @@ import numpy as np import torch.nn.functional as F from torch import nn +from fastNLP.core.utils import CheckError from fastNLP.core.dataset import DataSet from fastNLP.core.instance import Instance from fastNLP.core.losses import BCELoss @@ -56,7 +57,8 @@ class TrainerTestGround(unittest.TestCase): dev_data=dev_set, optimizer=SGD(lr=0.1), check_code_level=2, - use_tqdm=True) + use_tqdm=True, + save_path=None) trainer.train() """ # 应该正确运行 @@ -145,16 +147,14 @@ class TrainerTestGround(unittest.TestCase): return {'wrong_loss_key': loss} model = Model() - trainer = Trainer( - train_data=dataset, - model=model, - use_tqdm=False, - print_every=2 - ) - trainer.train() - """ - # 应该正确运行 - """ + with self.assertRaises(NameError): + trainer = Trainer( + train_data=dataset, + model=model, + use_tqdm=False, + print_every=2 + ) + trainer.train() def test_trainer_suggestion4(self): # 检查报错提示能否正确提醒用户 @@ -173,12 +173,13 @@ class TrainerTestGround(unittest.TestCase): return {'loss': loss} model = Model() - trainer = Trainer( - train_data=dataset, - model=model, - use_tqdm=False, - print_every=2 - ) + with self.assertRaises(NameError): + trainer = Trainer( + train_data=dataset, + model=model, + use_tqdm=False, + print_every=2 + ) def test_trainer_suggestion5(self): # 检查报错提示能否正确提醒用户 @@ -225,14 +226,15 @@ class TrainerTestGround(unittest.TestCase): return {'pred': x} model = Model() - trainer = Trainer( - train_data=dataset, - model=model, - dev_data=dataset, - metrics=AccuracyMetric(), - use_tqdm=False, - print_every=2 - ) + with self.assertRaises(NameError): + trainer = Trainer( + train_data=dataset, + model=model, + dev_data=dataset, + metrics=AccuracyMetric(), + use_tqdm=False, + print_every=2 + ) def test_case2(self): # check metrics Wrong diff --git a/test/core/test_vocabulary.py b/test/core/test_vocabulary.py index e453e935..af2c493b 100644 --- a/test/core/test_vocabulary.py +++ b/test/core/test_vocabulary.py @@ -10,36 +10,36 @@ counter = Counter(text) class TestAdd(unittest.TestCase): def test_add(self): - vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) + vocab = Vocabulary(max_size=None, min_freq=None) for word in text: vocab.add(word) self.assertEqual(vocab.word_count, counter) def test_add_word(self): - vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) + vocab = Vocabulary(max_size=None, min_freq=None) for word in text: vocab.add_word(word) self.assertEqual(vocab.word_count, counter) def test_add_word_lst(self): - vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) + vocab = Vocabulary(max_size=None, min_freq=None) vocab.add_word_lst(text) self.assertEqual(vocab.word_count, counter) def test_update(self): - vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) + vocab = Vocabulary(max_size=None, min_freq=None) vocab.update(text) self.assertEqual(vocab.word_count, counter) class TestIndexing(unittest.TestCase): def test_len(self): - vocab = Vocabulary(need_default=False, max_size=None, min_freq=None) + vocab = Vocabulary(max_size=None, min_freq=None, unknown=None, padding=None) vocab.update(text) self.assertEqual(len(vocab), len(counter)) def test_contains(self): - vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) + vocab = Vocabulary(max_size=None, min_freq=None, unknown=None, padding=None) vocab.update(text) self.assertTrue(text[-1] in vocab) self.assertFalse("~!@#" in vocab) @@ -47,7 +47,7 @@ class TestIndexing(unittest.TestCase): self.assertEqual("~!@#" in vocab, vocab.has_word("~!@#")) def test_index(self): - vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) + vocab = Vocabulary(max_size=None, min_freq=None) vocab.update(text) res = [vocab[w] for w in set(text)] self.assertEqual(len(res), len(set(res))) @@ -56,14 +56,14 @@ class TestIndexing(unittest.TestCase): self.assertEqual(len(res), len(set(res))) def test_to_word(self): - vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) + vocab = Vocabulary(max_size=None, min_freq=None) vocab.update(text) self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]]) class TestOther(unittest.TestCase): def test_additional_update(self): - vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) + vocab = Vocabulary(max_size=None, min_freq=None) vocab.update(text) _ = vocab["well"] @@ -77,7 +77,7 @@ class TestOther(unittest.TestCase): self.assertTrue("hahaha" in vocab) def test_warning(self): - vocab = Vocabulary(need_default=True, max_size=len(set(text)), min_freq=None) + vocab = Vocabulary(max_size=len(set(text)), min_freq=None) vocab.update(text) self.assertEqual(vocab.rebuild, True) print(len(vocab))