From 466f3c21ec7c07de98ae9979f2261e921260867a Mon Sep 17 00:00:00 2001 From: yunfan Date: Sat, 15 Sep 2018 19:53:48 +0800 Subject: [PATCH 1/5] add vocabulary --- fastNLP/data/vocabulary.py | 99 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) create mode 100644 fastNLP/data/vocabulary.py diff --git a/fastNLP/data/vocabulary.py b/fastNLP/data/vocabulary.py new file mode 100644 index 00000000..3cff161b --- /dev/null +++ b/fastNLP/data/vocabulary.py @@ -0,0 +1,99 @@ +from copy import deepcopy + +DEFAULT_PADDING_LABEL = '' # dict index = 0 +DEFAULT_UNKNOWN_LABEL = '' # dict index = 1 +DEFAULT_RESERVED_LABEL = ['', + '', + ''] # dict index = 2~4 + +DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1, + DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3, + DEFAULT_RESERVED_LABEL[2]: 4} + +def isiterable(p_object): + try: + it = iter(p_object) + except TypeError: + return False + return True + +class Vocabulary(object): + def __init__(self): + self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX) + self.padding_label = DEFAULT_PADDING_LABEL + self.unknown_label = DEFAULT_UNKNOWN_LABEL + self.idx2word = None + + def __len__(self): + return len(self.word2idx) + + def update(self, word): + """add word or list of words into Vocabulary + """ + if not isinstance(word, str) and isiterable(word): + # it's a nested list + for w in word: + self.update(w) + else: + # it's a word to be added + self.word2idx[word] = len(self) + if self.idx2word is not None: + self.idx2word = None + + + def __getitem__(self, w): + """ like to_index(w) function, turn a word to the index + if w is not in Vocabulary, return the unknown label + """ + if w in self.word2idx: + return self.word2idx[w] + else: + return self.word2idx[DEFAULT_UNKNOWN_LABEL] + + def unknown_idx(self): + return self.word2idx[self.unknown_label] + + def padding_idx(self): + return self.word2idx[self.padding_label] + + def build_reverse_vocab(self): + self.idx2word = {self.word2idx[w] : w for w in self.word2idx} + + def to_word(self, idx): + """given a word's index, return the word itself + """ + if self.idx2word is None: + self.build_reverse_vocab() + return self.idx2word[idx] + + def __getstate__(self): + """use to prepare data for pickle + """ + state = self.__dict__.copy() + # no need to pickle idx2word as it can be constructed from word2idx + del state['idx2word'] + return state + + def __setstate__(self, state): + """use to restore state from pickle + """ + self.__dict__.update(state) + self.idx2word = None + +if __name__ == '__main__': + import _pickle as pickle + vocab = Vocabulary() + filename = 'vocab' + vocab.update(filename) + vocab.update([filename, ['a'], [['b']], ['c']]) + idx = vocab[filename] + print('{} {}'.format(vocab.to_word(idx), vocab[filename])) + + with open(filename, 'wb') as f: + pickle.dump(vocab, f) + with open(filename, 'rb') as f: + vocab = pickle.load(f) + + print('{} {}'.format(vocab.to_word(idx), vocab[filename])) + print(vocab.word2idx) + \ No newline at end of file From 3f4544759ddfd4569be034de811b366f1a6bb3cf Mon Sep 17 00:00:00 2001 From: yunfan Date: Sat, 15 Sep 2018 20:39:51 +0800 Subject: [PATCH 2/5] add unittest of data, fix bug --- fastNLP/{data => core}/vocabulary.py | 22 ++------- test/core/test_field.py | 69 ++++++++++++++++++++++++++++ test/core/test_vocab.py | 35 ++++++++++++++ 3 files changed, 108 insertions(+), 18 deletions(-) rename fastNLP/{data => core}/vocabulary.py (79%) create mode 100644 test/core/test_field.py create mode 100644 test/core/test_vocab.py diff --git a/fastNLP/data/vocabulary.py b/fastNLP/core/vocabulary.py similarity index 79% rename from fastNLP/data/vocabulary.py rename to fastNLP/core/vocabulary.py index 3cff161b..baae3753 100644 --- a/fastNLP/data/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -36,9 +36,10 @@ class Vocabulary(object): self.update(w) else: # it's a word to be added - self.word2idx[word] = len(self) - if self.idx2word is not None: - self.idx2word = None + if word not in self.word2idx: + self.word2idx[word] = len(self) + if self.idx2word is not None: + self.idx2word = None def __getitem__(self, w): @@ -80,20 +81,5 @@ class Vocabulary(object): self.__dict__.update(state) self.idx2word = None -if __name__ == '__main__': - import _pickle as pickle - vocab = Vocabulary() - filename = 'vocab' - vocab.update(filename) - vocab.update([filename, ['a'], [['b']], ['c']]) - idx = vocab[filename] - print('{} {}'.format(vocab.to_word(idx), vocab[filename])) - with open(filename, 'wb') as f: - pickle.dump(vocab, f) - with open(filename, 'rb') as f: - vocab = pickle.load(f) - - print('{} {}'.format(vocab.to_word(idx), vocab[filename])) - print(vocab.word2idx) \ No newline at end of file diff --git a/test/core/test_field.py b/test/core/test_field.py new file mode 100644 index 00000000..7c1b6343 --- /dev/null +++ b/test/core/test_field.py @@ -0,0 +1,69 @@ +import os +import sys +sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) + +import unittest +import torch +from fastNLP.data.field import TextField, LabelField +from fastNLP.data.instance import Instance +from fastNLP.data.dataset import DataSet +from fastNLP.data.batch import Batch + + + +class TestField(unittest.TestCase): + def check_batched_data_equal(self, data1, data2): + self.assertEqual(len(data1), len(data2)) + for i in range(len(data1)): + self.assertTrue(data1[i].keys(), data2[i].keys()) + for i in range(len(data1)): + for t1, t2 in zip(data1[i].values(), data2[i].values()): + self.assertTrue(torch.equal(t1, t2)) + + def test_batchiter(self): + texts = [ + "i am a cat", + "this is a test of new batch", + "haha" + ] + labels = [0, 1, 0] + + # prepare vocabulary + vocab = {} + for text in texts: + for tokens in text.split(): + if tokens not in vocab: + vocab[tokens] = len(vocab) + + # prepare input dataset + data = DataSet() + for text, label in zip(texts, labels): + x = TextField(text.split(), False) + y = LabelField(label, is_target=True) + ins = Instance(text=x, label=y) + data.append(ins) + + # use vocabulary to index data + data.index_field("text", vocab) + + # define naive sampler for batch class + class SeqSampler: + def __call__(self, dataset): + return list(range(len(dataset))) + + # use bacth to iterate dataset + batcher = Batch(data, SeqSampler(), 2) + TRUE_X = [{'text': torch.tensor([[0, 1, 2, 3, 0, 0, 0], [4, 5, 2, 6, 7, 8, 9]])}, {'text': torch.tensor([[10]])}] + TRUE_Y = [{'label': torch.tensor([[0], [1]])}, {'label': torch.tensor([[0]])}] + for epoch in range(3): + test_x, test_y = [], [] + for batch_x, batch_y in batcher: + test_x.append(batch_x) + test_y.append(batch_y) + self.check_batched_data_equal(TRUE_X, test_x) + self.check_batched_data_equal(TRUE_Y, test_y) + + +if __name__ == "__main__": + unittest.main() + \ No newline at end of file diff --git a/test/core/test_vocab.py b/test/core/test_vocab.py new file mode 100644 index 00000000..dd51c197 --- /dev/null +++ b/test/core/test_vocab.py @@ -0,0 +1,35 @@ +import os +import sys +sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) + +import unittest +from fastNLP.data.vocabulary import Vocabulary, DEFAULT_WORD_TO_INDEX + +class TestVocabulary(unittest.TestCase): + def test_vocab(self): + import _pickle as pickle + import os + vocab = Vocabulary() + filename = 'vocab' + vocab.update(filename) + vocab.update([filename, ['a'], [['b']], ['c']]) + idx = vocab[filename] + before_pic = (vocab.to_word(idx), vocab[filename]) + + with open(filename, 'wb') as f: + pickle.dump(vocab, f) + with open(filename, 'rb') as f: + vocab = pickle.load(f) + os.remove(filename) + + vocab.build_reverse_vocab() + after_pic = (vocab.to_word(idx), vocab[filename]) + TRUE_DICT = {'vocab': 5, 'a': 6, 'b': 7, 'c': 8} + TRUE_DICT.update(DEFAULT_WORD_TO_INDEX) + TRUE_IDXDICT = {0: '', 1: '', 2: '', 3: '', 4: '', 5: 'vocab', 6: 'a', 7: 'b', 8: 'c'} + self.assertEqual(before_pic, after_pic) + self.assertDictEqual(TRUE_DICT, vocab.word2idx) + self.assertDictEqual(TRUE_IDXDICT, vocab.idx2word) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From 9c7f3cf26125234a785d9383839df1ee6779905f Mon Sep 17 00:00:00 2001 From: yunfan Date: Tue, 18 Sep 2018 16:43:56 +0800 Subject: [PATCH 3/5] add vocabulary into preprocessor --- fastNLP/core/preprocess.py | 80 ++++++++++++-------------------------- fastNLP/core/vocabulary.py | 29 +++++++++++++- test/core/test_field.py | 69 -------------------------------- test/core/test_vocab.py | 6 +-- 4 files changed, 53 insertions(+), 131 deletions(-) delete mode 100644 test/core/test_field.py diff --git a/fastNLP/core/preprocess.py b/fastNLP/core/preprocess.py index b5d348e6..2671b4f4 100644 --- a/fastNLP/core/preprocess.py +++ b/fastNLP/core/preprocess.py @@ -6,16 +6,7 @@ import numpy as np from fastNLP.core.dataset import DataSet from fastNLP.core.field import TextField, LabelField from fastNLP.core.instance import Instance - -DEFAULT_PADDING_LABEL = '' # dict index = 0 -DEFAULT_UNKNOWN_LABEL = '' # dict index = 1 -DEFAULT_RESERVED_LABEL = ['', - '', - ''] # dict index = 2~4 - -DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1, - DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3, - DEFAULT_RESERVED_LABEL[2]: 4} +from fastNLP.core.vocabulary import Vocabulary # the first vocab in dict with the index = 5 @@ -68,24 +59,22 @@ class BasePreprocess(object): - "word2id.pkl", a mapping from words(tokens) to indices - "id2word.pkl", a reversed dictionary - - "label2id.pkl", a dictionary on labels - - "id2label.pkl", a reversed dictionary on labels These four pickle files are expected to be saved in the given pickle directory once they are constructed. Preprocessors will check if those files are already in the directory and will reuse them in future calls. """ def __init__(self): - self.word2index = None - self.label2index = None + self.data_vocab = Vocabulary() + self.label_vocab = Vocabulary() @property def vocab_size(self): - return len(self.word2index) + return len(self.data_vocab) @property def num_classes(self): - return len(self.label2index) + return len(self.label_vocab) def run(self, train_dev_data, test_data=None, pickle_path="./", train_dev_split=0, cross_val=False, n_fold=10): """Main pre-processing pipeline. @@ -102,20 +91,14 @@ class BasePreprocess(object): """ if pickle_exist(pickle_path, "word2id.pkl") and pickle_exist(pickle_path, "class2id.pkl"): - self.word2index = load_pickle(pickle_path, "word2id.pkl") - self.label2index = load_pickle(pickle_path, "class2id.pkl") + self.data_vocab = load_pickle(pickle_path, "word2id.pkl") + self.label_vocab = load_pickle(pickle_path, "class2id.pkl") else: - self.word2index, self.label2index = self.build_dict(train_dev_data) - save_pickle(self.word2index, pickle_path, "word2id.pkl") - save_pickle(self.label2index, pickle_path, "class2id.pkl") - - if not pickle_exist(pickle_path, "id2word.pkl"): - index2word = self.build_reverse_dict(self.word2index) - save_pickle(index2word, pickle_path, "id2word.pkl") + self.data_vocab, self.label_vocab = self.build_dict(train_dev_data) + save_pickle(self.data_vocab, pickle_path, "word2id.pkl") + save_pickle(self.label_vocab, pickle_path, "class2id.pkl") - if not pickle_exist(pickle_path, "id2class.pkl"): - index2label = self.build_reverse_dict(self.label2index) - save_pickle(index2label, pickle_path, "id2class.pkl") + self.build_reverse_dict() train_set = [] dev_set = [] @@ -125,13 +108,13 @@ class BasePreprocess(object): split = int(len(train_dev_data) * train_dev_split) data_dev = train_dev_data[: split] data_train = train_dev_data[split:] - train_set = self.convert_to_dataset(data_train, self.word2index, self.label2index) - dev_set = self.convert_to_dataset(data_dev, self.word2index, self.label2index) + train_set = self.convert_to_dataset(data_train, self.data_vocab, self.label_vocab) + dev_set = self.convert_to_dataset(data_dev, self.data_vocab, self.label_vocab) save_pickle(dev_set, pickle_path, "data_dev.pkl") print("{} of the training data is split for validation. ".format(train_dev_split)) else: - train_set = self.convert_to_dataset(train_dev_data, self.word2index, self.label2index) + train_set = self.convert_to_dataset(train_dev_data, self.data_vocab, self.label_vocab) save_pickle(train_set, pickle_path, "data_train.pkl") else: train_set = load_pickle(pickle_path, "data_train.pkl") @@ -143,8 +126,8 @@ class BasePreprocess(object): # cross validation data_cv = self.cv_split(train_dev_data, n_fold) for i, (data_train_cv, data_dev_cv) in enumerate(data_cv): - data_train_cv = self.convert_to_dataset(data_train_cv, self.word2index, self.label2index) - data_dev_cv = self.convert_to_dataset(data_dev_cv, self.word2index, self.label2index) + data_train_cv = self.convert_to_dataset(data_train_cv, self.data_vocab, self.label_vocab) + data_dev_cv = self.convert_to_dataset(data_dev_cv, self.data_vocab, self.label_vocab) save_pickle( data_train_cv, pickle_path, "data_train_{}.pkl".format(i)) @@ -165,7 +148,7 @@ class BasePreprocess(object): test_set = [] if test_data is not None: if not pickle_exist(pickle_path, "data_test.pkl"): - test_set = self.convert_to_dataset(test_data, self.word2index, self.label2index) + test_set = self.convert_to_dataset(test_data, self.data_vocab, self.label_vocab) save_pickle(test_set, pickle_path, "data_test.pkl") # return preprocessed results @@ -180,28 +163,15 @@ class BasePreprocess(object): return tuple(results) def build_dict(self, data): - label2index = DEFAULT_WORD_TO_INDEX.copy() - word2index = DEFAULT_WORD_TO_INDEX.copy() for example in data: - for word in example[0]: - if word not in word2index: - word2index[word] = len(word2index) - label = example[1] - if isinstance(label, str): - # label is a string - if label not in label2index: - label2index[label] = len(label2index) - elif isinstance(label, list): - # label is a list of strings - for single_label in label: - if single_label not in label2index: - label2index[single_label] = len(label2index) - return word2index, label2index - - - def build_reverse_dict(self, word_dict): - id2word = {word_dict[w]: w for w in word_dict} - return id2word + word, label = example + self.data_vocab.update(word) + self.label_vocab.update(label) + return self.data_vocab, self.label_vocab + + def build_reverse_dict(self): + self.data_vocab.build_reverse_vocab() + self.label_vocab.build_reverse_vocab() def data_split(self, data, train_dev_split): """Split data into train and dev set.""" diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index baae3753..79b70939 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -18,6 +18,16 @@ def isiterable(p_object): return True class Vocabulary(object): + """Use for word and index one to one mapping + + Example:: + + vocab = Vocabulary() + word_list = "this is a word list".split() + vocab.update(word_list) + vocab["word"] + vocab.to_word(5) + """ def __init__(self): self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX) self.padding_label = DEFAULT_PADDING_LABEL @@ -29,6 +39,8 @@ class Vocabulary(object): def update(self, word): """add word or list of words into Vocabulary + + :param word: a list of str or str """ if not isinstance(word, str) and isiterable(word): # it's a nested list @@ -43,13 +55,22 @@ class Vocabulary(object): def __getitem__(self, w): - """ like to_index(w) function, turn a word to the index - if w is not in Vocabulary, return the unknown label + """To support usage like:: + + vocab[w] """ if w in self.word2idx: return self.word2idx[w] else: return self.word2idx[DEFAULT_UNKNOWN_LABEL] + + def to_index(self, w): + """ like to_index(w) function, turn a word to the index + if w is not in Vocabulary, return the unknown label + + :param str w: + """ + return self[w] def unknown_idx(self): return self.word2idx[self.unknown_label] @@ -58,10 +79,14 @@ class Vocabulary(object): return self.word2idx[self.padding_label] def build_reverse_vocab(self): + """build 'index to word' dict based on 'word to index' dict + """ self.idx2word = {self.word2idx[w] : w for w in self.word2idx} def to_word(self, idx): """given a word's index, return the word itself + + :param int idx: """ if self.idx2word is None: self.build_reverse_vocab() diff --git a/test/core/test_field.py b/test/core/test_field.py deleted file mode 100644 index 7c1b6343..00000000 --- a/test/core/test_field.py +++ /dev/null @@ -1,69 +0,0 @@ -import os -import sys -sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) - -import unittest -import torch -from fastNLP.data.field import TextField, LabelField -from fastNLP.data.instance import Instance -from fastNLP.data.dataset import DataSet -from fastNLP.data.batch import Batch - - - -class TestField(unittest.TestCase): - def check_batched_data_equal(self, data1, data2): - self.assertEqual(len(data1), len(data2)) - for i in range(len(data1)): - self.assertTrue(data1[i].keys(), data2[i].keys()) - for i in range(len(data1)): - for t1, t2 in zip(data1[i].values(), data2[i].values()): - self.assertTrue(torch.equal(t1, t2)) - - def test_batchiter(self): - texts = [ - "i am a cat", - "this is a test of new batch", - "haha" - ] - labels = [0, 1, 0] - - # prepare vocabulary - vocab = {} - for text in texts: - for tokens in text.split(): - if tokens not in vocab: - vocab[tokens] = len(vocab) - - # prepare input dataset - data = DataSet() - for text, label in zip(texts, labels): - x = TextField(text.split(), False) - y = LabelField(label, is_target=True) - ins = Instance(text=x, label=y) - data.append(ins) - - # use vocabulary to index data - data.index_field("text", vocab) - - # define naive sampler for batch class - class SeqSampler: - def __call__(self, dataset): - return list(range(len(dataset))) - - # use bacth to iterate dataset - batcher = Batch(data, SeqSampler(), 2) - TRUE_X = [{'text': torch.tensor([[0, 1, 2, 3, 0, 0, 0], [4, 5, 2, 6, 7, 8, 9]])}, {'text': torch.tensor([[10]])}] - TRUE_Y = [{'label': torch.tensor([[0], [1]])}, {'label': torch.tensor([[0]])}] - for epoch in range(3): - test_x, test_y = [], [] - for batch_x, batch_y in batcher: - test_x.append(batch_x) - test_y.append(batch_y) - self.check_batched_data_equal(TRUE_X, test_x) - self.check_batched_data_equal(TRUE_Y, test_y) - - -if __name__ == "__main__": - unittest.main() - \ No newline at end of file diff --git a/test/core/test_vocab.py b/test/core/test_vocab.py index dd51c197..89b0691a 100644 --- a/test/core/test_vocab.py +++ b/test/core/test_vocab.py @@ -1,9 +1,5 @@ -import os -import sys -sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) - import unittest -from fastNLP.data.vocabulary import Vocabulary, DEFAULT_WORD_TO_INDEX +from fastNLP.core.vocabulary import Vocabulary, DEFAULT_WORD_TO_INDEX class TestVocabulary(unittest.TestCase): def test_vocab(self): From 819c8f05bed3d47d1c85ae4e44643ab24432f240 Mon Sep 17 00:00:00 2001 From: yunfan Date: Wed, 19 Sep 2018 14:49:10 +0800 Subject: [PATCH 4/5] fix vocab --- fastNLP/core/predictor.py | 10 +++++----- fastNLP/fastnlp.py | 10 +++++----- reproduction/chinese_word_segment/run.py | 4 ++-- reproduction/pos_tag_model/train_pos_tag.py | 4 ++-- test/core/test_predictor.py | 9 +++++++-- test/model/seq_labeling.py | 2 +- test/model/test_cws.py | 2 +- 7 files changed, 23 insertions(+), 18 deletions(-) diff --git a/fastNLP/core/predictor.py b/fastNLP/core/predictor.py index 802661ef..c83b2069 100644 --- a/fastNLP/core/predictor.py +++ b/fastNLP/core/predictor.py @@ -27,8 +27,8 @@ class Predictor(object): self.batch_output = [] self.pickle_path = pickle_path self._task = task # one of ("seq_label", "text_classify") - self.index2label = load_pickle(self.pickle_path, "id2class.pkl") - self.word2index = load_pickle(self.pickle_path, "word2id.pkl") + self.label_vocab = load_pickle(self.pickle_path, "class2id.pkl") + self.word_vocab = load_pickle(self.pickle_path, "word2id.pkl") def predict(self, network, data): """Perform inference using the trained model. @@ -82,7 +82,7 @@ class Predictor(object): :return data_set: a DataSet instance. """ assert isinstance(data, list) - return create_dataset_from_lists(data, self.word2index, has_target=False) + return create_dataset_from_lists(data, self.word_vocab, has_target=False) def prepare_output(self, data): """Transform list of batch outputs into strings.""" @@ -97,14 +97,14 @@ class Predictor(object): results = [] for batch in batch_outputs: for example in np.array(batch): - results.append([self.index2label[int(x)] for x in example]) + results.append([self.label_vocab.to_word(int(x)) for x in example]) return results def _text_classify_prepare_output(self, batch_outputs): results = [] for batch_out in batch_outputs: idx = np.argmax(batch_out.detach().numpy(), axis=-1) - results.extend([self.index2label[i] for i in idx]) + results.extend([self.label_vocab.to_word(i) for i in idx]) return results diff --git a/fastNLP/fastnlp.py b/fastNLP/fastnlp.py index c76e6681..e683950d 100644 --- a/fastNLP/fastnlp.py +++ b/fastNLP/fastnlp.py @@ -69,7 +69,7 @@ class FastNLP(object): :param model_dir: this directory should contain the following files: 1. a pre-trained model 2. a config file - 3. "id2class.pkl" + 3. "class2id.pkl" 4. "word2id.pkl" """ self.model_dir = model_dir @@ -99,10 +99,10 @@ class FastNLP(object): print("Restore model hyper-parameters {}".format(str(model_args.data))) # fetch dictionary size and number of labels from pickle files - word2index = load_pickle(self.model_dir, "word2id.pkl") - model_args["vocab_size"] = len(word2index) - index2label = load_pickle(self.model_dir, "id2class.pkl") - model_args["num_classes"] = len(index2label) + word_vocab = load_pickle(self.model_dir, "word2id.pkl") + model_args["vocab_size"] = len(word_vocab) + label_vocab = load_pickle(self.model_dir, "class2id.pkl") + model_args["num_classes"] = len(label_vocab) # Construct the model model = model_class(model_args) diff --git a/reproduction/chinese_word_segment/run.py b/reproduction/chinese_word_segment/run.py index d0a22e84..0d5ae8c1 100644 --- a/reproduction/chinese_word_segment/run.py +++ b/reproduction/chinese_word_segment/run.py @@ -32,7 +32,7 @@ def infer(): # fetch dictionary size and number of labels from pickle files word2index = load_pickle(pickle_path, "word2id.pkl") test_args["vocab_size"] = len(word2index) - index2label = load_pickle(pickle_path, "id2class.pkl") + index2label = load_pickle(pickle_path, "class2id.pkl") test_args["num_classes"] = len(index2label) @@ -105,7 +105,7 @@ def test(): # fetch dictionary size and number of labels from pickle files word2index = load_pickle(pickle_path, "word2id.pkl") test_args["vocab_size"] = len(word2index) - index2label = load_pickle(pickle_path, "id2class.pkl") + index2label = load_pickle(pickle_path, "class2id.pkl") test_args["num_classes"] = len(index2label) # load dev data diff --git a/reproduction/pos_tag_model/train_pos_tag.py b/reproduction/pos_tag_model/train_pos_tag.py index 87a9f7e8..15164130 100644 --- a/reproduction/pos_tag_model/train_pos_tag.py +++ b/reproduction/pos_tag_model/train_pos_tag.py @@ -33,7 +33,7 @@ def infer(): # fetch dictionary size and number of labels from pickle files word2index = load_pickle(pickle_path, "word2id.pkl") test_args["vocab_size"] = len(word2index) - index2label = load_pickle(pickle_path, "id2class.pkl") + index2label = load_pickle(pickle_path, "class2id.pkl") test_args["num_classes"] = len(index2label) # Define the same model @@ -105,7 +105,7 @@ def test(): # fetch dictionary size and number of labels from pickle files word2index = load_pickle(pickle_path, "word2id.pkl") test_args["vocab_size"] = len(word2index) - index2label = load_pickle(pickle_path, "id2class.pkl") + index2label = load_pickle(pickle_path, "class2id.pkl") test_args["num_classes"] = len(index2label) # load dev data diff --git a/test/core/test_predictor.py b/test/core/test_predictor.py index c7ad65d7..411f636e 100644 --- a/test/core/test_predictor.py +++ b/test/core/test_predictor.py @@ -4,6 +4,7 @@ import unittest from fastNLP.core.predictor import Predictor from fastNLP.core.preprocess import save_pickle from fastNLP.models.sequence_modeling import SeqLabeling +from fastNLP.core.vocabulary import Vocabulary class TestPredictor(unittest.TestCase): @@ -23,10 +24,14 @@ class TestPredictor(unittest.TestCase): ['a', 'b', 'c', 'd', '$'], ['!', 'b', 'c', 'd', 'e'] ] - vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} + + vocab = Vocabulary() + vocab.word2idx = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} + class_vocab = Vocabulary() + class_vocab.word2idx = {"0":0, "1":1, "2":2, "3":3, "4":4} os.system("mkdir save") - save_pickle({0: "0", 1: "1", 2: "2", 3: "3", 4: "4"}, "./save/", "id2class.pkl") + save_pickle(class_vocab, "./save/", "class2id.pkl") save_pickle(vocab, "./save/", "word2id.pkl") model = SeqLabeling(model_args) diff --git a/test/model/seq_labeling.py b/test/model/seq_labeling.py index d7750b17..cd011c0d 100644 --- a/test/model/seq_labeling.py +++ b/test/model/seq_labeling.py @@ -38,7 +38,7 @@ def infer(): # fetch dictionary size and number of labels from pickle files word2index = load_pickle(pickle_path, "word2id.pkl") test_args["vocab_size"] = len(word2index) - index2label = load_pickle(pickle_path, "id2class.pkl") + index2label = load_pickle(pickle_path, "class2id.pkl") test_args["num_classes"] = len(index2label) # Define the same model diff --git a/test/model/test_cws.py b/test/model/test_cws.py index 802d97ba..70716c3a 100644 --- a/test/model/test_cws.py +++ b/test/model/test_cws.py @@ -27,7 +27,7 @@ def infer(): # fetch dictionary size and number of labels from pickle files word2index = load_pickle(pickle_path, "word2id.pkl") test_args["vocab_size"] = len(word2index) - index2label = load_pickle(pickle_path, "id2class.pkl") + index2label = load_pickle(pickle_path, "class2id.pkl") test_args["num_classes"] = len(index2label) # Define the same model From e8cc702737f93afc285a90051f29788b40847523 Mon Sep 17 00:00:00 2001 From: yunfan Date: Thu, 20 Sep 2018 15:11:01 +0800 Subject: [PATCH 5/5] add default switch --- fastNLP/core/vocabulary.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index 79b70939..ad618ff9 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -28,15 +28,25 @@ class Vocabulary(object): vocab["word"] vocab.to_word(5) """ - def __init__(self): - self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX) - self.padding_label = DEFAULT_PADDING_LABEL - self.unknown_label = DEFAULT_UNKNOWN_LABEL + def __init__(self, need_default=True): + """ + :param bool need_default: set if the Vocabulary has default labels reserved. + """ + if need_default: + self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX) + self.padding_label = DEFAULT_PADDING_LABEL + self.unknown_label = DEFAULT_UNKNOWN_LABEL + else: + self.word2idx = {} + self.padding_label = None + self.unknown_label = None + + self.has_default = need_default self.idx2word = None def __len__(self): return len(self.word2idx) - + def update(self, word): """add word or list of words into Vocabulary @@ -73,9 +83,13 @@ class Vocabulary(object): return self[w] def unknown_idx(self): + if self.unknown_label is None: + return None return self.word2idx[self.unknown_label] def padding_idx(self): + if self.padding_label is None: + return None return self.word2idx[self.padding_label] def build_reverse_vocab(self):