|
@@ -17,6 +17,8 @@ from reproduction.chinese_word_segment.process.span_converter import EmailConver |
|
|
from reproduction.chinese_word_segment.cws_io.cws_reader import NaiveCWSReader |
|
|
from reproduction.chinese_word_segment.cws_io.cws_reader import NaiveCWSReader |
|
|
from reproduction.chinese_word_segment.models.cws_model import CWSBiLSTMSegApp |
|
|
from reproduction.chinese_word_segment.models.cws_model import CWSBiLSTMSegApp |
|
|
|
|
|
|
|
|
|
|
|
from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1 |
|
|
|
|
|
|
|
|
ds_name = 'pku' |
|
|
ds_name = 'pku' |
|
|
tr_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/middle_files/{}_train.txt'.format(ds_name, ds_name) |
|
|
tr_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/middle_files/{}_train.txt'.format(ds_name, ds_name) |
|
|
dev_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/middle_files/{}_dev.txt'.format(ds_name, ds_name) |
|
|
dev_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/middle_files/{}_dev.txt'.format(ds_name, ds_name) |
|
@@ -31,11 +33,11 @@ dev_dataset = reader.load(dev_filename) |
|
|
fs2hs_proc = FullSpaceToHalfSpaceProcessor('raw_sentence') |
|
|
fs2hs_proc = FullSpaceToHalfSpaceProcessor('raw_sentence') |
|
|
|
|
|
|
|
|
sp_proc = SpeicalSpanProcessor('raw_sentence', 'sentence') |
|
|
sp_proc = SpeicalSpanProcessor('raw_sentence', 'sentence') |
|
|
sp_proc.add_span_converter(EmailConverter()) |
|
|
|
|
|
sp_proc.add_span_converter(MixNumAlphaConverter()) |
|
|
|
|
|
|
|
|
# sp_proc.add_span_converter(EmailConverter()) |
|
|
|
|
|
# sp_proc.add_span_converter(MixNumAlphaConverter()) |
|
|
sp_proc.add_span_converter(AlphaSpanConverter()) |
|
|
sp_proc.add_span_converter(AlphaSpanConverter()) |
|
|
sp_proc.add_span_converter(DigitSpanConverter()) |
|
|
sp_proc.add_span_converter(DigitSpanConverter()) |
|
|
sp_proc.add_span_converter(TimeConverter()) |
|
|
|
|
|
|
|
|
# sp_proc.add_span_converter(TimeConverter()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
char_proc = CWSCharSegProcessor('sentence', 'chars_list') |
|
|
char_proc = CWSCharSegProcessor('sentence', 'chars_list') |
|
@@ -86,68 +88,6 @@ print("Vocab size:{}, bigram size:{}.".format(char_vocab_proc.get_vocab_size(), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 3. 得到数据集可以用于训练了 |
|
|
# 3. 得到数据集可以用于训练了 |
|
|
from itertools import chain |
|
|
|
|
|
|
|
|
|
|
|
def refine_ys_on_seq_len(ys, seq_lens): |
|
|
|
|
|
refined_ys = [] |
|
|
|
|
|
for b_idx, length in enumerate(seq_lens): |
|
|
|
|
|
refined_ys.append(list(ys[b_idx][:length])) |
|
|
|
|
|
|
|
|
|
|
|
return refined_ys |
|
|
|
|
|
|
|
|
|
|
|
def flat_nested_list(nested_list): |
|
|
|
|
|
return list(chain(*nested_list)) |
|
|
|
|
|
|
|
|
|
|
|
def calculate_pre_rec_f1(model, batcher): |
|
|
|
|
|
true_ys, pred_ys = decode_iterator(model, batcher) |
|
|
|
|
|
|
|
|
|
|
|
true_ys = flat_nested_list(true_ys) |
|
|
|
|
|
pred_ys = flat_nested_list(pred_ys) |
|
|
|
|
|
|
|
|
|
|
|
cor_num = 0 |
|
|
|
|
|
yp_wordnum = pred_ys.count(1) |
|
|
|
|
|
yt_wordnum = true_ys.count(1) |
|
|
|
|
|
start = 0 |
|
|
|
|
|
for i in range(len(true_ys)): |
|
|
|
|
|
if true_ys[i] == 1: |
|
|
|
|
|
flag = True |
|
|
|
|
|
for j in range(start, i + 1): |
|
|
|
|
|
if true_ys[j] != pred_ys[j]: |
|
|
|
|
|
flag = False |
|
|
|
|
|
break |
|
|
|
|
|
if flag: |
|
|
|
|
|
cor_num += 1 |
|
|
|
|
|
start = i + 1 |
|
|
|
|
|
P = cor_num / (float(yp_wordnum) + 1e-6) |
|
|
|
|
|
R = cor_num / (float(yt_wordnum) + 1e-6) |
|
|
|
|
|
F = 2 * P * R / (P + R + 1e-6) |
|
|
|
|
|
return P, R, F |
|
|
|
|
|
|
|
|
|
|
|
def decode_iterator(model, batcher): |
|
|
|
|
|
true_ys = [] |
|
|
|
|
|
pred_ys = [] |
|
|
|
|
|
seq_lens = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
model.eval() |
|
|
|
|
|
for batch_x, batch_y in batcher: |
|
|
|
|
|
pred_dict = model(batch_x) |
|
|
|
|
|
seq_len = pred_dict['seq_lens'].cpu().numpy() |
|
|
|
|
|
probs = pred_dict['pred_probs'] |
|
|
|
|
|
_, pred_y = probs.max(dim=-1) |
|
|
|
|
|
true_y = batch_y['tags'] |
|
|
|
|
|
pred_y = pred_y.cpu().numpy() |
|
|
|
|
|
true_y = true_y.cpu().numpy() |
|
|
|
|
|
|
|
|
|
|
|
true_ys.extend(list(true_y)) |
|
|
|
|
|
pred_ys.extend(list(pred_y)) |
|
|
|
|
|
seq_lens.extend(list(seq_len)) |
|
|
|
|
|
model.train() |
|
|
|
|
|
|
|
|
|
|
|
true_ys = refine_ys_on_seq_len(true_ys, seq_lens) |
|
|
|
|
|
pred_ys = refine_ys_on_seq_len(pred_ys, seq_lens) |
|
|
|
|
|
|
|
|
|
|
|
return true_ys, pred_ys |
|
|
|
|
|
|
|
|
|
|
|
# TODO pretrain的embedding是怎么解决的? |
|
|
# TODO pretrain的embedding是怎么解决的? |
|
|
|
|
|
|
|
|
from reproduction.chinese_word_segment.utils import FocalLoss |
|
|
from reproduction.chinese_word_segment.utils import FocalLoss |
|
@@ -255,4 +195,8 @@ print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1 * 100, |
|
|
pre * 100, |
|
|
pre * 100, |
|
|
rec * 100)) |
|
|
rec * 100)) |
|
|
|
|
|
|
|
|
|
|
|
# TODO 这里貌似需要区分test pipeline与dev pipeline |
|
|
|
|
|
|
|
|
|
|
|
# TODO 还需要考虑如何替换回原文的问题? |
|
|
|
|
|
# 1. 不需要将特殊tag替换 |
|
|
|
|
|
# 2. 需要将特殊tag替换回去 |