Browse Source

optimize CWS example

- see test_fastNLP.py
- update interpret_word_seg_results in fastnlp.py
- delete useless data to increase git clone speed
tags/v0.1.0
FengZiYjun 6 years ago
parent
commit
501ffb26c5
7 changed files with 47 additions and 10695 deletions
  1. +26
    -28
      fastNLP/fastnlp.py
  2. +0
    -5331
      reproduction/CNN-sentence_classification/rt-polaritydata/rt-polarity.neg
  3. +0
    -5331
      reproduction/CNN-sentence_classification/rt-polaritydata/rt-polarity.pos
  4. BIN
      reproduction/HAN-document_classification/data/test_samples.pkl
  5. BIN
      reproduction/HAN-document_classification/data/train_samples.pkl
  6. BIN
      reproduction/HAN-document_classification/data/yelp.word2vec
  7. +21
    -5
      test/test_fastNLP.py

+ 26
- 28
fastNLP/fastnlp.py View File

@@ -216,7 +216,7 @@ def make_seq_label_output(result, infer_input):
"""
ret = []
for example_x, example_y in zip(infer_input, result):
ret.append([tuple([x, y]) for x, y in zip(example_x, example_y)])
ret.append([(x, y) for x, y in zip(example_x, example_y)])
return ret

def make_class_output(result, infer_input):
@@ -229,35 +229,33 @@ def make_class_output(result, infer_input):
return result


def interpret_word_seg_results(infer_input, results):
def interpret_word_seg_results(char_seq, label_seq):
"""Transform model output into user-friendly contents.

Example: In CWS, convert <BMES> labeling into segmented text.
:param results: list of strings. (model output)
:param infer_input: 2-D list of string (model input)
:return output: list of strings
:param char_seq: list of string,
:param label_seq: list of string, the same length as char_seq
Each entry is one of ('B', 'M', 'E', 'S').
:return output: list of words
"""
outputs = []
for sent_char, sent_label in zip(infer_input, results):
words = []
word = ""
for char, label in zip(sent_char, sent_label):
if label[0] == "B":
if word != "":
words.append(word)
word = char
elif label[0] == "M":
word += char
elif label[0] == "E":
word += char
words = []
word = ""
for char, label in zip(char_seq, label_seq):
if label[0] == "B":
if word != "":
words.append(word)
word = ""
elif label[0] == "S":
if word != "":
words.append(word)
word = ""
words.append(char)
else:
raise ValueError("invalid label")
outputs.append(" ".join(words))
return outputs
word = char
elif label[0] == "M":
word += char
elif label[0] == "E":
word += char
words.append(word)
word = ""
elif label[0] == "S":
if word != "":
words.append(word)
word = ""
words.append(char)
else:
raise ValueError("invalid label {}".format(label[0]))
return words

+ 0
- 5331
reproduction/CNN-sentence_classification/rt-polaritydata/rt-polarity.neg
File diff suppressed because it is too large
View File


+ 0
- 5331
reproduction/CNN-sentence_classification/rt-polaritydata/rt-polarity.pos
File diff suppressed because it is too large
View File


BIN
reproduction/HAN-document_classification/data/test_samples.pkl View File


BIN
reproduction/HAN-document_classification/data/train_samples.pkl View File


BIN
reproduction/HAN-document_classification/data/yelp.word2vec View File


+ 21
- 5
test/test_fastNLP.py View File

@@ -1,17 +1,24 @@
import sys

sys.path.append("..")
from fastNLP.fastnlp import FastNLP
from fastNLP.fastnlp import interpret_word_seg_results

PATH_TO_CWS_PICKLE_FILES = "/home/zyfeng/fastNLP/reproduction/chinese_word_segment/save/"

def word_seg():
nlp = FastNLP(model_dir=PATH_TO_CWS_PICKLE_FILES)
nlp.load("cws_basic_model", config_file="cws.cfg", section_name="POS_test")
text = "这是最好的基于深度学习的中文分词系统。"
result = nlp.run(text)
print(result)
print("FastNLP finished!")
text = ["这是最好的基于深度学习的中文分词系统。",
"大王叫我来巡山。",
"我党多年来致力于改善人民生活水平。"]
results = nlp.run(text)
print(results)
for example in results:
words, labels = [], []
for res in example:
words.append(res[0])
labels.append(res[1])
print(interpret_word_seg_results(words, labels))


def text_class():
@@ -23,5 +30,14 @@ def text_class():
print("FastNLP finished!")


def test_word_seg_interpret():
foo = [[('这', 'S'), ('是', 'S'), ('最', 'S'), ('好', 'S'), ('的', 'S'), ('基', 'B'), ('于', 'E'), ('深', 'B'), ('度', 'E'),
('学', 'B'), ('习', 'E'), ('的', 'S'), ('中', 'B'), ('文', 'E'), ('分', 'B'), ('词', 'E'), ('系', 'B'), ('统', 'E'),
('。', 'S')]]
chars = [x[0] for x in foo[0]]
labels = [x[1] for x in foo[0]]
print(interpret_word_seg_results(chars, labels))


if __name__ == "__main__":
word_seg()

Loading…
Cancel
Save