| @@ -36,9 +36,10 @@ class Vocabulary(object): | |||
| self.update(w) | |||
| else: | |||
| # it's a word to be added | |||
| self.word2idx[word] = len(self) | |||
| if self.idx2word is not None: | |||
| self.idx2word = None | |||
| if word not in self.word2idx: | |||
| self.word2idx[word] = len(self) | |||
| if self.idx2word is not None: | |||
| self.idx2word = None | |||
| def __getitem__(self, w): | |||
| @@ -80,20 +81,5 @@ class Vocabulary(object): | |||
| self.__dict__.update(state) | |||
| self.idx2word = None | |||
| if __name__ == '__main__': | |||
| import _pickle as pickle | |||
| vocab = Vocabulary() | |||
| filename = 'vocab' | |||
| vocab.update(filename) | |||
| vocab.update([filename, ['a'], [['b']], ['c']]) | |||
| idx = vocab[filename] | |||
| print('{} {}'.format(vocab.to_word(idx), vocab[filename])) | |||
| with open(filename, 'wb') as f: | |||
| pickle.dump(vocab, f) | |||
| with open(filename, 'rb') as f: | |||
| vocab = pickle.load(f) | |||
| print('{} {}'.format(vocab.to_word(idx), vocab[filename])) | |||
| print(vocab.word2idx) | |||
| @@ -0,0 +1,69 @@ | |||
| import os | |||
| import sys | |||
| sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | |||
| import unittest | |||
| import torch | |||
| from fastNLP.data.field import TextField, LabelField | |||
| from fastNLP.data.instance import Instance | |||
| from fastNLP.data.dataset import DataSet | |||
| from fastNLP.data.batch import Batch | |||
| class TestField(unittest.TestCase): | |||
| def check_batched_data_equal(self, data1, data2): | |||
| self.assertEqual(len(data1), len(data2)) | |||
| for i in range(len(data1)): | |||
| self.assertTrue(data1[i].keys(), data2[i].keys()) | |||
| for i in range(len(data1)): | |||
| for t1, t2 in zip(data1[i].values(), data2[i].values()): | |||
| self.assertTrue(torch.equal(t1, t2)) | |||
| def test_batchiter(self): | |||
| texts = [ | |||
| "i am a cat", | |||
| "this is a test of new batch", | |||
| "haha" | |||
| ] | |||
| labels = [0, 1, 0] | |||
| # prepare vocabulary | |||
| vocab = {} | |||
| for text in texts: | |||
| for tokens in text.split(): | |||
| if tokens not in vocab: | |||
| vocab[tokens] = len(vocab) | |||
| # prepare input dataset | |||
| data = DataSet() | |||
| for text, label in zip(texts, labels): | |||
| x = TextField(text.split(), False) | |||
| y = LabelField(label, is_target=True) | |||
| ins = Instance(text=x, label=y) | |||
| data.append(ins) | |||
| # use vocabulary to index data | |||
| data.index_field("text", vocab) | |||
| # define naive sampler for batch class | |||
| class SeqSampler: | |||
| def __call__(self, dataset): | |||
| return list(range(len(dataset))) | |||
| # use bacth to iterate dataset | |||
| batcher = Batch(data, SeqSampler(), 2) | |||
| TRUE_X = [{'text': torch.tensor([[0, 1, 2, 3, 0, 0, 0], [4, 5, 2, 6, 7, 8, 9]])}, {'text': torch.tensor([[10]])}] | |||
| TRUE_Y = [{'label': torch.tensor([[0], [1]])}, {'label': torch.tensor([[0]])}] | |||
| for epoch in range(3): | |||
| test_x, test_y = [], [] | |||
| for batch_x, batch_y in batcher: | |||
| test_x.append(batch_x) | |||
| test_y.append(batch_y) | |||
| self.check_batched_data_equal(TRUE_X, test_x) | |||
| self.check_batched_data_equal(TRUE_Y, test_y) | |||
| if __name__ == "__main__": | |||
| unittest.main() | |||
| @@ -0,0 +1,35 @@ | |||
| import os | |||
| import sys | |||
| sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | |||
| import unittest | |||
| from fastNLP.data.vocabulary import Vocabulary, DEFAULT_WORD_TO_INDEX | |||
| class TestVocabulary(unittest.TestCase): | |||
| def test_vocab(self): | |||
| import _pickle as pickle | |||
| import os | |||
| vocab = Vocabulary() | |||
| filename = 'vocab' | |||
| vocab.update(filename) | |||
| vocab.update([filename, ['a'], [['b']], ['c']]) | |||
| idx = vocab[filename] | |||
| before_pic = (vocab.to_word(idx), vocab[filename]) | |||
| with open(filename, 'wb') as f: | |||
| pickle.dump(vocab, f) | |||
| with open(filename, 'rb') as f: | |||
| vocab = pickle.load(f) | |||
| os.remove(filename) | |||
| vocab.build_reverse_vocab() | |||
| after_pic = (vocab.to_word(idx), vocab[filename]) | |||
| TRUE_DICT = {'vocab': 5, 'a': 6, 'b': 7, 'c': 8} | |||
| TRUE_DICT.update(DEFAULT_WORD_TO_INDEX) | |||
| TRUE_IDXDICT = {0: '<pad>', 1: '<unk>', 2: '<reserved-2>', 3: '<reserved-3>', 4: '<reserved-4>', 5: 'vocab', 6: 'a', 7: 'b', 8: 'c'} | |||
| self.assertEqual(before_pic, after_pic) | |||
| self.assertDictEqual(TRUE_DICT, vocab.word2idx) | |||
| self.assertDictEqual(TRUE_IDXDICT, vocab.idx2word) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||