diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index 1e0857f3..a1c8e678 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -199,3 +199,6 @@ class Vocabulary(object): def __repr__(self): return "Vocabulary({}...)".format(list(self.word_count.keys())[:5]) + + def __iter__(self): + return iter(list(self.word_count.keys())) diff --git a/test/core/test_vocabulary.py b/test/core/test_vocabulary.py index af2c493b..2f9cd3b1 100644 --- a/test/core/test_vocabulary.py +++ b/test/core/test_vocabulary.py @@ -60,6 +60,15 @@ class TestIndexing(unittest.TestCase): vocab.update(text) self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]]) + def test_iteration(self): + vocab = Vocabulary() + text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in", + "works", "well", "in", "most", "cases", "scales", "well"] + vocab.update(text) + text = set(text) + for word in vocab: + self.assertTrue(word in text) + class TestOther(unittest.TestCase): def test_additional_update(self):