|
- import unittest
- from collections import Counter
-
- from fastNLP import Vocabulary
- from fastNLP import DataSet
- from fastNLP import Instance
-
- text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in",
- "works", "well", "in", "most", "cases", "scales", "well"]
- counter = Counter(text)
-
-
- class TestAdd(unittest.TestCase):
- def test_add(self):
- vocab = Vocabulary()
- for word in text:
- vocab.add(word)
- self.assertEqual(vocab.word_count, counter)
-
- def test_add_word(self):
- vocab = Vocabulary()
- for word in text:
- vocab.add_word(word)
- self.assertEqual(vocab.word_count, counter)
-
- def test_add_word_lst(self):
- vocab = Vocabulary()
- vocab.add_word_lst(text)
- self.assertEqual(vocab.word_count, counter)
-
- def test_update(self):
- vocab = Vocabulary()
- vocab.update(text)
- self.assertEqual(vocab.word_count, counter)
-
- def test_from_dataset(self):
- start_char = 65
- num_samples = 10
-
- # 0 dim
- dataset = DataSet()
- for i in range(num_samples):
- ins = Instance(char=chr(start_char + i))
- dataset.append(ins)
- vocab = Vocabulary()
- vocab.from_dataset(dataset, field_name='char')
- for i in range(num_samples):
- self.assertEqual(vocab.to_index(chr(start_char + i)), i + 2)
- vocab.index_dataset(dataset, field_name='char')
-
- # 1 dim
- dataset = DataSet()
- for i in range(num_samples):
- ins = Instance(char=[chr(start_char + i)] * 6)
- dataset.append(ins)
- vocab = Vocabulary()
- vocab.from_dataset(dataset, field_name='char')
- for i in range(num_samples):
- self.assertEqual(vocab.to_index(chr(start_char + i)), i + 2)
- vocab.index_dataset(dataset, field_name='char')
-
- # 2 dim
- dataset = DataSet()
- for i in range(num_samples):
- ins = Instance(char=[[chr(start_char + i) for _ in range(6)] for _ in range(6)])
- dataset.append(ins)
- vocab = Vocabulary()
- vocab.from_dataset(dataset, field_name='char')
- for i in range(num_samples):
- self.assertEqual(vocab.to_index(chr(start_char + i)), i + 2)
- vocab.index_dataset(dataset, field_name='char')
-
- def test_from_dataset_no_entry(self):
- # 测试能否正确将no_create_entry正确设置
- dataset = DataSet()
- start_char = 65
- num_samples = 10
- test_dataset = DataSet()
- for i in range(num_samples):
- char = [chr(start_char + i)] * 6
- ins = Instance(char=char)
- dataset.append(ins)
- ins = Instance(char=[c+c for c in char])
- test_dataset.append(ins)
- vocab = Vocabulary()
- vocab.from_dataset(dataset, field_name='char', no_create_entry_dataset=test_dataset)
- vocab.index_dataset(dataset, field_name='char')
- for i in range(num_samples):
- self.assertEqual(True, vocab._is_word_no_create_entry(chr(start_char + i)+chr(start_char + i)))
-
- def test_no_entry(self):
- # 先建立vocabulary,然后变化no_create_entry, 测试能否正确识别
- text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in",
- "works", "well", "in", "most", "cases", "scales", "well"]
- vocab = Vocabulary()
- vocab.add_word_lst(text)
-
- self.assertFalse(vocab._is_word_no_create_entry('FastNLP'))
- vocab.add_word('FastNLP', no_create_entry=True)
- self.assertFalse(vocab._is_word_no_create_entry('FastNLP'))
-
- vocab.add_word('fastnlp', no_create_entry=True)
- self.assertTrue(vocab._is_word_no_create_entry('fastnlp'))
- vocab.add_word('fastnlp', no_create_entry=False)
- self.assertFalse(vocab._is_word_no_create_entry('fastnlp'))
-
- vocab.add_word_lst(['1']*10, no_create_entry=True)
- self.assertTrue(vocab._is_word_no_create_entry('1'))
- vocab.add_word('1')
- self.assertFalse(vocab._is_word_no_create_entry('1'))
-
-
- class TestIndexing(unittest.TestCase):
- def test_len(self):
- vocab = Vocabulary(unknown=None, padding=None)
- vocab.update(text)
- self.assertEqual(len(vocab), len(counter))
-
- def test_contains(self):
- vocab = Vocabulary(unknown=None)
- vocab.update(text)
- self.assertTrue(text[-1] in vocab)
- self.assertFalse("~!@#" in vocab)
- self.assertEqual(text[-1] in vocab, vocab.has_word(text[-1]))
- self.assertEqual("~!@#" in vocab, vocab.has_word("~!@#"))
-
- def test_index(self):
- vocab = Vocabulary()
- vocab.update(text)
- res = [vocab[w] for w in set(text)]
- self.assertEqual(len(res), len(set(res)))
-
- res = [vocab.to_index(w) for w in set(text)]
- self.assertEqual(len(res), len(set(res)))
-
- def test_to_word(self):
- vocab = Vocabulary()
- vocab.update(text)
- self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]])
-
- def test_iteration(self):
- vocab = Vocabulary(padding=None, unknown=None)
- text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in",
- "works", "well", "in", "most", "cases", "scales", "well"]
- vocab.update(text)
- text = set(text)
- for word, idx in vocab:
- self.assertTrue(word in text)
- self.assertTrue(idx < len(vocab))
-
- def test_rebuild(self):
- # 测试build之后新加入词,原来的词顺序不变
- vocab = Vocabulary()
- text = [str(idx) for idx in range(10)]
- vocab.update(text)
- for i in text:
- self.assertEqual(int(i)+2, vocab.to_index(i))
- indexes = []
- for word, index in vocab:
- indexes.append((word, index))
- vocab.add_word_lst([str(idx) for idx in range(10, 13)])
- for idx, pair in enumerate(indexes):
- self.assertEqual(pair[1], vocab.to_index(pair[0]))
- for i in range(13):
- self.assertEqual(int(i)+2, vocab.to_index(str(i)))
-
- class TestOther(unittest.TestCase):
- def test_additional_update(self):
- vocab = Vocabulary()
- vocab.update(text)
-
- _ = vocab["well"]
- self.assertEqual(vocab.rebuild, False)
-
- vocab.add("hahaha")
- self.assertEqual(vocab.rebuild, True)
-
- _ = vocab["hahaha"]
- self.assertEqual(vocab.rebuild, False)
- self.assertTrue("hahaha" in vocab)
-
- def test_warning(self):
- vocab = Vocabulary(max_size=len(set(text)))
- vocab.update(text)
- self.assertEqual(vocab.rebuild, True)
- print(len(vocab))
- self.assertEqual(vocab.rebuild, False)
-
- vocab.update(["hahahha", "hhh", "vvvv", "ass", "asss", "jfweiong", "eqgfeg", "feqfw"])
- # this will print a warning
- self.assertEqual(vocab.rebuild, True)
|