@@ -1,4 +1,5 @@
from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer
# from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer
from fastNLP.core.preprocess import load_pickle
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.loader.model_loader import ModelLoader
@@ -11,7 +12,9 @@ Example:
"url": "www.fudan.edu.cn",
"class": "sequence_modeling.SeqLabeling", # file_name.class_name in models/
"pickle": "seq_label_model.pkl",
"type": "seq_label"
"type": "seq_label",
"config_file_name": "config", # the name of the config file which stores model initialization parameters
"config_section_name": "text_class_model" # the name of the section in the config file which stores model init params
},
"text_class_model": {
"url": "www.fudan.edu.cn",
@@ -25,13 +28,12 @@ FastNLP_MODEL_COLLECTION = {
"url": "",
"class": "sequence_modeling.AdvSeqLabel",
"pickle": "cws_basic_model_v_0.pkl",
"type": "seq_label"
"type": "seq_label",
"config_file_name": "config",
"config_section_name": "text_class_model"
}
}
CONFIG_FILE_NAME = "config"
SECTION_NAME = "text_class_model"
class FastNLP(object):
"""
@@ -56,10 +58,13 @@ class FastNLP(object):
self.model = None
self.infer_type = None # "seq_label"/"text_class"
def load(self, model_name):
def load(self, model_name, config_file="config", section_name="model" ):
"""
Load a pre-trained FastNLP model together with additional data.
:param model_name: str, the name of a FastNLP model.
:param config_file: str, the name of the config file which stores the initialization information of the model.
(default: "config")
:param section_name: str, the name of the corresponding section in the config file. (default: model)
"""
assert type(model_name) is str
if model_name not in FastNLP_MODEL_COLLECTION:
@@ -71,7 +76,13 @@ class FastNLP(object):
model_class = self._get_model_class(FastNLP_MODEL_COLLECTION[model_name]["class"])
model_args = ConfigSection()
ConfigLoader.load_config(self.model_dir + CONFIG_FILE_NAME, {SECTION_NAME: model_args})
ConfigLoader.load_config(self.model_dir + config_file, {section_name: model_args})
# fetch dictionary size and number of labels from pickle files
word2index = load_pickle(self.model_dir, "word2id.pkl")
model_args["vocab_size"] = len(word2index)
index2label = load_pickle(self.model_dir, "id2class.pkl")
model_args["num_classes"] = len(index2label)
# Construct the model
model = model_class(model_args)