From 91f3d97acef6ddcd334c73f5017337f849e0d204 Mon Sep 17 00:00:00 2001 From: xuyige Date: Sun, 30 Sep 2018 21:24:05 +0800 Subject: [PATCH] Update to new version of framework --- fastNLP/core/batch.py | 2 +- fastNLP/core/dataset.py | 9 ++++++ fastNLP/core/metrics.py | 13 ++++++--- fastNLP/core/predictor.py | 2 +- fastNLP/core/preprocess.py | 3 ++ fastNLP/fastnlp.py | 42 +++++++++++++++++++++++---- fastNLP/saver/config_saver.py | 2 +- test/core/test_batch.py | 4 ++- test/core/test_metrics.py | 29 +++++++++++++++++-- test/core/test_predictor.py | 40 ++++++++++++++++++++------ test/core/test_tester.py | 14 +++++---- test/core/test_trainer.py | 21 ++++++++------ test/loader/config | 46 +++++++++++++++++++++++++++++- test/loader/test_config_loader.py | 2 +- test/loader/test_dataset_loader.py | 17 +++++------ test/model/test_cws.py | 6 ++-- test/saver/test_config_saver.py | 12 ++++---- test/test_fastNLP.py | 6 ++-- 18 files changed, 208 insertions(+), 62 deletions(-) diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index 0f8a0615..bf837d0f 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -69,6 +69,6 @@ class Batch(object): else: batch[name] = torch.stack(tensor_list, dim=0) - self.curidx += endidx + self.curidx = endidx return batch_x, batch_y diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 90f10a77..b2a4af59 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -144,6 +144,15 @@ class DataSet(list): else: self.convert(raw_data) + def load_raw(self, raw_data, vocabs): + """ + + :param raw_data: + :param vocabs: + :return: + """ + self.convert_for_infer(raw_data, vocabs) + def split(self, ratio, shuffle=True): """Train/dev splitting diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 8d7dafa0..75401194 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -38,14 +38,19 @@ class SeqLabelEvaluator(Evaluator): def __call__(self, predict, truth): """ - :param predict: list of tensors, the network outputs from all batches. + :param predict: list of List, the network outputs from all batches. :param truth: list of dict, the ground truths from all batch_y. :return accuracy: """ truth = [item["truth"] for item in truth] - truth = torch.cat(truth).view(-1, ) - results = torch.Tensor(predict).view(-1, ) - accuracy = torch.sum(results.to(truth) == truth).to(torch.float) / results.shape[0] + total_correct, total_count= 0., 0. + for x, y in zip(predict, truth): + mask = torch.Tensor(x).ge(1) + correct = torch.sum(torch.Tensor(x) * mask.float() == (y * mask.long()).float()) + correct -= torch.sum(torch.Tensor(x).le(0)) + total_correct += float(correct) + total_count += float(torch.sum(mask)) + accuracy = total_correct / total_count return {"accuracy": float(accuracy)} diff --git a/fastNLP/core/predictor.py b/fastNLP/core/predictor.py index c564bab0..14c4e8c1 100644 --- a/fastNLP/core/predictor.py +++ b/fastNLP/core/predictor.py @@ -34,7 +34,7 @@ class Predictor(object): """Perform inference using the trained model. :param network: a PyTorch model (cpu) - :param data: list of list of strings, [num_examples, seq_len] + :param data: a DataSet object. :return: list of list of strings, [num_examples, tag_seq_length] """ # transform strings into DataSet object diff --git a/fastNLP/core/preprocess.py b/fastNLP/core/preprocess.py index 913600ab..b0032a3c 100644 --- a/fastNLP/core/preprocess.py +++ b/fastNLP/core/preprocess.py @@ -18,6 +18,9 @@ def save_pickle(obj, pickle_path, file_name): :param pickle_path: str, the directory where the pickle file is to be saved :param file_name: str, the name of the pickle file. In general, it should be ended by "pkl". """ + if not os.path.exists(pickle_path): + os.mkdir(pickle_path) + print("make dir {} before saving pickle file".format(pickle_path)) with open(os.path.join(pickle_path, file_name), "wb") as f: _pickle.dump(obj, f) print("{} saved in {}".format(file_name, pickle_path)) diff --git a/fastNLP/fastnlp.py b/fastNLP/fastnlp.py index 4643c247..1fab5cda 100644 --- a/fastNLP/fastnlp.py +++ b/fastNLP/fastnlp.py @@ -4,6 +4,8 @@ from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer from fastNLP.core.preprocess import load_pickle from fastNLP.loader.config_loader import ConfigLoader, ConfigSection from fastNLP.loader.model_loader import ModelLoader +from fastNLP.core.dataset import SeqLabelDataSet, TextClassifyDataSet + """ mapping from model name to [URL, file_name.class_name, model_pickle_name] @@ -76,6 +78,8 @@ class FastNLP(object): self.model_dir = model_dir self.model = None self.infer_type = None # "seq_label"/"text_class" + self.word_vocab = None + self.label_vocab = None def load(self, model_name, config_file="config", section_name="model"): """ @@ -100,10 +104,10 @@ class FastNLP(object): print("Restore model hyper-parameters {}".format(str(model_args.data))) # fetch dictionary size and number of labels from pickle files - word_vocab = load_pickle(self.model_dir, "word2id.pkl") - model_args["vocab_size"] = len(word_vocab) - label_vocab = load_pickle(self.model_dir, "class2id.pkl") - model_args["num_classes"] = len(label_vocab) + self.word_vocab = load_pickle(self.model_dir, "word2id.pkl") + model_args["vocab_size"] = len(self.word_vocab) + self.label_vocab = load_pickle(self.model_dir, "label2id.pkl") + model_args["num_classes"] = len(self.label_vocab) # Construct the model model = model_class(model_args) @@ -130,8 +134,11 @@ class FastNLP(object): # tokenize: list of string ---> 2-D list of string infer_input = self.tokenize(raw_input, language="zh") - # 2-D list of string ---> 2-D list of tags - results = infer.predict(self.model, infer_input) + # create DataSet: 2-D list of strings ----> DataSet + infer_data = self._create_data_set(infer_input) + + # DataSet ---> 2-D list of tags + results = infer.predict(self.model, infer_data) # 2-D list of tags ---> list of final answers outputs = self._make_output(results, infer_input) @@ -154,6 +161,11 @@ class FastNLP(object): return module def _create_inference(self, model_dir): + """Specify which task to perform. + + :param model_dir: + :return: + """ if self.infer_type == "seq_label": return SeqLabelInfer(model_dir) elif self.infer_type == "text_class": @@ -161,6 +173,24 @@ class FastNLP(object): else: raise ValueError("fail to create inference instance") + def _create_data_set(self, infer_input): + """Create a DataSet object given the raw inputs. + + :param infer_input: 2-D lists of strings + :return data_set: a DataSet object + """ + if self.infer_type == "seq_label": + data_set = SeqLabelDataSet() + data_set.load_raw(infer_input, {"word_vocab": self.word_vocab}) + return data_set + elif self.infer_type == "text_class": + data_set = TextClassifyDataSet() + data_set.load_raw(infer_input, {"word_vocab": self.word_vocab}) + return data_set + else: + raise RuntimeError("fail to make outputs with infer type {}".format(self.infer_type)) + + def _load(self, model_dir, model_name): # To do return 0 diff --git a/fastNLP/saver/config_saver.py b/fastNLP/saver/config_saver.py index e05e865e..8d5f08d1 100644 --- a/fastNLP/saver/config_saver.py +++ b/fastNLP/saver/config_saver.py @@ -18,7 +18,7 @@ class ConfigSaver(object): :return: The section. """ sect = ConfigSection() - ConfigLoader(self.file_path).load_config(self.file_path, {sect_name: sect}) + ConfigLoader().load_config(self.file_path, {sect_name: sect}) return sect def _read_section(self): diff --git a/test/core/test_batch.py b/test/core/test_batch.py index 395aeb2b..5de91da8 100644 --- a/test/core/test_batch.py +++ b/test/core/test_batch.py @@ -43,8 +43,10 @@ class TestCase1(unittest.TestCase): # use batch to iterate dataset data_iterator = Batch(data, 2, SeqSampler(), False) + total_data = 0 for batch_x, batch_y in data_iterator: - self.assertEqual(len(batch_x), 2) + total_data += batch_x["text"].size(0) + self.assertTrue(batch_x["text"].size(0) == 2 or total_data == len(raw_texts)) self.assertTrue(isinstance(batch_x, dict)) self.assertTrue(isinstance(batch_x["text"], torch.LongTensor)) self.assertTrue(isinstance(batch_y, dict)) diff --git a/test/core/test_metrics.py b/test/core/test_metrics.py index c8d48162..806d1032 100644 --- a/test/core/test_metrics.py +++ b/test/core/test_metrics.py @@ -1,20 +1,42 @@ -import sys, os +import os +import sys + sys.path = [os.path.join(os.path.dirname(__file__), '..')] + sys.path from fastNLP.core import metrics # from sklearn import metrics as skmetrics import unittest -import numpy as np from numpy import random +from fastNLP.core.metrics import SeqLabelEvaluator +import torch + def generate_fake_label(low, high, size): return random.randint(low, high, size), random.randint(low, high, size) + +class TestEvaluator(unittest.TestCase): + def test_a(self): + evaluator = SeqLabelEvaluator() + pred = [[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]] + truth = [{"truth": torch.LongTensor([1, 2, 3, 3, 3])}, {"truth": torch.LongTensor([1, 2, 3, 3, 4])}] + ans = evaluator(pred, truth) + print(ans) + + def test_b(self): + evaluator = SeqLabelEvaluator() + pred = [[1, 2, 3, 4, 5, 0, 0], [1, 2, 3, 4, 5, 0, 0]] + truth = [{"truth": torch.LongTensor([1, 2, 3, 3, 3, 0, 0])}, {"truth": torch.LongTensor([1, 2, 3, 3, 4, 0, 0])}] + ans = evaluator(pred, truth) + print(ans) + + class TestMetrics(unittest.TestCase): delta = 1e-5 # test for binary, multiclass, multilabel data_types = [((1000,), 2), ((1000,), 10), ((1000, 10), 2)] fake_data = [generate_fake_label(0, high, shape) for shape, high in data_types] + def test_accuracy_score(self): for y_true, y_pred in self.fake_data: for normalize in [True, False]: @@ -22,7 +44,7 @@ class TestMetrics(unittest.TestCase): test = metrics.accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight) # ans = skmetrics.accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight) # self.assertAlmostEqual(test, ans, delta=self.delta) - + def test_recall_score(self): for y_true, y_pred in self.fake_data: # print(y_true.shape) @@ -73,5 +95,6 @@ class TestMetrics(unittest.TestCase): # ans = skmetrics.f1_score(y_true, y_pred) # self.assertAlmostEqual(ans, test, delta=self.delta) + if __name__ == '__main__': unittest.main() diff --git a/test/core/test_predictor.py b/test/core/test_predictor.py index 411f636e..b4a05df0 100644 --- a/test/core/test_predictor.py +++ b/test/core/test_predictor.py @@ -2,9 +2,12 @@ import os import unittest from fastNLP.core.predictor import Predictor +from fastNLP.core.dataset import TextClassifyDataSet, SeqLabelDataSet from fastNLP.core.preprocess import save_pickle -from fastNLP.models.sequence_modeling import SeqLabeling from fastNLP.core.vocabulary import Vocabulary +from fastNLP.loader.base_loader import BaseLoader +from fastNLP.models.sequence_modeling import SeqLabeling +from fastNLP.models.cnn_text_classification import CNNText class TestPredictor(unittest.TestCase): @@ -28,23 +31,44 @@ class TestPredictor(unittest.TestCase): vocab = Vocabulary() vocab.word2idx = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} class_vocab = Vocabulary() - class_vocab.word2idx = {"0":0, "1":1, "2":2, "3":3, "4":4} + class_vocab.word2idx = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4} os.system("mkdir save") - save_pickle(class_vocab, "./save/", "class2id.pkl") + save_pickle(class_vocab, "./save/", "label2id.pkl") save_pickle(vocab, "./save/", "word2id.pkl") - model = SeqLabeling(model_args) - predictor = Predictor("./save/", task="seq_label") + model = CNNText(model_args) + import fastNLP.core.predictor as pre + predictor = Predictor("./save/", pre.text_classify_post_processor) - results = predictor.predict(network=model, data=infer_data) + # Load infer data + infer_data_set = TextClassifyDataSet(loader=BaseLoader()) + infer_data_set.convert_for_infer(infer_data, vocabs={"word_vocab": vocab.word2idx}) + + results = predictor.predict(network=model, data=infer_data_set) self.assertTrue(isinstance(results, list)) self.assertGreater(len(results), 0) + self.assertEqual(len(results), len(infer_data)) for res in results: + self.assertTrue(isinstance(res, str)) + self.assertTrue(res in class_vocab.word2idx) + + del model, predictor, infer_data_set + + model = SeqLabeling(model_args) + predictor = Predictor("./save/", pre.seq_label_post_processor) + + infer_data_set = SeqLabelDataSet(loader=BaseLoader()) + infer_data_set.convert_for_infer(infer_data, vocabs={"word_vocab": vocab.word2idx}) + + results = predictor.predict(network=model, data=infer_data_set) + self.assertTrue(isinstance(results, list)) + self.assertEqual(len(results), len(infer_data)) + for i in range(len(infer_data)): + res = results[i] self.assertTrue(isinstance(res, list)) - self.assertEqual(len(res), 5) - self.assertTrue(isinstance(res[0], str)) + self.assertEqual(len(res), len(infer_data[i])) os.system("rm -rf save") print("pickle path deleted") diff --git a/test/core/test_tester.py b/test/core/test_tester.py index aa277b9a..1118f284 100644 --- a/test/core/test_tester.py +++ b/test/core/test_tester.py @@ -1,8 +1,9 @@ import os import unittest -from fastNLP.core.dataset import DataSet -from fastNLP.core.field import TextField +from fastNLP.core.dataset import SeqLabelDataSet +from fastNLP.core.metrics import SeqLabelEvaluator +from fastNLP.core.field import TextField, LabelField from fastNLP.core.instance import Instance from fastNLP.core.tester import SeqLabelTester from fastNLP.models.sequence_modeling import SeqLabeling @@ -21,7 +22,7 @@ class TestTester(unittest.TestCase): } valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, "save_loss": True, "batch_size": 2, "pickle_path": "./save/", - "use_cuda": False, "print_every_step": 1} + "use_cuda": False, "print_every_step": 1, "evaluator": SeqLabelEvaluator()} train_data = [ [['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], @@ -34,16 +35,17 @@ class TestTester(unittest.TestCase): vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4} - data_set = DataSet() + data_set = SeqLabelDataSet() for example in train_data: text, label = example[0], example[1] x = TextField(text, False) + x_len = LabelField(len(text), is_target=False) y = TextField(label, is_target=True) - ins = Instance(word_seq=x, label_seq=y) + ins = Instance(word_seq=x, truth=y, word_seq_origin_len=x_len) data_set.append(ins) data_set.index_field("word_seq", vocab) - data_set.index_field("label_seq", label_vocab) + data_set.index_field("truth", label_vocab) model = SeqLabeling(model_args) diff --git a/test/core/test_trainer.py b/test/core/test_trainer.py index c71cd695..b4a9178f 100644 --- a/test/core/test_trainer.py +++ b/test/core/test_trainer.py @@ -1,8 +1,9 @@ import os import unittest -from fastNLP.core.dataset import DataSet -from fastNLP.core.field import TextField +from fastNLP.core.dataset import SeqLabelDataSet +from fastNLP.core.metrics import SeqLabelEvaluator +from fastNLP.core.field import TextField, LabelField from fastNLP.core.instance import Instance from fastNLP.core.loss import Loss from fastNLP.core.optimizer import Optimizer @@ -12,14 +13,15 @@ from fastNLP.models.sequence_modeling import SeqLabeling class TestTrainer(unittest.TestCase): def test_case_1(self): - args = {"epochs": 3, "batch_size": 2, "validate": True, "use_cuda": False, "pickle_path": "./save/", + args = {"epochs": 3, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/", "save_best_dev": True, "model_name": "default_model_name.pkl", - "loss": Loss(None), + "loss": Loss("cross_entropy"), "optimizer": Optimizer("Adam", lr=0.001, weight_decay=0), "vocab_size": 10, "word_emb_dim": 100, "rnn_hidden_units": 100, - "num_classes": 5 + "num_classes": 5, + "evaluator": SeqLabelEvaluator() } trainer = SeqLabelTrainer(**args) @@ -34,16 +36,17 @@ class TestTrainer(unittest.TestCase): vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4} - data_set = DataSet() + data_set = SeqLabelDataSet() for example in train_data: text, label = example[0], example[1] x = TextField(text, False) - y = TextField(label, is_target=True) - ins = Instance(word_seq=x, label_seq=y) + x_len = LabelField(len(text), is_target=False) + y = TextField(label, is_target=False) + ins = Instance(word_seq=x, truth=y, word_seq_origin_len=x_len) data_set.append(ins) data_set.index_field("word_seq", vocab) - data_set.index_field("label_seq", label_vocab) + data_set.index_field("truth", label_vocab) model = SeqLabeling(args) diff --git a/test/loader/config b/test/loader/config index b91e750d..5ff9eacf 100644 --- a/test/loader/config +++ b/test/loader/config @@ -9,10 +9,54 @@ input = [1,2,3] text = "this is text" -doubles = 0.5 +doubles = 0.8 + +tt = 0.5 + +test = 105 + +str = "this is a str" + +double = 0.5 + [t] x = "this is an test section" + + [test-case-2] double = 0.5 + +doubles = 0.8 + +tt = 0.5 + +test = 105 + +str = "this is a str" + +[another-test] +doubles = 0.8 + +tt = 0.5 + +test = 105 + +str = "this is a str" + +double = 0.5 + + +[one-another-test] +doubles = 0.8 + +tt = 0.5 + +test = 105 + +str = "this is a str" + +double = 0.5 + + diff --git a/test/loader/test_config_loader.py b/test/loader/test_config_loader.py index 485eed3c..ef274b50 100644 --- a/test/loader/test_config_loader.py +++ b/test/loader/test_config_loader.py @@ -31,7 +31,7 @@ class TestConfigLoader(unittest.TestCase): return dict test_arg = ConfigSection() - ConfigLoader("config").load_config(os.path.join("./test/loader", "config"), {"test": test_arg}) + ConfigLoader().load_config(os.path.join("./test/loader", "config"), {"test": test_arg}) section = read_section_from_config(os.path.join("./test/loader", "config"), "test") diff --git a/test/loader/test_dataset_loader.py b/test/loader/test_dataset_loader.py index 1bb070e0..94a7fa71 100644 --- a/test/loader/test_dataset_loader.py +++ b/test/loader/test_dataset_loader.py @@ -1,3 +1,4 @@ +import os import unittest from fastNLP.loader.dataset_loader import POSDataSetLoader, LMDataSetLoader, TokenizeDataSetLoader, \ @@ -14,28 +15,28 @@ class TestDatasetLoader(unittest.TestCase): def test_case_TokenizeDatasetLoader(self): loader = TokenizeDataSetLoader() - data = loader.load("test/data_for_tests/", max_seq_len=32) + data = loader.load("./test/data_for_tests/cws_pku_utf_8", max_seq_len=32) print("pass TokenizeDataSetLoader test!") def test_case_POSDatasetLoader(self): loader = POSDataSetLoader() - data = loader.load() - datas = loader.load_lines() + data = loader.load("./test/data_for_tests/people.txt") + datas = loader.load_lines("./test/data_for_tests/people.txt") print("pass POSDataSetLoader test!") def test_case_LMDatasetLoader(self): loader = LMDataSetLoader() - data = loader.load() - datas = loader.load_lines() + data = loader.load("./test/data_for_tests/charlm.txt") + datas = loader.load_lines("./test/data_for_tests/charlm.txt") print("pass TokenizeDataSetLoader test!") def test_PeopleDailyCorpusLoader(self): loader = PeopleDailyCorpusLoader() - _, _ = loader.load() + _, _ = loader.load("./test/data_for_tests/people_daily_raw.txt") def test_ConllLoader(self): - loader = ConllLoader("./test/data_for_tests/conll_example.txt") - _ = loader.load() + loader = ConllLoader() + _ = loader.load("./test/data_for_tests/conll_example.txt") if __name__ == '__main__': diff --git a/test/model/test_cws.py b/test/model/test_cws.py index ba1a9c03..9baa8820 100644 --- a/test/model/test_cws.py +++ b/test/model/test_cws.py @@ -13,10 +13,10 @@ from fastNLP.models.sequence_modeling import SeqLabeling from fastNLP.saver.model_saver import ModelSaver data_name = "pku_training.utf8" -cws_data_path = "test/data_for_tests/cws_pku_utf_8" +cws_data_path = "./test/data_for_tests/cws_pku_utf_8" pickle_path = "./save/" -data_infer_path = "test/data_for_tests/people_infer.txt" -config_path = "test/data_for_tests/config" +data_infer_path = "./test/data_for_tests/people_infer.txt" +config_path = "./test/data_for_tests/config" def infer(): # Load infer configuration, the same as test diff --git a/test/saver/test_config_saver.py b/test/saver/test_config_saver.py index 45daf0c6..c032f4dc 100644 --- a/test/saver/test_config_saver.py +++ b/test/saver/test_config_saver.py @@ -21,7 +21,7 @@ class TestConfigSaver(unittest.TestCase): standard_section = ConfigSection() t_section = ConfigSection() - ConfigLoader(config_file_path).load_config(config_file_path, {"test": standard_section, "t": t_section}) + ConfigLoader().load_config(config_file_path, {"test": standard_section, "t": t_section}) config_saver = ConfigSaver(config_file_path) @@ -48,11 +48,11 @@ class TestConfigSaver(unittest.TestCase): one_another_test_section = ConfigSection() a_test_case_2_section = ConfigSection() - ConfigLoader(config_file_path).load_config(config_file_path, {"test": test_section, - "another-test": another_test_section, - "t": at_section, - "one-another-test": one_another_test_section, - "test-case-2": a_test_case_2_section}) + ConfigLoader().load_config(config_file_path, {"test": test_section, + "another-test": another_test_section, + "t": at_section, + "one-another-test": one_another_test_section, + "test-case-2": a_test_case_2_section}) assert test_section == standard_section assert at_section == t_section diff --git a/test/test_fastNLP.py b/test/test_fastNLP.py index a40a0cf4..1180adef 100644 --- a/test/test_fastNLP.py +++ b/test/test_fastNLP.py @@ -54,7 +54,7 @@ def mock_cws(): class2id = Vocabulary(need_default=False) label_list = ['B', 'M', 'E', 'S'] class2id.update(label_list) - save_pickle(class2id, "./mock/", "class2id.pkl") + save_pickle(class2id, "./mock/", "label2id.pkl") model_args = {"vocab_size": len(word2id), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(class2id)} config_file = """ @@ -115,7 +115,7 @@ def mock_pos_tag(): idx2label = Vocabulary(need_default=False) label_list = ['B-n', 'M-v', 'E-nv', 'S-adj', 'B-v', 'M-vn', 'S-adv'] idx2label.update(label_list) - save_pickle(idx2label, "./mock/", "class2id.pkl") + save_pickle(idx2label, "./mock/", "label2id.pkl") model_args = {"vocab_size": len(vocab), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(idx2label)} config_file = """ @@ -163,7 +163,7 @@ def mock_text_classify(): idx2label = Vocabulary(need_default=False) label_list = ['class_A', 'class_B', 'class_C', 'class_D', 'class_E', 'class_F'] idx2label.update(label_list) - save_pickle(idx2label, "./mock/", "class2id.pkl") + save_pickle(idx2label, "./mock/", "label2id.pkl") model_args = {"vocab_size": len(vocab), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(idx2label)} config_file = """