Browse Source

edit fastnlp run method, get ready for CWS

- change inputs of fastnlp from string to list of strings
- adopt flexible outputs according to diff tasks
tags/v0.1.0
FengZiYjun 6 years ago
parent
commit
625b72691b
1 changed files with 57 additions and 36 deletions
  1. +57
    -36
      fastNLP/fastnlp.py

+ 57
- 36
fastNLP/fastnlp.py View File

@@ -7,12 +7,9 @@ mapping from model name to [URL, file_name.class_name, model_pickle_name]
Notice that the class of the model should be in "models" directory. Notice that the class of the model should be in "models" directory.


Example: Example:
"zh_pos_tag_model": ["www.fudan.edu.cn", "sequence_modeling.SeqLabeling", "saved_model.pkl"]
"""
FastNLP_MODEL_COLLECTION = {
"seq_label_model": { "seq_label_model": {
"url": "www.fudan.edu.cn", "url": "www.fudan.edu.cn",
"class": "sequence_modeling.SeqLabeling",
"class": "sequence_modeling.SeqLabeling", # file_name.class_name in models/
"pickle": "seq_label_model.pkl", "pickle": "seq_label_model.pkl",
"type": "seq_label" "type": "seq_label"
}, },
@@ -22,6 +19,14 @@ FastNLP_MODEL_COLLECTION = {
"pickle": "text_class_model.pkl", "pickle": "text_class_model.pkl",
"type": "text_class" "type": "text_class"
} }
"""
FastNLP_MODEL_COLLECTION = {
"cws_basic_model": {
"url": "",
"class": "sequence_modeling.AdvSeqLabel",
"pickle": "cws_basic_model_v_0.pkl",
"type": "seq_label"
}
} }


CONFIG_FILE_NAME = "config" CONFIG_FILE_NAME = "config"
@@ -82,19 +87,19 @@ class FastNLP(object):
def run(self, raw_input): def run(self, raw_input):
""" """
Perform inference over given input using the loaded model. Perform inference over given input using the loaded model.
:param raw_input: str, raw text
:param raw_input: list of string. Each list is an input query.
:return results: :return results:
""" """


infer = self._create_inference(self.model_dir) infer = self._create_inference(self.model_dir)


# string ---> 2-D list of string
infer_input = self.string_to_list(raw_input)
# tokenize: list of string ---> 2-D list of string
infer_input = self.tokenize(raw_input, language="zh")


# 2-D list of string ---> list of strings
# 2-D list of string ---> 2-D list of tags
results = infer.predict(self.model, infer_input) results = infer.predict(self.model, infer_input)


# list of strings ---> final answers
# 2-D list of tags ---> list of final answers
outputs = self._make_output(results, infer_input) outputs = self._make_output(results, infer_input)
return outputs return outputs


@@ -142,55 +147,71 @@ class FastNLP(object):
""" """
return True return True


def string_to_list(self, text, delimiter="\n"):
"""
This function is used to transform raw input to lists, which is done by DatasetLoader in training.
Split text string into three-level lists.
[
[word_11, word_12, ...],
[word_21, word_22, ...],
...
]
:param text: string
:param delimiter: str, character used to split text into sentences.
:return data: two-level lists
def tokenize(self, text, language):
"""Extract tokens from strings.
For English, extract words separated by space.
For Chinese, extract characters.
TODO: more complex tokenization methods

:param text: list of string
:param language: str, one of ('zh', 'en'), Chinese or English.
:return data: list of list of string, each string is a token.
""" """
data = [] data = []
sents = text.strip().split(delimiter)
for sent in sents:
characters = []
for ch in sent:
characters.append(ch)
data.append(characters)
delimiter = " " if language is "en" else ""
for sent in text:
tokens = sent.strip().split(delimiter)
data.append(tokens)
return data return data


def _make_output(self, results, infer_input): def _make_output(self, results, infer_input):
"""Transform the infer output into user-friendly output.

:param results: 1 or 2-D list of strings.
If self.infer_type == "seq_label", it is of shape [num_examples, tag_seq_length]
If self.infer_type == "text_class", it is of shape [num_examples]
:param infer_input: 2-D list of string, the input query before inference.
:return outputs: list. Each entry is a prediction.
"""
if self.infer_type == "seq_label": if self.infer_type == "seq_label":
outputs = make_seq_label_output(results, infer_input) outputs = make_seq_label_output(results, infer_input)
elif self.infer_type == "text_class": elif self.infer_type == "text_class":
outputs = make_class_output(results, infer_input) outputs = make_class_output(results, infer_input)
else: else:
raise ValueError("fail to make outputs with infer type {}".format(self.infer_type))
raise RuntimeError("fail to make outputs with infer type {}".format(self.infer_type))
return outputs return outputs




def make_seq_label_output(result, infer_input): def make_seq_label_output(result, infer_input):
"""
Transform model output into user-friendly contents.
:param result: 1-D list of strings. (model output)
"""Transform model output into user-friendly contents.
:param result: 2-D list of strings. (model output)
:param infer_input: 2-D list of string (model input) :param infer_input: 2-D list of string (model input)
:return outputs:
:return ret: list of list of tuples
[
[(word_11, label_11), (word_12, label_12), ...],
[(word_21, label_21), (word_22, label_22), ...],
...
]
""" """
return result

ret = []
for example_x, example_y in zip(infer_input, result):
ret.append([tuple([x, y]) for x, y in zip(example_x, example_y)])
return ret


def make_class_output(result, infer_input): def make_class_output(result, infer_input):
"""Transform model output into user-friendly contents.

:param result: 2-D list of strings. (model output)
:param infer_input: 1-D list of string (model input)
:return ret: the same as result, [label_1, label_2, ...]
"""
return result return result




def interpret_word_seg_results(infer_input, results): def interpret_word_seg_results(infer_input, results):
"""
Transform model output into user-friendly contents.
"""Transform model output into user-friendly contents.
Example: In CWS, convert <BMES> labeling into segmented text. Example: In CWS, convert <BMES> labeling into segmented text.
:param results: list of strings. (model output) :param results: list of strings. (model output)
:param infer_input: 2-D list of string (model input) :param infer_input: 2-D list of string (model input)


Loading…
Cancel
Save