Browse Source

调整CWS函数的位置

tags/v0.2.0
yh_cc 6 years ago
parent
commit
de3feeaf5a
3 changed files with 72 additions and 86 deletions
  1. +1
    -0
      fastNLP/api/cws.py
  2. +9
    -65
      reproduction/chinese_word_segment/train_context.py
  3. +62
    -21
      reproduction/chinese_word_segment/utils.py

+ 1
- 0
fastNLP/api/cws.py View File

@@ -30,3 +30,4 @@ class CWS(API):
# 4. TODO 这里应该要交给一个iterator一样的东西预测这个结果 # 4. TODO 这里应该要交给一个iterator一样的东西预测这个结果


# 5. TODO 得到结果,需要考虑是否需要反转回去, 及post_process的操作 # 5. TODO 得到结果,需要考虑是否需要反转回去, 及post_process的操作

+ 9
- 65
reproduction/chinese_word_segment/train_context.py View File

@@ -17,6 +17,8 @@ from reproduction.chinese_word_segment.process.span_converter import EmailConver
from reproduction.chinese_word_segment.cws_io.cws_reader import NaiveCWSReader from reproduction.chinese_word_segment.cws_io.cws_reader import NaiveCWSReader
from reproduction.chinese_word_segment.models.cws_model import CWSBiLSTMSegApp from reproduction.chinese_word_segment.models.cws_model import CWSBiLSTMSegApp


from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1

ds_name = 'pku' ds_name = 'pku'
tr_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/middle_files/{}_train.txt'.format(ds_name, ds_name) tr_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/middle_files/{}_train.txt'.format(ds_name, ds_name)
dev_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/middle_files/{}_dev.txt'.format(ds_name, ds_name) dev_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/middle_files/{}_dev.txt'.format(ds_name, ds_name)
@@ -31,11 +33,11 @@ dev_dataset = reader.load(dev_filename)
fs2hs_proc = FullSpaceToHalfSpaceProcessor('raw_sentence') fs2hs_proc = FullSpaceToHalfSpaceProcessor('raw_sentence')


sp_proc = SpeicalSpanProcessor('raw_sentence', 'sentence') sp_proc = SpeicalSpanProcessor('raw_sentence', 'sentence')
sp_proc.add_span_converter(EmailConverter())
sp_proc.add_span_converter(MixNumAlphaConverter())
# sp_proc.add_span_converter(EmailConverter())
# sp_proc.add_span_converter(MixNumAlphaConverter())
sp_proc.add_span_converter(AlphaSpanConverter()) sp_proc.add_span_converter(AlphaSpanConverter())
sp_proc.add_span_converter(DigitSpanConverter()) sp_proc.add_span_converter(DigitSpanConverter())
sp_proc.add_span_converter(TimeConverter())
# sp_proc.add_span_converter(TimeConverter())




char_proc = CWSCharSegProcessor('sentence', 'chars_list') char_proc = CWSCharSegProcessor('sentence', 'chars_list')
@@ -86,68 +88,6 @@ print("Vocab size:{}, bigram size:{}.".format(char_vocab_proc.get_vocab_size(),




# 3. 得到数据集可以用于训练了 # 3. 得到数据集可以用于训练了
from itertools import chain

def refine_ys_on_seq_len(ys, seq_lens):
refined_ys = []
for b_idx, length in enumerate(seq_lens):
refined_ys.append(list(ys[b_idx][:length]))

return refined_ys

def flat_nested_list(nested_list):
return list(chain(*nested_list))

def calculate_pre_rec_f1(model, batcher):
true_ys, pred_ys = decode_iterator(model, batcher)

true_ys = flat_nested_list(true_ys)
pred_ys = flat_nested_list(pred_ys)

cor_num = 0
yp_wordnum = pred_ys.count(1)
yt_wordnum = true_ys.count(1)
start = 0
for i in range(len(true_ys)):
if true_ys[i] == 1:
flag = True
for j in range(start, i + 1):
if true_ys[j] != pred_ys[j]:
flag = False
break
if flag:
cor_num += 1
start = i + 1
P = cor_num / (float(yp_wordnum) + 1e-6)
R = cor_num / (float(yt_wordnum) + 1e-6)
F = 2 * P * R / (P + R + 1e-6)
return P, R, F

def decode_iterator(model, batcher):
true_ys = []
pred_ys = []
seq_lens = []
with torch.no_grad():
model.eval()
for batch_x, batch_y in batcher:
pred_dict = model(batch_x)
seq_len = pred_dict['seq_lens'].cpu().numpy()
probs = pred_dict['pred_probs']
_, pred_y = probs.max(dim=-1)
true_y = batch_y['tags']
pred_y = pred_y.cpu().numpy()
true_y = true_y.cpu().numpy()

true_ys.extend(list(true_y))
pred_ys.extend(list(pred_y))
seq_lens.extend(list(seq_len))
model.train()

true_ys = refine_ys_on_seq_len(true_ys, seq_lens)
pred_ys = refine_ys_on_seq_len(pred_ys, seq_lens)

return true_ys, pred_ys

# TODO pretrain的embedding是怎么解决的? # TODO pretrain的embedding是怎么解决的?


from reproduction.chinese_word_segment.utils import FocalLoss from reproduction.chinese_word_segment.utils import FocalLoss
@@ -255,4 +195,8 @@ print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1 * 100,
pre * 100, pre * 100,
rec * 100)) rec * 100))


