diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index 8c3050ba..5d843ffa 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -91,47 +91,84 @@ class Vocabulary(object): self.idx2word = None self.rebuild = True # 用于承载不需要单独创建entry的词语,具体见from_dataset()方法 - self._no_create_word = defaultdict(int) + self._no_create_word = Counter() @_check_build_status - def update(self, word_lst): + def update(self, word_lst, no_create_entry=False): """依次增加序列中词在词典中的出现频率 :param list word_lst: a list of strings - """ + :param bool no_create_entry: 在使用fastNLP.TokenEmbedding加载预训练模型时,没有从预训练词表中找到这个词的处理方式。 + 如果为True,则不会有这个词语创建一个单独的entry,它将一直被指向unk的表示; 如果为False,则为这个词创建一个单独 + 的entry。如果这个word来自于dev或者test,一般设置为True,如果来自与train一般设置为False。以下两种情况: 如果新 + 加入一个word,且no_create_entry为True,但这个词之前已经在Vocabulary中且并不是no_create_entry的,则还是会为这 + 个词创建一个单独的vector; 如果no_create_entry为False,但这个词之前已经在Vocabulary中且并不是no_create_entry的, + 则这个词将认为是需要创建单独的vector的。 + """ + self._add_no_create_entry(word_lst, no_create_entry) self.word_count.update(word_lst) @_check_build_status - def add(self, word): + def add(self, word, no_create_entry=False): """ 增加一个新词在词典中的出现频率 :param str word: 新词 - """ + :param bool no_create_entry: 在使用fastNLP.TokenEmbedding加载预训练模型时,没有从预训练词表中找到这个词的处理方式。 + 如果为True,则不会有这个词语创建一个单独的entry,它将一直被指向unk的表示; 如果为False,则为这个词创建一个单独 + 的entry。如果这个word来自于dev或者test,一般设置为True,如果来自与train一般设置为False。以下两种情况: 如果新 + 加入一个word,且no_create_entry为True,但这个词之前已经在Vocabulary中且并不是no_create_entry的,则还是会为这 + 个词创建一个单独的vector; 如果no_create_entry为False,但这个词之前已经在Vocabulary中且并不是no_create_entry的, + 则这个词将认为是需要创建单独的vector的。 + """ + self._add_no_create_entry(word, no_create_entry) self.word_count[word] += 1 - + + def _add_no_create_entry(self, word, no_create_entry): + """ + 在新加入word时,检查_no_create_word的设置。 + + :param str, List[str] word: + :param bool no_create_entry: + :return: + """ + if isinstance(word, str): + word = [word] + for w in word: + if no_create_entry and self.word_count.get(w, 0) == self._no_create_word.get(w, 0): + self._no_create_word[w] += 1 + elif not no_create_entry and w in self._no_create_word: + self._no_create_word.pop(w) + @_check_build_status - def add_word(self, word): + def add_word(self, word, no_create_entry=False): """ 增加一个新词在词典中的出现频率 :param str word: 新词 - """ - if word in self._no_create_word: - self._no_create_word.pop(word) - self.add(word) + :param bool no_create_entry: 在使用fastNLP.TokenEmbedding加载预训练模型时,没有从预训练词表中找到这个词的处理方式。 + 如果为True,则不会有这个词语创建一个单独的entry,它将一直被指向unk的表示; 如果为False,则为这个词创建一个单独 + 的entry。如果这个word来自于dev或者test,一般设置为True,如果来自与train一般设置为False。以下两种情况: 如果新 + 加入一个word,且no_create_entry为True,但这个词之前已经在Vocabulary中且并不是no_create_entry的,则还是会为这 + 个词创建一个单独的vector; 如果no_create_entry为False,但这个词之前已经在Vocabulary中且并不是no_create_entry的, + 则这个词将认为是需要创建单独的vector的。 + """ + self.add(word, no_create_entry=no_create_entry) @_check_build_status - def add_word_lst(self, word_lst): + def add_word_lst(self, word_lst, no_create_entry=False): """ 依次增加序列中词在词典中的出现频率 :param list[str] word_lst: 词的序列 - """ - for word in word_lst: - if word in self._no_create_word: - self._no_create_word.pop(word) - self.update(word_lst) + :param bool no_create_entry: 在使用fastNLP.TokenEmbedding加载预训练模型时,没有从预训练词表中找到这个词的处理方式。 + 如果为True,则不会有这个词语创建一个单独的entry,它将一直被指向unk的表示; 如果为False,则为这个词创建一个单独 + 的entry。如果这个word来自于dev或者test,一般设置为True,如果来自与train一般设置为False。以下两种情况: 如果新 + 加入一个word,且no_create_entry为True,但这个词之前已经在Vocabulary中且并不是no_create_entry的,则还是会为这 + 个词创建一个单独的vector; 如果no_create_entry为False,但这个词之前已经在Vocabulary中且并不是no_create_entry的, + 则这个词将认为是需要创建单独的vector的。 + """ + self.update(word_lst, no_create_entry=no_create_entry) def build_vocab(self): """ @@ -283,23 +320,17 @@ class Vocabulary(object): for fn in field_name: field = ins[fn] if isinstance(field, str): - if no_create_entry and field not in self.word_count: - self._no_create_word[field] += 1 - self.add_word(field) + self.add_word(field, no_create_entry=no_create_entry) elif isinstance(field, (list, np.ndarray)): if not isinstance(field[0], (list, np.ndarray)): for word in field: - if no_create_entry and word not in self.word_count: - self._no_create_word[word] += 1 - self.add_word(word) + self.add_word(word, no_create_entry=no_create_entry) else: if isinstance(field[0][0], (list, np.ndarray)): raise RuntimeError("Only support field with 2 dimensions.") for words in field: for word in words: - if no_create_entry and word not in self.word_count: - self._no_create_word[word] += 1 - self.add_word(word) + self.add_word(word, no_create_entry=no_create_entry) for idx, dataset in enumerate(datasets): if isinstance(dataset, DataSet): diff --git a/test/core/test_vocabulary.py b/test/core/test_vocabulary.py index b3326f6a..5d8d4269 100644 --- a/test/core/test_vocabulary.py +++ b/test/core/test_vocabulary.py @@ -88,6 +88,27 @@ class TestAdd(unittest.TestCase): for i in range(num_samples): self.assertEqual(True, vocab._is_word_no_create_entry(chr(start_char + i)+chr(start_char + i))) + def test_no_entry(self): + # 先建立vocabulary,然后变化no_create_entry, 测试能否正确识别 + text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in", + "works", "well", "in", "most", "cases", "scales", "well"] + vocab = Vocabulary() + vocab.add_word_lst(text) + + self.assertFalse(vocab._is_word_no_create_entry('FastNLP')) + vocab.add_word('FastNLP', no_create_entry=True) + self.assertFalse(vocab._is_word_no_create_entry('FastNLP')) + + vocab.add_word('fastnlp', no_create_entry=True) + self.assertTrue(vocab._is_word_no_create_entry('fastnlp')) + vocab.add_word('fastnlp', no_create_entry=False) + self.assertFalse(vocab._is_word_no_create_entry('fastnlp')) + + vocab.add_word_lst(['1']*10, no_create_entry=True) + self.assertTrue(vocab._is_word_no_create_entry('1')) + vocab.add_word('1') + self.assertFalse(vocab._is_word_no_create_entry('1')) + class TestIndexing(unittest.TestCase): def test_len(self): @@ -127,6 +148,21 @@ class TestIndexing(unittest.TestCase): self.assertTrue(word in text) self.assertTrue(idx < len(vocab)) + def test_rebuild(self): + # 测试build之后新加入词,原来的词顺序不变 + vocab = Vocabulary() + text = [str(idx) for idx in range(10)] + vocab.update(text) + for i in text: + self.assertEqual(int(i)+2, vocab.to_index(i)) + indexes = [] + for word, index in vocab: + indexes.append((word, index)) + vocab.add_word_lst([str(idx) for idx in range(10, 13)]) + for idx, pair in enumerate(indexes): + self.assertEqual(pair[1], vocab.to_index(pair[0])) + for i in range(13): + self.assertEqual(int(i)+2, vocab.to_index(str(i))) class TestOther(unittest.TestCase): def test_additional_update(self):