Browse Source

Prepare for CWS service:

- specify the name of the config file and the name of corresponding section where model init params store.
- fastnlp.py needs load_pickle to get dictionary size and the number of labels
- other minor adjustments
tags/v0.1.0
FengZiYjun 6 years ago
parent
commit
9d6b0daa99
4 changed files with 31 additions and 16 deletions
  1. +2
    -2
      fastNLP/core/preprocess.py
  2. +19
    -8
      fastNLP/fastnlp.py
  3. +3
    -3
      reproduction/chinese_word_segment/run.py
  4. +7
    -3
      test/test_fastNLP.py

+ 2
- 2
fastNLP/core/preprocess.py View File

@@ -19,13 +19,13 @@ DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1,
def save_pickle(obj, pickle_path, file_name): def save_pickle(obj, pickle_path, file_name):
with open(os.path.join(pickle_path, file_name), "wb") as f: with open(os.path.join(pickle_path, file_name), "wb") as f:
_pickle.dump(obj, f) _pickle.dump(obj, f)
print("{} saved. ".format(file_name))
print("{} saved in {}.".format(file_name, pickle_path))




def load_pickle(pickle_path, file_name): def load_pickle(pickle_path, file_name):
with open(os.path.join(pickle_path, file_name), "rb") as f: with open(os.path.join(pickle_path, file_name), "rb") as f:
obj = _pickle.load(f) obj = _pickle.load(f)
print("{} loaded. ".format(file_name))
print("{} loaded from {}.".format(file_name, pickle_path))
return obj return obj






+ 19
- 8
fastNLP/fastnlp.py View File

@@ -1,4 +1,5 @@
from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer
# from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer
from fastNLP.core.preprocess import load_pickle
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.loader.model_loader import ModelLoader from fastNLP.loader.model_loader import ModelLoader


@@ -11,7 +12,9 @@ Example:
"url": "www.fudan.edu.cn", "url": "www.fudan.edu.cn",
"class": "sequence_modeling.SeqLabeling", # file_name.class_name in models/ "class": "sequence_modeling.SeqLabeling", # file_name.class_name in models/
"pickle": "seq_label_model.pkl", "pickle": "seq_label_model.pkl",
"type": "seq_label"
"type": "seq_label",
"config_file_name": "config", # the name of the config file which stores model initialization parameters
"config_section_name": "text_class_model" # the name of the section in the config file which stores model init params
}, },
"text_class_model": { "text_class_model": {
"url": "www.fudan.edu.cn", "url": "www.fudan.edu.cn",
@@ -25,13 +28,12 @@ FastNLP_MODEL_COLLECTION = {
"url": "", "url": "",
"class": "sequence_modeling.AdvSeqLabel", "class": "sequence_modeling.AdvSeqLabel",
"pickle": "cws_basic_model_v_0.pkl", "pickle": "cws_basic_model_v_0.pkl",
"type": "seq_label"
"type": "seq_label",
"config_file_name": "config",
"config_section_name": "text_class_model"
} }
} }


CONFIG_FILE_NAME = "config"
SECTION_NAME = "text_class_model"



class FastNLP(object): class FastNLP(object):
""" """
@@ -56,10 +58,13 @@ class FastNLP(object):
self.model = None self.model = None
self.infer_type = None # "seq_label"/"text_class" self.infer_type = None # "seq_label"/"text_class"


def load(self, model_name):
def load(self, model_name, config_file="config", section_name="model"):
""" """
Load a pre-trained FastNLP model together with additional data. Load a pre-trained FastNLP model together with additional data.
:param model_name: str, the name of a FastNLP model. :param model_name: str, the name of a FastNLP model.
:param config_file: str, the name of the config file which stores the initialization information of the model.
(default: "config")
:param section_name: str, the name of the corresponding section in the config file. (default: model)
""" """
assert type(model_name) is str assert type(model_name) is str
if model_name not in FastNLP_MODEL_COLLECTION: if model_name not in FastNLP_MODEL_COLLECTION:
@@ -71,7 +76,13 @@ class FastNLP(object):
model_class = self._get_model_class(FastNLP_MODEL_COLLECTION[model_name]["class"]) model_class = self._get_model_class(FastNLP_MODEL_COLLECTION[model_name]["class"])


model_args = ConfigSection() model_args = ConfigSection()
ConfigLoader.load_config(self.model_dir + CONFIG_FILE_NAME, {SECTION_NAME: model_args})
ConfigLoader.load_config(self.model_dir + config_file, {section_name: model_args})

# fetch dictionary size and number of labels from pickle files
word2index = load_pickle(self.model_dir, "word2id.pkl")
model_args["vocab_size"] = len(word2index)
index2label = load_pickle(self.model_dir, "id2class.pkl")
model_args["num_classes"] = len(index2label)


# Construct the model # Construct the model
model = model_class(model_args) model = model_class(model_args)


+ 3
- 3
reproduction/chinese_word_segment/run.py View File

@@ -1,4 +1,5 @@
import sys, os
import os
import sys


sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))


@@ -11,7 +12,6 @@ from fastNLP.loader.model_loader import ModelLoader
from fastNLP.core.tester import SeqLabelTester from fastNLP.core.tester import SeqLabelTester
from fastNLP.models.sequence_modeling import AdvSeqLabel from fastNLP.models.sequence_modeling import AdvSeqLabel
from fastNLP.core.inference import SeqLabelInfer from fastNLP.core.inference import SeqLabelInfer
from fastNLP.core.optimizer import SGD


# not in the file's dir # not in the file's dir
if len(os.path.dirname(__file__)) != 0: if len(os.path.dirname(__file__)) != 0:
@@ -75,7 +75,7 @@ def train():
train_args["num_classes"] = p.num_classes train_args["num_classes"] = p.num_classes


# Trainer # Trainer
trainer = SeqLabelTrainer(train_args)
trainer = SeqLabelTrainer(**train_args.data)


# Model # Model
model = AdvSeqLabel(train_args) model = AdvSeqLabel(train_args)


+ 7
- 3
test/test_fastNLP.py View File

@@ -1,9 +1,13 @@
import sys

sys.path.append("..")
from fastNLP.fastnlp import FastNLP from fastNLP.fastnlp import FastNLP


PATH_TO_CWS_PICKLE_FILES = "/home/zyfeng/data/save/"


def word_seg(): def word_seg():
nlp = FastNLP("./data_for_tests/")
nlp.load("seq_label_model")
nlp = FastNLP(model_dir=PATH_TO_CWS_PICKLE_FILES)
nlp.load("cws_basic_model", config_file="cws.cfg", section_name="POS_test")
text = "这是最好的基于深度学习的中文分词系统。" text = "这是最好的基于深度学习的中文分词系统。"
result = nlp.run(text) result = nlp.run(text)
print(result) print(result)
@@ -20,4 +24,4 @@ def text_class():




if __name__ == "__main__": if __name__ == "__main__":
text_class()
word_seg()

Loading…
Cancel
Save