|
@@ -1,4 +1,4 @@ |
|
|
from fastNLP.core.inference import Inference |
|
|
|
|
|
|
|
|
from fastNLP.core.inference import SeqLabelInfer, ClassificationInfer |
|
|
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection |
|
|
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection |
|
|
from fastNLP.loader.model_loader import ModelLoader |
|
|
from fastNLP.loader.model_loader import ModelLoader |
|
|
|
|
|
|
|
@@ -10,14 +10,28 @@ Example: |
|
|
"zh_pos_tag_model": ["www.fudan.edu.cn", "sequence_modeling.SeqLabeling", "saved_model.pkl"] |
|
|
"zh_pos_tag_model": ["www.fudan.edu.cn", "sequence_modeling.SeqLabeling", "saved_model.pkl"] |
|
|
""" |
|
|
""" |
|
|
FastNLP_MODEL_COLLECTION = { |
|
|
FastNLP_MODEL_COLLECTION = { |
|
|
"zh_pos_tag_model": ["www.fudan.edu.cn", "sequence_modeling.SeqLabeling", "saved_model.pkl"] |
|
|
|
|
|
|
|
|
"seq_label_model": { |
|
|
|
|
|
"url": "www.fudan.edu.cn", |
|
|
|
|
|
"class": "sequence_modeling.SeqLabeling", |
|
|
|
|
|
"pickle": "seq_label_model.pkl", |
|
|
|
|
|
"type": "seq_label" |
|
|
|
|
|
}, |
|
|
|
|
|
"text_class_model": { |
|
|
|
|
|
"url": "www.fudan.edu.cn", |
|
|
|
|
|
"class": "cnn_text_classification.CNNText", |
|
|
|
|
|
"pickle": "text_class_model.pkl", |
|
|
|
|
|
"type": "text_class" |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
CONFIG_FILE_NAME = "config" |
|
|
|
|
|
SECTION_NAME = "text_class_model" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FastNLP(object): |
|
|
class FastNLP(object): |
|
|
""" |
|
|
""" |
|
|
High-level interface for direct model inference. |
|
|
High-level interface for direct model inference. |
|
|
Usage: |
|
|
|
|
|
|
|
|
Example Usage: |
|
|
fastnlp = FastNLP() |
|
|
fastnlp = FastNLP() |
|
|
fastnlp.load("zh_pos_tag_model") |
|
|
fastnlp.load("zh_pos_tag_model") |
|
|
text = "这是最好的基于深度学习的中文分词系统。" |
|
|
text = "这是最好的基于深度学习的中文分词系统。" |
|
@@ -35,6 +49,7 @@ class FastNLP(object): |
|
|
""" |
|
|
""" |
|
|
self.model_dir = model_dir |
|
|
self.model_dir = model_dir |
|
|
self.model = None |
|
|
self.model = None |
|
|
|
|
|
self.infer_type = None # "seq_label"/"text_class" |
|
|
|
|
|
|
|
|
def load(self, model_name): |
|
|
def load(self, model_name): |
|
|
""" |
|
|
""" |
|
@@ -46,21 +61,21 @@ class FastNLP(object): |
|
|
raise ValueError("No FastNLP model named {}.".format(model_name)) |
|
|
raise ValueError("No FastNLP model named {}.".format(model_name)) |
|
|
|
|
|
|
|
|
if not self.model_exist(model_dir=self.model_dir): |
|
|
if not self.model_exist(model_dir=self.model_dir): |
|
|
self._download(model_name, FastNLP_MODEL_COLLECTION[model_name][0]) |
|
|
|
|
|
|
|
|
self._download(model_name, FastNLP_MODEL_COLLECTION[model_name]["url"]) |
|
|
|
|
|
|
|
|
model_class = self._get_model_class(FastNLP_MODEL_COLLECTION[model_name][1]) |
|
|
|
|
|
|
|
|
model_class = self._get_model_class(FastNLP_MODEL_COLLECTION[model_name]["class"]) |
|
|
|
|
|
|
|
|
model_args = ConfigSection() |
|
|
model_args = ConfigSection() |
|
|
# To do: customized config file for model init parameters |
|
|
|
|
|
ConfigLoader.load_config(self.model_dir + "config", {"POS_infer": model_args}) |
|
|
|
|
|
|
|
|
ConfigLoader.load_config(self.model_dir + CONFIG_FILE_NAME, {SECTION_NAME: model_args}) |
|
|
|
|
|
|
|
|
# Construct the model |
|
|
# Construct the model |
|
|
model = model_class(model_args) |
|
|
model = model_class(model_args) |
|
|
|
|
|
|
|
|
# To do: framework independent |
|
|
# To do: framework independent |
|
|
ModelLoader.load_pytorch(model, self.model_dir + FastNLP_MODEL_COLLECTION[model_name][2]) |
|
|
|
|
|
|
|
|
ModelLoader.load_pytorch(model, self.model_dir + FastNLP_MODEL_COLLECTION[model_name]["pickle"]) |
|
|
|
|
|
|
|
|
self.model = model |
|
|
self.model = model |
|
|
|
|
|
self.infer_type = FastNLP_MODEL_COLLECTION[model_name]["type"] |
|
|
|
|
|
|
|
|
print("Model loaded. ") |
|
|
print("Model loaded. ") |
|
|
|
|
|
|
|
@@ -71,12 +86,16 @@ class FastNLP(object): |
|
|
:return results: |
|
|
:return results: |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
infer = Inference(self.model_dir) |
|
|
|
|
|
|
|
|
infer = self._create_inference(self.model_dir) |
|
|
|
|
|
|
|
|
|
|
|
# string ---> 2-D list of string |
|
|
infer_input = self.string_to_list(raw_input) |
|
|
infer_input = self.string_to_list(raw_input) |
|
|
|
|
|
|
|
|
|
|
|
# 2-D list of string ---> list of strings |
|
|
results = infer.predict(self.model, infer_input) |
|
|
results = infer.predict(self.model, infer_input) |
|
|
|
|
|
|
|
|
outputs = self.make_output(results) |
|
|
|
|
|
|
|
|
# list of strings ---> final answers |
|
|
|
|
|
outputs = self._make_output(results, infer_input) |
|
|
return outputs |
|
|
return outputs |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
@staticmethod |
|
@@ -95,6 +114,14 @@ class FastNLP(object): |
|
|
module = getattr(module, sub) |
|
|
module = getattr(module, sub) |
|
|
return module |
|
|
return module |
|
|
|
|
|
|
|
|
|
|
|
def _create_inference(self, model_dir): |
|
|
|
|
|
if self.infer_type == "seq_label": |
|
|
|
|
|
return SeqLabelInfer(model_dir) |
|
|
|
|
|
elif self.infer_type == "text_class": |
|
|
|
|
|
return ClassificationInfer(model_dir) |
|
|
|
|
|
else: |
|
|
|
|
|
raise ValueError("fail to create inference instance") |
|
|
|
|
|
|
|
|
def _load(self, model_dir, model_name): |
|
|
def _load(self, model_dir, model_name): |
|
|
# To do |
|
|
# To do |
|
|
return 0 |
|
|
return 0 |
|
@@ -117,7 +144,6 @@ class FastNLP(object): |
|
|
|
|
|
|
|
|
def string_to_list(self, text, delimiter="\n"): |
|
|
def string_to_list(self, text, delimiter="\n"): |
|
|
""" |
|
|
""" |
|
|
For word seg only, currently. |
|
|
|
|
|
This function is used to transform raw input to lists, which is done by DatasetLoader in training. |
|
|
This function is used to transform raw input to lists, which is done by DatasetLoader in training. |
|
|
Split text string into three-level lists. |
|
|
Split text string into three-level lists. |
|
|
[ |
|
|
[ |
|
@@ -127,7 +153,7 @@ class FastNLP(object): |
|
|
] |
|
|
] |
|
|
:param text: string |
|
|
:param text: string |
|
|
:param delimiter: str, character used to split text into sentences. |
|
|
:param delimiter: str, character used to split text into sentences. |
|
|
:return data: three-level lists |
|
|
|
|
|
|
|
|
:return data: two-level lists |
|
|
""" |
|
|
""" |
|
|
data = [] |
|
|
data = [] |
|
|
sents = text.strip().split(delimiter) |
|
|
sents = text.strip().split(delimiter) |
|
@@ -136,38 +162,61 @@ class FastNLP(object): |
|
|
for ch in sent: |
|
|
for ch in sent: |
|
|
characters.append(ch) |
|
|
characters.append(ch) |
|
|
data.append(characters) |
|
|
data.append(characters) |
|
|
# To refactor: this is used in make_output |
|
|
|
|
|
self.data = data |
|
|
|
|
|
return data |
|
|
return data |
|
|
|
|
|
|
|
|
def make_output(self, results): |
|
|
|
|
|
""" |
|
|
|
|
|
Transform model output into user-friendly contents. |
|
|
|
|
|
Example: In CWS, convert <BMES> labeling into segmented text. |
|
|
|
|
|
:param results: |
|
|
|
|
|
:return: |
|
|
|
|
|
""" |
|
|
|
|
|
outputs = [] |
|
|
|
|
|
for sent_char, sent_label in zip(self.data, results): |
|
|
|
|
|
words = [] |
|
|
|
|
|
word = "" |
|
|
|
|
|
for char, label in zip(sent_char, sent_label): |
|
|
|
|
|
if label[0] == "B": |
|
|
|
|
|
if word != "": |
|
|
|
|
|
words.append(word) |
|
|
|
|
|
word = char |
|
|
|
|
|
elif label[0] == "M": |
|
|
|
|
|
word += char |
|
|
|
|
|
elif label[0] == "E": |
|
|
|
|
|
word += char |
|
|
|
|
|
words.append(word) |
|
|
|
|
|
word = "" |
|
|
|
|
|
elif label[0] == "S": |
|
|
|
|
|
if word != "": |
|
|
|
|
|
words.append(word) |
|
|
|
|
|
word = "" |
|
|
|
|
|
words.append(char) |
|
|
|
|
|
else: |
|
|
|
|
|
raise ValueError("invalid label") |
|
|
|
|
|
outputs.append(" ".join(words)) |
|
|
|
|
|
|
|
|
def _make_output(self, results, infer_input): |
|
|
|
|
|
if self.infer_type == "seq_label": |
|
|
|
|
|
outputs = make_seq_label_output(results, infer_input) |
|
|
|
|
|
elif self.infer_type == "text_class": |
|
|
|
|
|
outputs = make_class_output(results, infer_input) |
|
|
|
|
|
else: |
|
|
|
|
|
raise ValueError("fail to make outputs with infer type {}".format(self.infer_type)) |
|
|
return outputs |
|
|
return outputs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def make_seq_label_output(result, infer_input): |
|
|
|
|
|
""" |
|
|
|
|
|
Transform model output into user-friendly contents. |
|
|
|
|
|
:param result: 1-D list of strings. (model output) |
|
|
|
|
|
:param infer_input: 2-D list of string (model input) |
|
|
|
|
|
:return outputs: |
|
|
|
|
|
""" |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def make_class_output(result, infer_input): |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def interpret_word_seg_results(infer_input, results): |
|
|
|
|
|
""" |
|
|
|
|
|
Transform model output into user-friendly contents. |
|
|
|
|
|
Example: In CWS, convert <BMES> labeling into segmented text. |
|
|
|
|
|
:param results: list of strings. (model output) |
|
|
|
|
|
:param infer_input: 2-D list of string (model input) |
|
|
|
|
|
:return output: list of strings |
|
|
|
|
|
""" |
|
|
|
|
|
outputs = [] |
|
|
|
|
|
for sent_char, sent_label in zip(infer_input, results): |
|
|
|
|
|
words = [] |
|
|
|
|
|
word = "" |
|
|
|
|
|
for char, label in zip(sent_char, sent_label): |
|
|
|
|
|
if label[0] == "B": |
|
|
|
|
|
if word != "": |
|
|
|
|
|
words.append(word) |
|
|
|
|
|
word = char |
|
|
|
|
|
elif label[0] == "M": |
|
|
|
|
|
word += char |
|
|
|
|
|
elif label[0] == "E": |
|
|
|
|
|
word += char |
|
|
|
|
|
words.append(word) |
|
|
|
|
|
word = "" |
|
|
|
|
|
elif label[0] == "S": |
|
|
|
|
|
if word != "": |
|
|
|
|
|
words.append(word) |
|
|
|
|
|
word = "" |
|
|
|
|
|
words.append(char) |
|
|
|
|
|
else: |
|
|
|
|
|
raise ValueError("invalid label") |
|
|
|
|
|
outputs.append(" ".join(words)) |
|
|
|
|
|
return outputs |