- see test_fastNLP.py - update interpret_word_seg_results in fastnlp.py - delete useless data to increase git clone speedtags/v0.1.0
@@ -216,7 +216,7 @@ def make_seq_label_output(result, infer_input): | |||||
""" | """ | ||||
ret = [] | ret = [] | ||||
for example_x, example_y in zip(infer_input, result): | for example_x, example_y in zip(infer_input, result): | ||||
ret.append([tuple([x, y]) for x, y in zip(example_x, example_y)]) | |||||
ret.append([(x, y) for x, y in zip(example_x, example_y)]) | |||||
return ret | return ret | ||||
def make_class_output(result, infer_input): | def make_class_output(result, infer_input): | ||||
@@ -229,35 +229,33 @@ def make_class_output(result, infer_input): | |||||
return result | return result | ||||
def interpret_word_seg_results(infer_input, results): | |||||
def interpret_word_seg_results(char_seq, label_seq): | |||||
"""Transform model output into user-friendly contents. | """Transform model output into user-friendly contents. | ||||
Example: In CWS, convert <BMES> labeling into segmented text. | Example: In CWS, convert <BMES> labeling into segmented text. | ||||
:param results: list of strings. (model output) | |||||
:param infer_input: 2-D list of string (model input) | |||||
:return output: list of strings | |||||
:param char_seq: list of string, | |||||
:param label_seq: list of string, the same length as char_seq | |||||
Each entry is one of ('B', 'M', 'E', 'S'). | |||||
:return output: list of words | |||||
""" | """ | ||||
outputs = [] | |||||
for sent_char, sent_label in zip(infer_input, results): | |||||
words = [] | |||||
word = "" | |||||
for char, label in zip(sent_char, sent_label): | |||||
if label[0] == "B": | |||||
if word != "": | |||||
words.append(word) | |||||
word = char | |||||
elif label[0] == "M": | |||||
word += char | |||||
elif label[0] == "E": | |||||
word += char | |||||
words = [] | |||||
word = "" | |||||
for char, label in zip(char_seq, label_seq): | |||||
if label[0] == "B": | |||||
if word != "": | |||||
words.append(word) | words.append(word) | ||||
word = "" | |||||
elif label[0] == "S": | |||||
if word != "": | |||||
words.append(word) | |||||
word = "" | |||||
words.append(char) | |||||
else: | |||||
raise ValueError("invalid label") | |||||
outputs.append(" ".join(words)) | |||||
return outputs | |||||
word = char | |||||
elif label[0] == "M": | |||||
word += char | |||||
elif label[0] == "E": | |||||
word += char | |||||
words.append(word) | |||||
word = "" | |||||
elif label[0] == "S": | |||||
if word != "": | |||||
words.append(word) | |||||
word = "" | |||||
words.append(char) | |||||
else: | |||||
raise ValueError("invalid label {}".format(label[0])) | |||||
return words |
@@ -1,17 +1,24 @@ | |||||
import sys | import sys | ||||
sys.path.append("..") | sys.path.append("..") | ||||
from fastNLP.fastnlp import FastNLP | from fastNLP.fastnlp import FastNLP | ||||
from fastNLP.fastnlp import interpret_word_seg_results | |||||
PATH_TO_CWS_PICKLE_FILES = "/home/zyfeng/fastNLP/reproduction/chinese_word_segment/save/" | PATH_TO_CWS_PICKLE_FILES = "/home/zyfeng/fastNLP/reproduction/chinese_word_segment/save/" | ||||
def word_seg(): | def word_seg(): | ||||
nlp = FastNLP(model_dir=PATH_TO_CWS_PICKLE_FILES) | nlp = FastNLP(model_dir=PATH_TO_CWS_PICKLE_FILES) | ||||
nlp.load("cws_basic_model", config_file="cws.cfg", section_name="POS_test") | nlp.load("cws_basic_model", config_file="cws.cfg", section_name="POS_test") | ||||
text = "这是最好的基于深度学习的中文分词系统。" | |||||
result = nlp.run(text) | |||||
print(result) | |||||
print("FastNLP finished!") | |||||
text = ["这是最好的基于深度学习的中文分词系统。", | |||||
"大王叫我来巡山。", | |||||
"我党多年来致力于改善人民生活水平。"] | |||||
results = nlp.run(text) | |||||
print(results) | |||||
for example in results: | |||||
words, labels = [], [] | |||||
for res in example: | |||||
words.append(res[0]) | |||||
labels.append(res[1]) | |||||
print(interpret_word_seg_results(words, labels)) | |||||
def text_class(): | def text_class(): | ||||
@@ -23,5 +30,14 @@ def text_class(): | |||||
print("FastNLP finished!") | print("FastNLP finished!") | ||||
def test_word_seg_interpret(): | |||||
foo = [[('这', 'S'), ('是', 'S'), ('最', 'S'), ('好', 'S'), ('的', 'S'), ('基', 'B'), ('于', 'E'), ('深', 'B'), ('度', 'E'), | |||||
('学', 'B'), ('习', 'E'), ('的', 'S'), ('中', 'B'), ('文', 'E'), ('分', 'B'), ('词', 'E'), ('系', 'B'), ('统', 'E'), | |||||
('。', 'S')]] | |||||
chars = [x[0] for x in foo[0]] | |||||
labels = [x[1] for x in foo[0]] | |||||
print(interpret_word_seg_results(chars, labels)) | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
word_seg() | word_seg() |