@@ -36,9 +36,10 @@ class Vocabulary(object): | |||||
self.update(w) | self.update(w) | ||||
else: | else: | ||||
# it's a word to be added | # it's a word to be added | ||||
self.word2idx[word] = len(self) | |||||
if self.idx2word is not None: | |||||
self.idx2word = None | |||||
if word not in self.word2idx: | |||||
self.word2idx[word] = len(self) | |||||
if self.idx2word is not None: | |||||
self.idx2word = None | |||||
def __getitem__(self, w): | def __getitem__(self, w): | ||||
@@ -80,20 +81,5 @@ class Vocabulary(object): | |||||
self.__dict__.update(state) | self.__dict__.update(state) | ||||
self.idx2word = None | self.idx2word = None | ||||
if __name__ == '__main__': | |||||
import _pickle as pickle | |||||
vocab = Vocabulary() | |||||
filename = 'vocab' | |||||
vocab.update(filename) | |||||
vocab.update([filename, ['a'], [['b']], ['c']]) | |||||
idx = vocab[filename] | |||||
print('{} {}'.format(vocab.to_word(idx), vocab[filename])) | |||||
with open(filename, 'wb') as f: | |||||
pickle.dump(vocab, f) | |||||
with open(filename, 'rb') as f: | |||||
vocab = pickle.load(f) | |||||
print('{} {}'.format(vocab.to_word(idx), vocab[filename])) | |||||
print(vocab.word2idx) | |||||
@@ -0,0 +1,69 @@ | |||||
import os | |||||
import sys | |||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | |||||
import unittest | |||||
import torch | |||||
from fastNLP.data.field import TextField, LabelField | |||||
from fastNLP.data.instance import Instance | |||||
from fastNLP.data.dataset import DataSet | |||||
from fastNLP.data.batch import Batch | |||||
class TestField(unittest.TestCase): | |||||
def check_batched_data_equal(self, data1, data2): | |||||
self.assertEqual(len(data1), len(data2)) | |||||
for i in range(len(data1)): | |||||
self.assertTrue(data1[i].keys(), data2[i].keys()) | |||||
for i in range(len(data1)): | |||||
for t1, t2 in zip(data1[i].values(), data2[i].values()): | |||||
self.assertTrue(torch.equal(t1, t2)) | |||||
def test_batchiter(self): | |||||
texts = [ | |||||
"i am a cat", | |||||
"this is a test of new batch", | |||||
"haha" | |||||
] | |||||
labels = [0, 1, 0] | |||||
# prepare vocabulary | |||||
vocab = {} | |||||
for text in texts: | |||||
for tokens in text.split(): | |||||
if tokens not in vocab: | |||||
vocab[tokens] = len(vocab) | |||||
# prepare input dataset | |||||
data = DataSet() | |||||
for text, label in zip(texts, labels): | |||||
x = TextField(text.split(), False) | |||||
y = LabelField(label, is_target=True) | |||||
ins = Instance(text=x, label=y) | |||||
data.append(ins) | |||||
# use vocabulary to index data | |||||
data.index_field("text", vocab) | |||||
# define naive sampler for batch class | |||||
class SeqSampler: | |||||
def __call__(self, dataset): | |||||
return list(range(len(dataset))) | |||||
# use bacth to iterate dataset | |||||
batcher = Batch(data, SeqSampler(), 2) | |||||
TRUE_X = [{'text': torch.tensor([[0, 1, 2, 3, 0, 0, 0], [4, 5, 2, 6, 7, 8, 9]])}, {'text': torch.tensor([[10]])}] | |||||
TRUE_Y = [{'label': torch.tensor([[0], [1]])}, {'label': torch.tensor([[0]])}] | |||||
for epoch in range(3): | |||||
test_x, test_y = [], [] | |||||
for batch_x, batch_y in batcher: | |||||
test_x.append(batch_x) | |||||
test_y.append(batch_y) | |||||
self.check_batched_data_equal(TRUE_X, test_x) | |||||
self.check_batched_data_equal(TRUE_Y, test_y) | |||||
if __name__ == "__main__": | |||||
unittest.main() | |||||
@@ -0,0 +1,35 @@ | |||||
import os | |||||
import sys | |||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | |||||
import unittest | |||||
from fastNLP.data.vocabulary import Vocabulary, DEFAULT_WORD_TO_INDEX | |||||
class TestVocabulary(unittest.TestCase): | |||||
def test_vocab(self): | |||||
import _pickle as pickle | |||||
import os | |||||
vocab = Vocabulary() | |||||
filename = 'vocab' | |||||
vocab.update(filename) | |||||
vocab.update([filename, ['a'], [['b']], ['c']]) | |||||
idx = vocab[filename] | |||||
before_pic = (vocab.to_word(idx), vocab[filename]) | |||||
with open(filename, 'wb') as f: | |||||
pickle.dump(vocab, f) | |||||
with open(filename, 'rb') as f: | |||||
vocab = pickle.load(f) | |||||
os.remove(filename) | |||||
vocab.build_reverse_vocab() | |||||
after_pic = (vocab.to_word(idx), vocab[filename]) | |||||
TRUE_DICT = {'vocab': 5, 'a': 6, 'b': 7, 'c': 8} | |||||
TRUE_DICT.update(DEFAULT_WORD_TO_INDEX) | |||||
TRUE_IDXDICT = {0: '<pad>', 1: '<unk>', 2: '<reserved-2>', 3: '<reserved-3>', 4: '<reserved-4>', 5: 'vocab', 6: 'a', 7: 'b', 8: 'c'} | |||||
self.assertEqual(before_pic, after_pic) | |||||
self.assertDictEqual(TRUE_DICT, vocab.word2idx) | |||||
self.assertDictEqual(TRUE_IDXDICT, vocab.idx2word) | |||||
if __name__ == '__main__': | |||||
unittest.main() |