From 625b72691b755ffaaa59a264649020c85d92372d Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Wed, 29 Aug 2018 23:26:32 +0800 Subject: [PATCH] edit fastnlp run method, get ready for CWS - change inputs of fastnlp from string to list of strings - adopt flexible outputs according to diff tasks --- fastNLP/fastnlp.py | 93 ++++++++++++++++++++++++++++------------------ 1 file changed, 57 insertions(+), 36 deletions(-) diff --git a/fastNLP/fastnlp.py b/fastNLP/fastnlp.py index 6339c11a..5be5cad3 100644 --- a/fastNLP/fastnlp.py +++ b/fastNLP/fastnlp.py @@ -7,12 +7,9 @@ mapping from model name to [URL, file_name.class_name, model_pickle_name] Notice that the class of the model should be in "models" directory. Example: - "zh_pos_tag_model": ["www.fudan.edu.cn", "sequence_modeling.SeqLabeling", "saved_model.pkl"] -""" -FastNLP_MODEL_COLLECTION = { "seq_label_model": { "url": "www.fudan.edu.cn", - "class": "sequence_modeling.SeqLabeling", + "class": "sequence_modeling.SeqLabeling", # file_name.class_name in models/ "pickle": "seq_label_model.pkl", "type": "seq_label" }, @@ -22,6 +19,14 @@ FastNLP_MODEL_COLLECTION = { "pickle": "text_class_model.pkl", "type": "text_class" } +""" +FastNLP_MODEL_COLLECTION = { + "cws_basic_model": { + "url": "", + "class": "sequence_modeling.AdvSeqLabel", + "pickle": "cws_basic_model_v_0.pkl", + "type": "seq_label" + } } CONFIG_FILE_NAME = "config" @@ -82,19 +87,19 @@ class FastNLP(object): def run(self, raw_input): """ Perform inference over given input using the loaded model. - :param raw_input: str, raw text + :param raw_input: list of string. Each list is an input query. :return results: """ infer = self._create_inference(self.model_dir) - # string ---> 2-D list of string - infer_input = self.string_to_list(raw_input) + # tokenize: list of string ---> 2-D list of string + infer_input = self.tokenize(raw_input, language="zh") - # 2-D list of string ---> list of strings + # 2-D list of string ---> 2-D list of tags results = infer.predict(self.model, infer_input) - # list of strings ---> final answers + # 2-D list of tags ---> list of final answers outputs = self._make_output(results, infer_input) return outputs @@ -142,55 +147,71 @@ class FastNLP(object): """ return True - def string_to_list(self, text, delimiter="\n"): - """ - This function is used to transform raw input to lists, which is done by DatasetLoader in training. - Split text string into three-level lists. - [ - [word_11, word_12, ...], - [word_21, word_22, ...], - ... - ] - :param text: string - :param delimiter: str, character used to split text into sentences. - :return data: two-level lists + def tokenize(self, text, language): + """Extract tokens from strings. + For English, extract words separated by space. + For Chinese, extract characters. + TODO: more complex tokenization methods + + :param text: list of string + :param language: str, one of ('zh', 'en'), Chinese or English. + :return data: list of list of string, each string is a token. """ data = [] - sents = text.strip().split(delimiter) - for sent in sents: - characters = [] - for ch in sent: - characters.append(ch) - data.append(characters) + delimiter = " " if language is "en" else "" + for sent in text: + tokens = sent.strip().split(delimiter) + data.append(tokens) return data def _make_output(self, results, infer_input): + """Transform the infer output into user-friendly output. + + :param results: 1 or 2-D list of strings. + If self.infer_type == "seq_label", it is of shape [num_examples, tag_seq_length] + If self.infer_type == "text_class", it is of shape [num_examples] + :param infer_input: 2-D list of string, the input query before inference. + :return outputs: list. Each entry is a prediction. + """ if self.infer_type == "seq_label": outputs = make_seq_label_output(results, infer_input) elif self.infer_type == "text_class": outputs = make_class_output(results, infer_input) else: - raise ValueError("fail to make outputs with infer type {}".format(self.infer_type)) + raise RuntimeError("fail to make outputs with infer type {}".format(self.infer_type)) return outputs def make_seq_label_output(result, infer_input): - """ - Transform model output into user-friendly contents. - :param result: 1-D list of strings. (model output) + """Transform model output into user-friendly contents. + + :param result: 2-D list of strings. (model output) :param infer_input: 2-D list of string (model input) - :return outputs: + :return ret: list of list of tuples + [ + [(word_11, label_11), (word_12, label_12), ...], + [(word_21, label_21), (word_22, label_22), ...], + ... + ] """ - return result - + ret = [] + for example_x, example_y in zip(infer_input, result): + ret.append([tuple([x, y]) for x, y in zip(example_x, example_y)]) + return ret def make_class_output(result, infer_input): + """Transform model output into user-friendly contents. + + :param result: 2-D list of strings. (model output) + :param infer_input: 1-D list of string (model input) + :return ret: the same as result, [label_1, label_2, ...] + """ return result def interpret_word_seg_results(infer_input, results): - """ - Transform model output into user-friendly contents. + """Transform model output into user-friendly contents. + Example: In CWS, convert labeling into segmented text. :param results: list of strings. (model output) :param infer_input: 2-D list of string (model input)