|
|
@@ -11,11 +11,15 @@ from reproduction.chinese_word_segment.process.cws_processor import SeqLenProces |
|
|
|
|
|
|
|
from reproduction.chinese_word_segment.process.span_converter import AlphaSpanConverter |
|
|
|
from reproduction.chinese_word_segment.process.span_converter import DigitSpanConverter |
|
|
|
from reproduction.chinese_word_segment.process.span_converter import TimeConverter |
|
|
|
from reproduction.chinese_word_segment.process.span_converter import MixNumAlphaConverter |
|
|
|
from reproduction.chinese_word_segment.process.span_converter import EmailConverter |
|
|
|
from reproduction.chinese_word_segment.cws_io.cws_reader import NaiveCWSReader |
|
|
|
from reproduction.chinese_word_segment.models.cws_model import CWSBiLSTMSegApp |
|
|
|
|
|
|
|
tr_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/pku/middle_files/pku_train.txt' |
|
|
|
dev_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/pku/middle_files/pku_dev.txt' |
|
|
|
ds_name = 'pku' |
|
|
|
tr_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/middle_files/{}_train.txt'.format(ds_name, ds_name) |
|
|
|
dev_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/middle_files/{}_dev.txt'.format(ds_name, ds_name) |
|
|
|
|
|
|
|
reader = NaiveCWSReader() |
|
|
|
|
|
|
@@ -27,8 +31,12 @@ dev_dataset = reader.load(dev_filename) |
|
|
|
fs2hs_proc = FullSpaceToHalfSpaceProcessor('raw_sentence') |
|
|
|
|
|
|
|
sp_proc = SpeicalSpanProcessor('raw_sentence', 'sentence') |
|
|
|
sp_proc.add_span_converter(EmailConverter()) |
|
|
|
sp_proc.add_span_converter(MixNumAlphaConverter()) |
|
|
|
sp_proc.add_span_converter(AlphaSpanConverter()) |
|
|
|
sp_proc.add_span_converter(DigitSpanConverter()) |
|
|
|
sp_proc.add_span_converter(TimeConverter()) |
|
|
|
|
|
|
|
|
|
|
|
char_proc = CWSCharSegProcessor('sentence', 'chars_list') |
|
|
|
|
|
|
@@ -37,7 +45,7 @@ tag_proc = CWSSegAppTagProcessor('sentence', 'tags') |
|
|
|
bigram_proc = Pre2Post2BigramProcessor('chars_list', 'bigrams_list') |
|
|
|
|
|
|
|
char_vocab_proc = VocabProcessor('chars_list') |
|
|
|
bigram_vocab_proc = VocabProcessor('bigrams_list') |
|
|
|
bigram_vocab_proc = VocabProcessor('bigrams_list', min_count=4) |
|
|
|
|
|
|
|
# 2. 使用processor |
|
|
|
fs2hs_proc(tr_dataset) |
|
|
@@ -74,6 +82,8 @@ bigram_index_proc(dev_dataset) |
|
|
|
seq_len_proc(dev_dataset) |
|
|
|
|
|
|
|
print("Finish preparing data.") |
|
|
|
print("Vocab size:{}, bigram size:{}.".format(char_vocab_proc.get_vocab_size(), bigram_vocab_proc.get_vocab_size())) |
|
|
|
|
|
|
|
|
|
|
|
# 3. 得到数据集可以用于训练了 |
|
|
|
from itertools import chain |
|
|
@@ -89,11 +99,10 @@ def flat_nested_list(nested_list): |
|
|
|
return list(chain(*nested_list)) |
|
|
|
|
|
|
|
def calculate_pre_rec_f1(model, batcher): |
|
|
|
true_ys, pred_ys, seq_lens = decode_iterator(model, batcher) |
|
|
|
refined_true_ys = refine_ys_on_seq_len(true_ys, seq_lens) |
|
|
|
refined_pred_ys = refine_ys_on_seq_len(pred_ys, seq_lens) |
|
|
|
true_ys = flat_nested_list(refined_true_ys) |
|
|
|
pred_ys = flat_nested_list(refined_pred_ys) |
|
|
|
true_ys, pred_ys = decode_iterator(model, batcher) |
|
|
|
|
|
|
|
true_ys = flat_nested_list(true_ys) |
|
|
|
pred_ys = flat_nested_list(pred_ys) |
|
|
|
|
|
|
|
cor_num = 0 |
|
|
|
yp_wordnum = pred_ys.count(1) |
|
|
@@ -134,7 +143,10 @@ def decode_iterator(model, batcher): |
|
|
|
seq_lens.extend(list(seq_len)) |
|
|
|
model.train() |
|
|
|
|
|
|
|
return true_ys, pred_ys, seq_lens |
|
|
|
true_ys = refine_ys_on_seq_len(true_ys, seq_lens) |
|
|
|
pred_ys = refine_ys_on_seq_len(pred_ys, seq_lens) |
|
|
|
|
|
|
|
return true_ys, pred_ys |
|
|
|
|
|
|
|
# TODO pretrain的embedding是怎么解决的? |
|
|
|
|
|
|
@@ -161,7 +173,7 @@ cws_model.cuda() |
|
|
|
|
|
|
|
num_epochs = 3 |
|
|
|
loss_fn = FocalLoss(class_num=tag_size) |
|
|
|
optimizer = optim.Adagrad(cws_model.parameters(), lr=0.01) |
|
|
|
optimizer = optim.Adagrad(cws_model.parameters(), lr=0.02) |
|
|
|
|
|
|
|
|
|
|
|
print_every = 50 |
|
|
@@ -179,6 +191,8 @@ for num_epoch in range(num_epochs): |
|
|
|
pbar.set_description_str('Epoch:%d' % (num_epoch + 1)) |
|
|
|
cws_model.train() |
|
|
|
for batch_idx, (batch_x, batch_y) in enumerate(tr_batcher, 1): |
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
|
|
pred_dict = cws_model(batch_x) # B x L x tag_size |
|
|
|
|
|
|
|
seq_lens = pred_dict['seq_lens'] |
|
|
@@ -217,6 +231,7 @@ for num_epoch in range(num_epochs): |
|
|
|
} |
|
|
|
best_epoch = num_epoch |
|
|
|
|
|
|
|
cws_model.load_state_dict(best_state_dict) |
|
|
|
|
|
|
|
# 4. 组装需要存下的内容 |
|
|
|
pp = Pipeline() |
|
|
@@ -229,7 +244,7 @@ pp.add_processor(char_index_proc) |
|
|
|
pp.add_processor(bigram_index_proc) |
|
|
|
pp.add_processor(seq_len_proc) |
|
|
|
|
|
|
|
te_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/pku/middle_files/pku_test.txt' |
|
|
|
te_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/middle_files/{}_test.txt'.format(ds_name, ds_name) |
|
|
|
te_dataset = reader.load(te_filename) |
|
|
|
pp(te_dataset) |
|
|
|
|
|
|
|