| @@ -27,8 +27,8 @@ class Predictor(object): | |||||
| self.batch_output = [] | self.batch_output = [] | ||||
| self.pickle_path = pickle_path | self.pickle_path = pickle_path | ||||
| self._task = task # one of ("seq_label", "text_classify") | self._task = task # one of ("seq_label", "text_classify") | ||||
| self.index2label = load_pickle(self.pickle_path, "id2class.pkl") | |||||
| self.word2index = load_pickle(self.pickle_path, "word2id.pkl") | |||||
| self.label_vocab = load_pickle(self.pickle_path, "class2id.pkl") | |||||
| self.word_vocab = load_pickle(self.pickle_path, "word2id.pkl") | |||||
| def predict(self, network, data): | def predict(self, network, data): | ||||
| """Perform inference using the trained model. | """Perform inference using the trained model. | ||||
| @@ -82,7 +82,7 @@ class Predictor(object): | |||||
| :return data_set: a DataSet instance. | :return data_set: a DataSet instance. | ||||
| """ | """ | ||||
| assert isinstance(data, list) | assert isinstance(data, list) | ||||
| return create_dataset_from_lists(data, self.word2index, has_target=False) | |||||
| return create_dataset_from_lists(data, self.word_vocab, has_target=False) | |||||
| def prepare_output(self, data): | def prepare_output(self, data): | ||||
| """Transform list of batch outputs into strings.""" | """Transform list of batch outputs into strings.""" | ||||
| @@ -97,14 +97,14 @@ class Predictor(object): | |||||
| results = [] | results = [] | ||||
| for batch in batch_outputs: | for batch in batch_outputs: | ||||
| for example in np.array(batch): | for example in np.array(batch): | ||||
| results.append([self.index2label[int(x)] for x in example]) | |||||
| results.append([self.label_vocab.to_word(int(x)) for x in example]) | |||||
| return results | return results | ||||
| def _text_classify_prepare_output(self, batch_outputs): | def _text_classify_prepare_output(self, batch_outputs): | ||||
| results = [] | results = [] | ||||
| for batch_out in batch_outputs: | for batch_out in batch_outputs: | ||||
| idx = np.argmax(batch_out.detach().numpy(), axis=-1) | idx = np.argmax(batch_out.detach().numpy(), axis=-1) | ||||
| results.extend([self.index2label[i] for i in idx]) | |||||
| results.extend([self.label_vocab.to_word(i) for i in idx]) | |||||
| return results | return results | ||||
| @@ -69,7 +69,7 @@ class FastNLP(object): | |||||
| :param model_dir: this directory should contain the following files: | :param model_dir: this directory should contain the following files: | ||||
| 1. a pre-trained model | 1. a pre-trained model | ||||
| 2. a config file | 2. a config file | ||||
| 3. "id2class.pkl" | |||||
| 3. "class2id.pkl" | |||||
| 4. "word2id.pkl" | 4. "word2id.pkl" | ||||
| """ | """ | ||||
| self.model_dir = model_dir | self.model_dir = model_dir | ||||
| @@ -99,10 +99,10 @@ class FastNLP(object): | |||||
| print("Restore model hyper-parameters {}".format(str(model_args.data))) | print("Restore model hyper-parameters {}".format(str(model_args.data))) | ||||
| # fetch dictionary size and number of labels from pickle files | # fetch dictionary size and number of labels from pickle files | ||||
| word2index = load_pickle(self.model_dir, "word2id.pkl") | |||||
| model_args["vocab_size"] = len(word2index) | |||||
| index2label = load_pickle(self.model_dir, "id2class.pkl") | |||||
| model_args["num_classes"] = len(index2label) | |||||
| word_vocab = load_pickle(self.model_dir, "word2id.pkl") | |||||
| model_args["vocab_size"] = len(word_vocab) | |||||
| label_vocab = load_pickle(self.model_dir, "class2id.pkl") | |||||
| model_args["num_classes"] = len(label_vocab) | |||||
| # Construct the model | # Construct the model | ||||
| model = model_class(model_args) | model = model_class(model_args) | ||||
| @@ -32,7 +32,7 @@ def infer(): | |||||
| # fetch dictionary size and number of labels from pickle files | # fetch dictionary size and number of labels from pickle files | ||||
| word2index = load_pickle(pickle_path, "word2id.pkl") | word2index = load_pickle(pickle_path, "word2id.pkl") | ||||
| test_args["vocab_size"] = len(word2index) | test_args["vocab_size"] = len(word2index) | ||||
| index2label = load_pickle(pickle_path, "id2class.pkl") | |||||
| index2label = load_pickle(pickle_path, "class2id.pkl") | |||||
| test_args["num_classes"] = len(index2label) | test_args["num_classes"] = len(index2label) | ||||
| @@ -105,7 +105,7 @@ def test(): | |||||
| # fetch dictionary size and number of labels from pickle files | # fetch dictionary size and number of labels from pickle files | ||||
| word2index = load_pickle(pickle_path, "word2id.pkl") | word2index = load_pickle(pickle_path, "word2id.pkl") | ||||
| test_args["vocab_size"] = len(word2index) | test_args["vocab_size"] = len(word2index) | ||||
| index2label = load_pickle(pickle_path, "id2class.pkl") | |||||
| index2label = load_pickle(pickle_path, "class2id.pkl") | |||||
| test_args["num_classes"] = len(index2label) | test_args["num_classes"] = len(index2label) | ||||
| # load dev data | # load dev data | ||||
| @@ -33,7 +33,7 @@ def infer(): | |||||
| # fetch dictionary size and number of labels from pickle files | # fetch dictionary size and number of labels from pickle files | ||||
| word2index = load_pickle(pickle_path, "word2id.pkl") | word2index = load_pickle(pickle_path, "word2id.pkl") | ||||
| test_args["vocab_size"] = len(word2index) | test_args["vocab_size"] = len(word2index) | ||||
| index2label = load_pickle(pickle_path, "id2class.pkl") | |||||
| index2label = load_pickle(pickle_path, "class2id.pkl") | |||||
| test_args["num_classes"] = len(index2label) | test_args["num_classes"] = len(index2label) | ||||
| # Define the same model | # Define the same model | ||||
| @@ -105,7 +105,7 @@ def test(): | |||||
| # fetch dictionary size and number of labels from pickle files | # fetch dictionary size and number of labels from pickle files | ||||
| word2index = load_pickle(pickle_path, "word2id.pkl") | word2index = load_pickle(pickle_path, "word2id.pkl") | ||||
| test_args["vocab_size"] = len(word2index) | test_args["vocab_size"] = len(word2index) | ||||
| index2label = load_pickle(pickle_path, "id2class.pkl") | |||||
| index2label = load_pickle(pickle_path, "class2id.pkl") | |||||
| test_args["num_classes"] = len(index2label) | test_args["num_classes"] = len(index2label) | ||||
| # load dev data | # load dev data | ||||
| @@ -4,6 +4,7 @@ import unittest | |||||
| from fastNLP.core.predictor import Predictor | from fastNLP.core.predictor import Predictor | ||||
| from fastNLP.core.preprocess import save_pickle | from fastNLP.core.preprocess import save_pickle | ||||
| from fastNLP.models.sequence_modeling import SeqLabeling | from fastNLP.models.sequence_modeling import SeqLabeling | ||||
| from fastNLP.core.vocabulary import Vocabulary | |||||
| class TestPredictor(unittest.TestCase): | class TestPredictor(unittest.TestCase): | ||||
| @@ -23,10 +24,14 @@ class TestPredictor(unittest.TestCase): | |||||
| ['a', 'b', 'c', 'd', '$'], | ['a', 'b', 'c', 'd', '$'], | ||||
| ['!', 'b', 'c', 'd', 'e'] | ['!', 'b', 'c', 'd', 'e'] | ||||
| ] | ] | ||||
| vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | |||||
| vocab = Vocabulary() | |||||
| vocab.word2idx = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | |||||
| class_vocab = Vocabulary() | |||||
| class_vocab.word2idx = {"0":0, "1":1, "2":2, "3":3, "4":4} | |||||
| os.system("mkdir save") | os.system("mkdir save") | ||||
| save_pickle({0: "0", 1: "1", 2: "2", 3: "3", 4: "4"}, "./save/", "id2class.pkl") | |||||
| save_pickle(class_vocab, "./save/", "class2id.pkl") | |||||
| save_pickle(vocab, "./save/", "word2id.pkl") | save_pickle(vocab, "./save/", "word2id.pkl") | ||||
| model = SeqLabeling(model_args) | model = SeqLabeling(model_args) | ||||
| @@ -38,7 +38,7 @@ def infer(): | |||||
| # fetch dictionary size and number of labels from pickle files | # fetch dictionary size and number of labels from pickle files | ||||
| word2index = load_pickle(pickle_path, "word2id.pkl") | word2index = load_pickle(pickle_path, "word2id.pkl") | ||||
| test_args["vocab_size"] = len(word2index) | test_args["vocab_size"] = len(word2index) | ||||
| index2label = load_pickle(pickle_path, "id2class.pkl") | |||||
| index2label = load_pickle(pickle_path, "class2id.pkl") | |||||
| test_args["num_classes"] = len(index2label) | test_args["num_classes"] = len(index2label) | ||||
| # Define the same model | # Define the same model | ||||
| @@ -27,7 +27,7 @@ def infer(): | |||||
| # fetch dictionary size and number of labels from pickle files | # fetch dictionary size and number of labels from pickle files | ||||
| word2index = load_pickle(pickle_path, "word2id.pkl") | word2index = load_pickle(pickle_path, "word2id.pkl") | ||||
| test_args["vocab_size"] = len(word2index) | test_args["vocab_size"] = len(word2index) | ||||
| index2label = load_pickle(pickle_path, "id2class.pkl") | |||||
| index2label = load_pickle(pickle_path, "class2id.pkl") | |||||
| test_args["num_classes"] = len(index2label) | test_args["num_classes"] = len(index2label) | ||||
| # Define the same model | # Define the same model | ||||