From de3feeaf5aca2529585b7572cd1d16d4dfcf4865 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Sat, 10 Nov 2018 20:10:13 +0800 Subject: [PATCH] =?UTF-8?q?=E8=B0=83=E6=95=B4CWS=E5=87=BD=E6=95=B0?= =?UTF-8?q?=E7=9A=84=E4=BD=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/api/cws.py | 1 + .../chinese_word_segment/train_context.py | 74 ++--------------- reproduction/chinese_word_segment/utils.py | 83 ++++++++++++++----- 3 files changed, 72 insertions(+), 86 deletions(-) diff --git a/fastNLP/api/cws.py b/fastNLP/api/cws.py index ea6f96e6..1f3c08d2 100644 --- a/fastNLP/api/cws.py +++ b/fastNLP/api/cws.py @@ -30,3 +30,4 @@ class CWS(API): # 4. TODO 这里应该要交给一个iterator一样的东西预测这个结果 # 5. TODO 得到结果,需要考虑是否需要反转回去, 及post_process的操作 + \ No newline at end of file diff --git a/reproduction/chinese_word_segment/train_context.py b/reproduction/chinese_word_segment/train_context.py index 21b7ab89..f0b2e3f1 100644 --- a/reproduction/chinese_word_segment/train_context.py +++ b/reproduction/chinese_word_segment/train_context.py @@ -17,6 +17,8 @@ from reproduction.chinese_word_segment.process.span_converter import EmailConver from reproduction.chinese_word_segment.cws_io.cws_reader import NaiveCWSReader from reproduction.chinese_word_segment.models.cws_model import CWSBiLSTMSegApp +from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1 + ds_name = 'pku' tr_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/middle_files/{}_train.txt'.format(ds_name, ds_name) dev_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/middle_files/{}_dev.txt'.format(ds_name, ds_name) @@ -31,11 +33,11 @@ dev_dataset = reader.load(dev_filename) fs2hs_proc = FullSpaceToHalfSpaceProcessor('raw_sentence') sp_proc = SpeicalSpanProcessor('raw_sentence', 'sentence') -sp_proc.add_span_converter(EmailConverter()) -sp_proc.add_span_converter(MixNumAlphaConverter()) +# sp_proc.add_span_converter(EmailConverter()) +# sp_proc.add_span_converter(MixNumAlphaConverter()) sp_proc.add_span_converter(AlphaSpanConverter()) sp_proc.add_span_converter(DigitSpanConverter()) -sp_proc.add_span_converter(TimeConverter()) +# sp_proc.add_span_converter(TimeConverter()) char_proc = CWSCharSegProcessor('sentence', 'chars_list') @@ -86,68 +88,6 @@ print("Vocab size:{}, bigram size:{}.".format(char_vocab_proc.get_vocab_size(), # 3. 得到数据集可以用于训练了 -from itertools import chain - -def refine_ys_on_seq_len(ys, seq_lens): - refined_ys = [] - for b_idx, length in enumerate(seq_lens): - refined_ys.append(list(ys[b_idx][:length])) - - return refined_ys - -def flat_nested_list(nested_list): - return list(chain(*nested_list)) - -def calculate_pre_rec_f1(model, batcher): - true_ys, pred_ys = decode_iterator(model, batcher) - - true_ys = flat_nested_list(true_ys) - pred_ys = flat_nested_list(pred_ys) - - cor_num = 0 - yp_wordnum = pred_ys.count(1) - yt_wordnum = true_ys.count(1) - start = 0 - for i in range(len(true_ys)): - if true_ys[i] == 1: - flag = True - for j in range(start, i + 1): - if true_ys[j] != pred_ys[j]: - flag = False - break - if flag: - cor_num += 1 - start = i + 1 - P = cor_num / (float(yp_wordnum) + 1e-6) - R = cor_num / (float(yt_wordnum) + 1e-6) - F = 2 * P * R / (P + R + 1e-6) - return P, R, F - -def decode_iterator(model, batcher): - true_ys = [] - pred_ys = [] - seq_lens = [] - with torch.no_grad(): - model.eval() - for batch_x, batch_y in batcher: - pred_dict = model(batch_x) - seq_len = pred_dict['seq_lens'].cpu().numpy() - probs = pred_dict['pred_probs'] - _, pred_y = probs.max(dim=-1) - true_y = batch_y['tags'] - pred_y = pred_y.cpu().numpy() - true_y = true_y.cpu().numpy() - - true_ys.extend(list(true_y)) - pred_ys.extend(list(pred_y)) - seq_lens.extend(list(seq_len)) - model.train() - - true_ys = refine_ys_on_seq_len(true_ys, seq_lens) - pred_ys = refine_ys_on_seq_len(pred_ys, seq_lens) - - return true_ys, pred_ys - # TODO pretrain的embedding是怎么解决的? from reproduction.chinese_word_segment.utils import FocalLoss @@ -255,4 +195,8 @@ print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1 * 100, pre * 100, rec * 100)) +# TODO 这里貌似需要区分test pipeline与dev pipeline +# TODO 还需要考虑如何替换回原文的问题? +# 1. 不需要将特殊tag替换 +# 2. 需要将特殊tag替换回去 \ No newline at end of file diff --git a/reproduction/chinese_word_segment/utils.py b/reproduction/chinese_word_segment/utils.py index 92cd19d1..9411c9f2 100644 --- a/reproduction/chinese_word_segment/utils.py +++ b/reproduction/chinese_word_segment/utils.py @@ -12,27 +12,68 @@ def seq_lens_to_mask(seq_lens): return masks -def cut_long_training_sentences(sentences, max_sample_length=200): - cutted_sentence = [] - for sent in sentences: - sent_no_space = sent.replace(' ', '') - if len(sent_no_space) > max_sample_length: - parts = sent.strip().split() - new_line = '' - length = 0 - for part in parts: - length += len(part) - new_line += part + ' ' - if length > max_sample_length: - new_line = new_line[:-1] - cutted_sentence.append(new_line) - length = 0 - new_line = '' - if new_line != '': - cutted_sentence.append(new_line[:-1]) - else: - cutted_sentence.append(sent) - return cutted_sentence +from itertools import chain + +def refine_ys_on_seq_len(ys, seq_lens): + refined_ys = [] + for b_idx, length in enumerate(seq_lens): + refined_ys.append(list(ys[b_idx][:length])) + + return refined_ys + +def flat_nested_list(nested_list): + return list(chain(*nested_list)) + +def calculate_pre_rec_f1(model, batcher): + true_ys, pred_ys = decode_iterator(model, batcher) + + true_ys = flat_nested_list(true_ys) + pred_ys = flat_nested_list(pred_ys) + + cor_num = 0 + yp_wordnum = pred_ys.count(1) + yt_wordnum = true_ys.count(1) + start = 0 + for i in range(len(true_ys)): + if true_ys[i] == 1: + flag = True + for j in range(start, i + 1): + if true_ys[j] != pred_ys[j]: + flag = False + break + if flag: + cor_num += 1 + start = i + 1 + P = cor_num / (float(yp_wordnum) + 1e-6) + R = cor_num / (float(yt_wordnum) + 1e-6) + F = 2 * P * R / (P + R + 1e-6) + return P, R, F + + +def decode_iterator(model, batcher): + true_ys = [] + pred_ys = [] + seq_lens = [] + with torch.no_grad(): + model.eval() + for batch_x, batch_y in batcher: + pred_dict = model(batch_x) + seq_len = pred_dict['seq_lens'].cpu().numpy() + probs = pred_dict['pred_probs'] + _, pred_y = probs.max(dim=-1) + true_y = batch_y['tags'] + pred_y = pred_y.cpu().numpy() + true_y = true_y.cpu().numpy() + + true_ys.extend(list(true_y)) + pred_ys.extend(list(pred_y)) + seq_lens.extend(list(seq_len)) + model.train() + + true_ys = refine_ys_on_seq_len(true_ys, seq_lens) + pred_ys = refine_ys_on_seq_len(pred_ys, seq_lens) + + return true_ys, pred_ys from torch import nn