From 9d6b0daa9914051b90266ce58812bd9a5d2b3495 Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Thu, 30 Aug 2018 11:45:47 +0800 Subject: [PATCH] Prepare for CWS service: - specify the name of the config file and the name of corresponding section where model init params store. - fastnlp.py needs load_pickle to get dictionary size and the number of labels - other minor adjustments --- fastNLP/core/preprocess.py | 4 ++-- fastNLP/fastnlp.py | 27 +++++++++++++++++------- reproduction/chinese_word_segment/run.py | 6 +++--- test/test_fastNLP.py | 10 ++++++--- 4 files changed, 31 insertions(+), 16 deletions(-) diff --git a/fastNLP/core/preprocess.py b/fastNLP/core/preprocess.py index dfaf3e94..1805f4eb 100644 --- a/fastNLP/core/preprocess.py +++ b/fastNLP/core/preprocess.py @@ -19,13 +19,13 @@ DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1, def save_pickle(obj, pickle_path, file_name): with open(os.path.join(pickle_path, file_name), "wb") as f: _pickle.dump(obj, f) - print("{} saved. ".format(file_name)) + print("{} saved in {}.".format(file_name, pickle_path)) def load_pickle(pickle_path, file_name): with open(os.path.join(pickle_path, file_name), "rb") as f: obj = _pickle.load(f) - print("{} loaded. ".format(file_name)) + print("{} loaded from {}.".format(file_name, pickle_path)) return obj diff --git a/fastNLP/fastnlp.py b/fastNLP/fastnlp.py index 5be5cad3..2590eaf5 100644 --- a/fastNLP/fastnlp.py +++ b/fastNLP/fastnlp.py @@ -1,4 +1,5 @@ -from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer +# from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer +from fastNLP.core.preprocess import load_pickle from fastNLP.loader.config_loader import ConfigLoader, ConfigSection from fastNLP.loader.model_loader import ModelLoader @@ -11,7 +12,9 @@ Example: "url": "www.fudan.edu.cn", "class": "sequence_modeling.SeqLabeling", # file_name.class_name in models/ "pickle": "seq_label_model.pkl", - "type": "seq_label" + "type": "seq_label", + "config_file_name": "config", # the name of the config file which stores model initialization parameters + "config_section_name": "text_class_model" # the name of the section in the config file which stores model init params }, "text_class_model": { "url": "www.fudan.edu.cn", @@ -25,13 +28,12 @@ FastNLP_MODEL_COLLECTION = { "url": "", "class": "sequence_modeling.AdvSeqLabel", "pickle": "cws_basic_model_v_0.pkl", - "type": "seq_label" + "type": "seq_label", + "config_file_name": "config", + "config_section_name": "text_class_model" } } -CONFIG_FILE_NAME = "config" -SECTION_NAME = "text_class_model" - class FastNLP(object): """ @@ -56,10 +58,13 @@ class FastNLP(object): self.model = None self.infer_type = None # "seq_label"/"text_class" - def load(self, model_name): + def load(self, model_name, config_file="config", section_name="model"): """ Load a pre-trained FastNLP model together with additional data. :param model_name: str, the name of a FastNLP model. + :param config_file: str, the name of the config file which stores the initialization information of the model. + (default: "config") + :param section_name: str, the name of the corresponding section in the config file. (default: model) """ assert type(model_name) is str if model_name not in FastNLP_MODEL_COLLECTION: @@ -71,7 +76,13 @@ class FastNLP(object): model_class = self._get_model_class(FastNLP_MODEL_COLLECTION[model_name]["class"]) model_args = ConfigSection() - ConfigLoader.load_config(self.model_dir + CONFIG_FILE_NAME, {SECTION_NAME: model_args}) + ConfigLoader.load_config(self.model_dir + config_file, {section_name: model_args}) + + # fetch dictionary size and number of labels from pickle files + word2index = load_pickle(self.model_dir, "word2id.pkl") + model_args["vocab_size"] = len(word2index) + index2label = load_pickle(self.model_dir, "id2class.pkl") + model_args["num_classes"] = len(index2label) # Construct the model model = model_class(model_args) diff --git a/reproduction/chinese_word_segment/run.py b/reproduction/chinese_word_segment/run.py index 57601e8e..188d478d 100644 --- a/reproduction/chinese_word_segment/run.py +++ b/reproduction/chinese_word_segment/run.py @@ -1,4 +1,5 @@ -import sys, os +import os +import sys sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) @@ -11,7 +12,6 @@ from fastNLP.loader.model_loader import ModelLoader from fastNLP.core.tester import SeqLabelTester from fastNLP.models.sequence_modeling import AdvSeqLabel from fastNLP.core.inference import SeqLabelInfer -from fastNLP.core.optimizer import SGD # not in the file's dir if len(os.path.dirname(__file__)) != 0: @@ -75,7 +75,7 @@ def train(): train_args["num_classes"] = p.num_classes # Trainer - trainer = SeqLabelTrainer(train_args) + trainer = SeqLabelTrainer(**train_args.data) # Model model = AdvSeqLabel(train_args) diff --git a/test/test_fastNLP.py b/test/test_fastNLP.py index 0b64b261..9d0f2dee 100644 --- a/test/test_fastNLP.py +++ b/test/test_fastNLP.py @@ -1,9 +1,13 @@ +import sys + +sys.path.append("..") from fastNLP.fastnlp import FastNLP +PATH_TO_CWS_PICKLE_FILES = "/home/zyfeng/data/save/" def word_seg(): - nlp = FastNLP("./data_for_tests/") - nlp.load("seq_label_model") + nlp = FastNLP(model_dir=PATH_TO_CWS_PICKLE_FILES) + nlp.load("cws_basic_model", config_file="cws.cfg", section_name="POS_test") text = "这是最好的基于深度学习的中文分词系统。" result = nlp.run(text) print(result) @@ -20,4 +24,4 @@ def text_class(): if __name__ == "__main__": - text_class() + word_seg()