@@ -36,9 +36,10 @@ class Vocabulary(object): | |||
self.update(w) | |||
else: | |||
# it's a word to be added | |||
self.word2idx[word] = len(self) | |||
if self.idx2word is not None: | |||
self.idx2word = None | |||
if word not in self.word2idx: | |||
self.word2idx[word] = len(self) | |||
if self.idx2word is not None: | |||
self.idx2word = None | |||
def __getitem__(self, w): | |||
@@ -80,20 +81,5 @@ class Vocabulary(object): | |||
self.__dict__.update(state) | |||
self.idx2word = None | |||
if __name__ == '__main__': | |||
import _pickle as pickle | |||
vocab = Vocabulary() | |||
filename = 'vocab' | |||
vocab.update(filename) | |||
vocab.update([filename, ['a'], [['b']], ['c']]) | |||
idx = vocab[filename] | |||
print('{} {}'.format(vocab.to_word(idx), vocab[filename])) | |||
with open(filename, 'wb') as f: | |||
pickle.dump(vocab, f) | |||
with open(filename, 'rb') as f: | |||
vocab = pickle.load(f) | |||
print('{} {}'.format(vocab.to_word(idx), vocab[filename])) | |||
print(vocab.word2idx) | |||
@@ -0,0 +1,69 @@ | |||
import os | |||
import sys | |||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | |||
import unittest | |||
import torch | |||
from fastNLP.data.field import TextField, LabelField | |||
from fastNLP.data.instance import Instance | |||
from fastNLP.data.dataset import DataSet | |||
from fastNLP.data.batch import Batch | |||
class TestField(unittest.TestCase): | |||
def check_batched_data_equal(self, data1, data2): | |||
self.assertEqual(len(data1), len(data2)) | |||
for i in range(len(data1)): | |||
self.assertTrue(data1[i].keys(), data2[i].keys()) | |||
for i in range(len(data1)): | |||
for t1, t2 in zip(data1[i].values(), data2[i].values()): | |||
self.assertTrue(torch.equal(t1, t2)) | |||
def test_batchiter(self): | |||
texts = [ | |||
"i am a cat", | |||
"this is a test of new batch", | |||
"haha" | |||
] | |||
labels = [0, 1, 0] | |||
# prepare vocabulary | |||
vocab = {} | |||
for text in texts: | |||
for tokens in text.split(): | |||
if tokens not in vocab: | |||
vocab[tokens] = len(vocab) | |||
# prepare input dataset | |||
data = DataSet() | |||
for text, label in zip(texts, labels): | |||
x = TextField(text.split(), False) | |||
y = LabelField(label, is_target=True) | |||
ins = Instance(text=x, label=y) | |||
data.append(ins) | |||
# use vocabulary to index data | |||
data.index_field("text", vocab) | |||
# define naive sampler for batch class | |||
class SeqSampler: | |||
def __call__(self, dataset): | |||
return list(range(len(dataset))) | |||
# use bacth to iterate dataset | |||
batcher = Batch(data, SeqSampler(), 2) | |||
TRUE_X = [{'text': torch.tensor([[0, 1, 2, 3, 0, 0, 0], [4, 5, 2, 6, 7, 8, 9]])}, {'text': torch.tensor([[10]])}] | |||
TRUE_Y = [{'label': torch.tensor([[0], [1]])}, {'label': torch.tensor([[0]])}] | |||
for epoch in range(3): | |||
test_x, test_y = [], [] | |||
for batch_x, batch_y in batcher: | |||
test_x.append(batch_x) | |||
test_y.append(batch_y) | |||
self.check_batched_data_equal(TRUE_X, test_x) | |||
self.check_batched_data_equal(TRUE_Y, test_y) | |||
if __name__ == "__main__": | |||
unittest.main() | |||
@@ -0,0 +1,35 @@ | |||
import os | |||
import sys | |||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | |||
import unittest | |||
from fastNLP.data.vocabulary import Vocabulary, DEFAULT_WORD_TO_INDEX | |||
class TestVocabulary(unittest.TestCase): | |||
def test_vocab(self): | |||
import _pickle as pickle | |||
import os | |||
vocab = Vocabulary() | |||
filename = 'vocab' | |||
vocab.update(filename) | |||
vocab.update([filename, ['a'], [['b']], ['c']]) | |||
idx = vocab[filename] | |||
before_pic = (vocab.to_word(idx), vocab[filename]) | |||
with open(filename, 'wb') as f: | |||
pickle.dump(vocab, f) | |||
with open(filename, 'rb') as f: | |||
vocab = pickle.load(f) | |||
os.remove(filename) | |||
vocab.build_reverse_vocab() | |||
after_pic = (vocab.to_word(idx), vocab[filename]) | |||
TRUE_DICT = {'vocab': 5, 'a': 6, 'b': 7, 'c': 8} | |||
TRUE_DICT.update(DEFAULT_WORD_TO_INDEX) | |||
TRUE_IDXDICT = {0: '<pad>', 1: '<unk>', 2: '<reserved-2>', 3: '<reserved-3>', 4: '<reserved-4>', 5: 'vocab', 6: 'a', 7: 'b', 8: 'c'} | |||
self.assertEqual(before_pic, after_pic) | |||
self.assertDictEqual(TRUE_DICT, vocab.word2idx) | |||
self.assertDictEqual(TRUE_IDXDICT, vocab.idx2word) | |||
if __name__ == '__main__': | |||
unittest.main() |