From 6a498bbdf26220622b226d37ce50a3a29bd699c6 Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Sat, 23 Mar 2019 15:44:23 +0800 Subject: [PATCH] =?UTF-8?q?*=20=E7=BB=99vocabulary=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E9=81=8D=E5=8E=86=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/vocabulary.py | 3 +++ test/core/test_vocabulary.py | 9 +++++++++ 2 files changed, 12 insertions(+) diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index 1e0857f3..a1c8e678 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -199,3 +199,6 @@ class Vocabulary(object): def __repr__(self): return "Vocabulary({}...)".format(list(self.word_count.keys())[:5]) + + def __iter__(self): + return iter(list(self.word_count.keys())) diff --git a/test/core/test_vocabulary.py b/test/core/test_vocabulary.py index af2c493b..2f9cd3b1 100644 --- a/test/core/test_vocabulary.py +++ b/test/core/test_vocabulary.py @@ -60,6 +60,15 @@ class TestIndexing(unittest.TestCase): vocab.update(text) self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]]) + def test_iteration(self): + vocab = Vocabulary() + text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in", + "works", "well", "in", "most", "cases", "scales", "well"] + vocab.update(text) + text = set(text) + for word in vocab: + self.assertTrue(word in text) + class TestOther(unittest.TestCase): def test_additional_update(self):