diff --git a/fastNLP/core/predictor.py b/fastNLP/core/predictor.py index 802661ef..c83b2069 100644 --- a/fastNLP/core/predictor.py +++ b/fastNLP/core/predictor.py @@ -27,8 +27,8 @@ class Predictor(object): self.batch_output = [] self.pickle_path = pickle_path self._task = task # one of ("seq_label", "text_classify") - self.index2label = load_pickle(self.pickle_path, "id2class.pkl") - self.word2index = load_pickle(self.pickle_path, "word2id.pkl") + self.label_vocab = load_pickle(self.pickle_path, "class2id.pkl") + self.word_vocab = load_pickle(self.pickle_path, "word2id.pkl") def predict(self, network, data): """Perform inference using the trained model. @@ -82,7 +82,7 @@ class Predictor(object): :return data_set: a DataSet instance. """ assert isinstance(data, list) - return create_dataset_from_lists(data, self.word2index, has_target=False) + return create_dataset_from_lists(data, self.word_vocab, has_target=False) def prepare_output(self, data): """Transform list of batch outputs into strings.""" @@ -97,14 +97,14 @@ class Predictor(object): results = [] for batch in batch_outputs: for example in np.array(batch): - results.append([self.index2label[int(x)] for x in example]) + results.append([self.label_vocab.to_word(int(x)) for x in example]) return results def _text_classify_prepare_output(self, batch_outputs): results = [] for batch_out in batch_outputs: idx = np.argmax(batch_out.detach().numpy(), axis=-1) - results.extend([self.index2label[i] for i in idx]) + results.extend([self.label_vocab.to_word(i) for i in idx]) return results diff --git a/fastNLP/core/preprocess.py b/fastNLP/core/preprocess.py index b5d348e6..2671b4f4 100644 --- a/fastNLP/core/preprocess.py +++ b/fastNLP/core/preprocess.py @@ -6,16 +6,7 @@ import numpy as np from fastNLP.core.dataset import DataSet from fastNLP.core.field import TextField, LabelField from fastNLP.core.instance import Instance - -DEFAULT_PADDING_LABEL = '' # dict index = 0 -DEFAULT_UNKNOWN_LABEL = '' # dict index = 1 -DEFAULT_RESERVED_LABEL = ['', - '', - ''] # dict index = 2~4 - -DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1, - DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3, - DEFAULT_RESERVED_LABEL[2]: 4} +from fastNLP.core.vocabulary import Vocabulary # the first vocab in dict with the index = 5 @@ -68,24 +59,22 @@ class BasePreprocess(object): - "word2id.pkl", a mapping from words(tokens) to indices - "id2word.pkl", a reversed dictionary - - "label2id.pkl", a dictionary on labels - - "id2label.pkl", a reversed dictionary on labels These four pickle files are expected to be saved in the given pickle directory once they are constructed. Preprocessors will check if those files are already in the directory and will reuse them in future calls. """ def __init__(self): - self.word2index = None - self.label2index = None + self.data_vocab = Vocabulary() + self.label_vocab = Vocabulary() @property def vocab_size(self): - return len(self.word2index) + return len(self.data_vocab) @property def num_classes(self): - return len(self.label2index) + return len(self.label_vocab) def run(self, train_dev_data, test_data=None, pickle_path="./", train_dev_split=0, cross_val=False, n_fold=10): """Main pre-processing pipeline. @@ -102,20 +91,14 @@ class BasePreprocess(object): """ if pickle_exist(pickle_path, "word2id.pkl") and pickle_exist(pickle_path, "class2id.pkl"): - self.word2index = load_pickle(pickle_path, "word2id.pkl") - self.label2index = load_pickle(pickle_path, "class2id.pkl") + self.data_vocab = load_pickle(pickle_path, "word2id.pkl") + self.label_vocab = load_pickle(pickle_path, "class2id.pkl") else: - self.word2index, self.label2index = self.build_dict(train_dev_data) - save_pickle(self.word2index, pickle_path, "word2id.pkl") - save_pickle(self.label2index, pickle_path, "class2id.pkl") - - if not pickle_exist(pickle_path, "id2word.pkl"): - index2word = self.build_reverse_dict(self.word2index) - save_pickle(index2word, pickle_path, "id2word.pkl") + self.data_vocab, self.label_vocab = self.build_dict(train_dev_data) + save_pickle(self.data_vocab, pickle_path, "word2id.pkl") + save_pickle(self.label_vocab, pickle_path, "class2id.pkl") - if not pickle_exist(pickle_path, "id2class.pkl"): - index2label = self.build_reverse_dict(self.label2index) - save_pickle(index2label, pickle_path, "id2class.pkl") + self.build_reverse_dict() train_set = [] dev_set = [] @@ -125,13 +108,13 @@ class BasePreprocess(object): split = int(len(train_dev_data) * train_dev_split) data_dev = train_dev_data[: split] data_train = train_dev_data[split:] - train_set = self.convert_to_dataset(data_train, self.word2index, self.label2index) - dev_set = self.convert_to_dataset(data_dev, self.word2index, self.label2index) + train_set = self.convert_to_dataset(data_train, self.data_vocab, self.label_vocab) + dev_set = self.convert_to_dataset(data_dev, self.data_vocab, self.label_vocab) save_pickle(dev_set, pickle_path, "data_dev.pkl") print("{} of the training data is split for validation. ".format(train_dev_split)) else: - train_set = self.convert_to_dataset(train_dev_data, self.word2index, self.label2index) + train_set = self.convert_to_dataset(train_dev_data, self.data_vocab, self.label_vocab) save_pickle(train_set, pickle_path, "data_train.pkl") else: train_set = load_pickle(pickle_path, "data_train.pkl") @@ -143,8 +126,8 @@ class BasePreprocess(object): # cross validation data_cv = self.cv_split(train_dev_data, n_fold) for i, (data_train_cv, data_dev_cv) in enumerate(data_cv): - data_train_cv = self.convert_to_dataset(data_train_cv, self.word2index, self.label2index) - data_dev_cv = self.convert_to_dataset(data_dev_cv, self.word2index, self.label2index) + data_train_cv = self.convert_to_dataset(data_train_cv, self.data_vocab, self.label_vocab) + data_dev_cv = self.convert_to_dataset(data_dev_cv, self.data_vocab, self.label_vocab) save_pickle( data_train_cv, pickle_path, "data_train_{}.pkl".format(i)) @@ -165,7 +148,7 @@ class BasePreprocess(object): test_set = [] if test_data is not None: if not pickle_exist(pickle_path, "data_test.pkl"): - test_set = self.convert_to_dataset(test_data, self.word2index, self.label2index) + test_set = self.convert_to_dataset(test_data, self.data_vocab, self.label_vocab) save_pickle(test_set, pickle_path, "data_test.pkl") # return preprocessed results @@ -180,28 +163,15 @@ class BasePreprocess(object): return tuple(results) def build_dict(self, data): - label2index = DEFAULT_WORD_TO_INDEX.copy() - word2index = DEFAULT_WORD_TO_INDEX.copy() for example in data: - for word in example[0]: - if word not in word2index: - word2index[word] = len(word2index) - label = example[1] - if isinstance(label, str): - # label is a string - if label not in label2index: - label2index[label] = len(label2index) - elif isinstance(label, list): - # label is a list of strings - for single_label in label: - if single_label not in label2index: - label2index[single_label] = len(label2index) - return word2index, label2index - - - def build_reverse_dict(self, word_dict): - id2word = {word_dict[w]: w for w in word_dict} - return id2word + word, label = example + self.data_vocab.update(word) + self.label_vocab.update(label) + return self.data_vocab, self.label_vocab + + def build_reverse_dict(self): + self.data_vocab.build_reverse_vocab() + self.label_vocab.build_reverse_vocab() def data_split(self, data, train_dev_split): """Split data into train and dev set.""" diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py new file mode 100644 index 00000000..ad618ff9 --- /dev/null +++ b/fastNLP/core/vocabulary.py @@ -0,0 +1,124 @@ +from copy import deepcopy + +DEFAULT_PADDING_LABEL = '' # dict index = 0 +DEFAULT_UNKNOWN_LABEL = '' # dict index = 1 +DEFAULT_RESERVED_LABEL = ['', + '', + ''] # dict index = 2~4 + +DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1, + DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3, + DEFAULT_RESERVED_LABEL[2]: 4} + +def isiterable(p_object): + try: + it = iter(p_object) + except TypeError: + return False + return True + +class Vocabulary(object): + """Use for word and index one to one mapping + + Example:: + + vocab = Vocabulary() + word_list = "this is a word list".split() + vocab.update(word_list) + vocab["word"] + vocab.to_word(5) + """ + def __init__(self, need_default=True): + """ + :param bool need_default: set if the Vocabulary has default labels reserved. + """ + if need_default: + self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX) + self.padding_label = DEFAULT_PADDING_LABEL + self.unknown_label = DEFAULT_UNKNOWN_LABEL + else: + self.word2idx = {} + self.padding_label = None + self.unknown_label = None + + self.has_default = need_default + self.idx2word = None + + def __len__(self): + return len(self.word2idx) + + def update(self, word): + """add word or list of words into Vocabulary + + :param word: a list of str or str + """ + if not isinstance(word, str) and isiterable(word): + # it's a nested list + for w in word: + self.update(w) + else: + # it's a word to be added + if word not in self.word2idx: + self.word2idx[word] = len(self) + if self.idx2word is not None: + self.idx2word = None + + + def __getitem__(self, w): + """To support usage like:: + + vocab[w] + """ + if w in self.word2idx: + return self.word2idx[w] + else: + return self.word2idx[DEFAULT_UNKNOWN_LABEL] + + def to_index(self, w): + """ like to_index(w) function, turn a word to the index + if w is not in Vocabulary, return the unknown label + + :param str w: + """ + return self[w] + + def unknown_idx(self): + if self.unknown_label is None: + return None + return self.word2idx[self.unknown_label] + + def padding_idx(self): + if self.padding_label is None: + return None + return self.word2idx[self.padding_label] + + def build_reverse_vocab(self): + """build 'index to word' dict based on 'word to index' dict + """ + self.idx2word = {self.word2idx[w] : w for w in self.word2idx} + + def to_word(self, idx): + """given a word's index, return the word itself + + :param int idx: + """ + if self.idx2word is None: + self.build_reverse_vocab() + return self.idx2word[idx] + + def __getstate__(self): + """use to prepare data for pickle + """ + state = self.__dict__.copy() + # no need to pickle idx2word as it can be constructed from word2idx + del state['idx2word'] + return state + + def __setstate__(self, state): + """use to restore state from pickle + """ + self.__dict__.update(state) + self.idx2word = None + + + \ No newline at end of file diff --git a/fastNLP/fastnlp.py b/fastNLP/fastnlp.py index c76e6681..e683950d 100644 --- a/fastNLP/fastnlp.py +++ b/fastNLP/fastnlp.py @@ -69,7 +69,7 @@ class FastNLP(object): :param model_dir: this directory should contain the following files: 1. a pre-trained model 2. a config file - 3. "id2class.pkl" + 3. "class2id.pkl" 4. "word2id.pkl" """ self.model_dir = model_dir @@ -99,10 +99,10 @@ class FastNLP(object): print("Restore model hyper-parameters {}".format(str(model_args.data))) # fetch dictionary size and number of labels from pickle files - word2index = load_pickle(self.model_dir, "word2id.pkl") - model_args["vocab_size"] = len(word2index) - index2label = load_pickle(self.model_dir, "id2class.pkl") - model_args["num_classes"] = len(index2label) + word_vocab = load_pickle(self.model_dir, "word2id.pkl") + model_args["vocab_size"] = len(word_vocab) + label_vocab = load_pickle(self.model_dir, "class2id.pkl") + model_args["num_classes"] = len(label_vocab) # Construct the model model = model_class(model_args) diff --git a/reproduction/chinese_word_segment/run.py b/reproduction/chinese_word_segment/run.py index d0a22e84..0d5ae8c1 100644 --- a/reproduction/chinese_word_segment/run.py +++ b/reproduction/chinese_word_segment/run.py @@ -32,7 +32,7 @@ def infer(): # fetch dictionary size and number of labels from pickle files word2index = load_pickle(pickle_path, "word2id.pkl") test_args["vocab_size"] = len(word2index) - index2label = load_pickle(pickle_path, "id2class.pkl") + index2label = load_pickle(pickle_path, "class2id.pkl") test_args["num_classes"] = len(index2label) @@ -105,7 +105,7 @@ def test(): # fetch dictionary size and number of labels from pickle files word2index = load_pickle(pickle_path, "word2id.pkl") test_args["vocab_size"] = len(word2index) - index2label = load_pickle(pickle_path, "id2class.pkl") + index2label = load_pickle(pickle_path, "class2id.pkl") test_args["num_classes"] = len(index2label) # load dev data diff --git a/reproduction/pos_tag_model/train_pos_tag.py b/reproduction/pos_tag_model/train_pos_tag.py index 87a9f7e8..15164130 100644 --- a/reproduction/pos_tag_model/train_pos_tag.py +++ b/reproduction/pos_tag_model/train_pos_tag.py @@ -33,7 +33,7 @@ def infer(): # fetch dictionary size and number of labels from pickle files word2index = load_pickle(pickle_path, "word2id.pkl") test_args["vocab_size"] = len(word2index) - index2label = load_pickle(pickle_path, "id2class.pkl") + index2label = load_pickle(pickle_path, "class2id.pkl") test_args["num_classes"] = len(index2label) # Define the same model @@ -105,7 +105,7 @@ def test(): # fetch dictionary size and number of labels from pickle files word2index = load_pickle(pickle_path, "word2id.pkl") test_args["vocab_size"] = len(word2index) - index2label = load_pickle(pickle_path, "id2class.pkl") + index2label = load_pickle(pickle_path, "class2id.pkl") test_args["num_classes"] = len(index2label) # load dev data diff --git a/test/core/test_predictor.py b/test/core/test_predictor.py index c7ad65d7..411f636e 100644 --- a/test/core/test_predictor.py +++ b/test/core/test_predictor.py @@ -4,6 +4,7 @@ import unittest from fastNLP.core.predictor import Predictor from fastNLP.core.preprocess import save_pickle from fastNLP.models.sequence_modeling import SeqLabeling +from fastNLP.core.vocabulary import Vocabulary class TestPredictor(unittest.TestCase): @@ -23,10 +24,14 @@ class TestPredictor(unittest.TestCase): ['a', 'b', 'c', 'd', '$'], ['!', 'b', 'c', 'd', 'e'] ] - vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} + + vocab = Vocabulary() + vocab.word2idx = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} + class_vocab = Vocabulary() + class_vocab.word2idx = {"0":0, "1":1, "2":2, "3":3, "4":4} os.system("mkdir save") - save_pickle({0: "0", 1: "1", 2: "2", 3: "3", 4: "4"}, "./save/", "id2class.pkl") + save_pickle(class_vocab, "./save/", "class2id.pkl") save_pickle(vocab, "./save/", "word2id.pkl") model = SeqLabeling(model_args) diff --git a/test/core/test_vocab.py b/test/core/test_vocab.py new file mode 100644 index 00000000..89b0691a --- /dev/null +++ b/test/core/test_vocab.py @@ -0,0 +1,31 @@ +import unittest +from fastNLP.core.vocabulary import Vocabulary, DEFAULT_WORD_TO_INDEX + +class TestVocabulary(unittest.TestCase): + def test_vocab(self): + import _pickle as pickle + import os + vocab = Vocabulary() + filename = 'vocab' + vocab.update(filename) + vocab.update([filename, ['a'], [['b']], ['c']]) + idx = vocab[filename] + before_pic = (vocab.to_word(idx), vocab[filename]) + + with open(filename, 'wb') as f: + pickle.dump(vocab, f) + with open(filename, 'rb') as f: + vocab = pickle.load(f) + os.remove(filename) + + vocab.build_reverse_vocab() + after_pic = (vocab.to_word(idx), vocab[filename]) + TRUE_DICT = {'vocab': 5, 'a': 6, 'b': 7, 'c': 8} + TRUE_DICT.update(DEFAULT_WORD_TO_INDEX) + TRUE_IDXDICT = {0: '', 1: '', 2: '', 3: '', 4: '', 5: 'vocab', 6: 'a', 7: 'b', 8: 'c'} + self.assertEqual(before_pic, after_pic) + self.assertDictEqual(TRUE_DICT, vocab.word2idx) + self.assertDictEqual(TRUE_IDXDICT, vocab.idx2word) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/test/model/seq_labeling.py b/test/model/seq_labeling.py index d7750b17..cd011c0d 100644 --- a/test/model/seq_labeling.py +++ b/test/model/seq_labeling.py @@ -38,7 +38,7 @@ def infer(): # fetch dictionary size and number of labels from pickle files word2index = load_pickle(pickle_path, "word2id.pkl") test_args["vocab_size"] = len(word2index) - index2label = load_pickle(pickle_path, "id2class.pkl") + index2label = load_pickle(pickle_path, "class2id.pkl") test_args["num_classes"] = len(index2label) # Define the same model diff --git a/test/model/test_cws.py b/test/model/test_cws.py index 802d97ba..70716c3a 100644 --- a/test/model/test_cws.py +++ b/test/model/test_cws.py @@ -27,7 +27,7 @@ def infer(): # fetch dictionary size and number of labels from pickle files word2index = load_pickle(pickle_path, "word2id.pkl") test_args["vocab_size"] = len(word2index) - index2label = load_pickle(pickle_path, "id2class.pkl") + index2label = load_pickle(pickle_path, "class2id.pkl") test_args["num_classes"] = len(index2label) # Define the same model