# TODO 这里貌似需要区分test pipeline与dev pipeline


# TODO 还需要考虑如何替换回原文的问题?
# 1. 不需要将特殊tag替换
# 2. 需要将特殊tag替换回去

+ 62
- 21
reproduction/chinese_word_segment/utils.py View File

@@ -12,27 +12,68 @@ def seq_lens_to_mask(seq_lens):
return masks return masks




def cut_long_training_sentences(sentences, max_sample_length=200):
cutted_sentence = []
for sent in sentences:
sent_no_space = sent.replace(' ', '')
if len(sent_no_space) > max_sample_length:
parts = sent.strip().split()
new_line = ''
length = 0
for part in parts:
length += len(part)
new_line += part + ' '
if length > max_sample_length:
new_line = new_line[:-1]
cutted_sentence.append(new_line)
length = 0
new_line = ''
if new_line != '':
cutted_sentence.append(new_line[:-1])
else:
cutted_sentence.append(sent)
return cutted_sentence
from itertools import chain

def refine_ys_on_seq_len(ys, seq_lens):
refined_ys = []
for b_idx, length in enumerate(seq_lens):
refined_ys.append(list(ys[b_idx][:length]))

return refined_ys

def flat_nested_list(nested_list):
return list(chain(*nested_list))

def calculate_pre_rec_f1(model, batcher):
true_ys, pred_ys = decode_iterator(model, batcher)

true_ys = flat_nested_list(true_ys)
pred_ys = flat_nested_list(pred_ys)

cor_num = 0
yp_wordnum = pred_ys.count(1)
yt_wordnum = true_ys.count(1)
start = 0
for i in range(len(true_ys)):
if true_ys[i] == 1:
flag = True
for j in range(start, i + 1):
if true_ys[j] != pred_ys[j]:
flag = False
break
if flag:
cor_num += 1
start = i + 1
P = cor_num / (float(yp_wordnum) + 1e-6)
R = cor_num / (float(yt_wordnum) + 1e-6)
F = 2 * P * R / (P + R + 1e-6)
return P, R, F


def decode_iterator(model, batcher):
true_ys = []
pred_ys = []
seq_lens = []
with torch.no_grad():
model.eval()
for batch_x, batch_y in batcher:
pred_dict = model(batch_x)
seq_len = pred_dict['seq_lens'].cpu().numpy()
probs = pred_dict['pred_probs']
_, pred_y = probs.max(dim=-1)
true_y = batch_y['tags']
pred_y = pred_y.cpu().numpy()
true_y = true_y.cpu().numpy()

true_ys.extend(list(true_y))
pred_ys.extend(list(pred_y))
seq_lens.extend(list(seq_len))
model.train()

true_ys = refine_ys_on_seq_len(true_ys, seq_lens)
pred_ys = refine_ys_on_seq_len(pred_ys, seq_lens)

return true_ys, pred_ys




from torch import nn from torch import nn


Loading…
Cancel
Save