diff --git a/reproduction/chinese_word_segment/train_context.py b/reproduction/chinese_word_segment/train_context.py index ce055b0e..ac0b8471 100644 --- a/reproduction/chinese_word_segment/train_context.py +++ b/reproduction/chinese_word_segment/train_context.py @@ -209,9 +209,10 @@ torch.save(test_context_dict, 'models/test_context.pkl') # 4. 组装需要存下的内容 from fastNLP.api.processor import ModelProcessor +from reproduction.chinese_word_segment.process.cws_processor import SegApp2OutputProcessor model_proc = ModelProcessor(cws_model) -index2word_proc = +output_proc = SegApp2OutputProcessor() pp = Pipeline() pp.add_processor(fs2hs_proc) @@ -222,27 +223,15 @@ pp.add_processor(char_index_proc) pp.add_processor(bigram_index_proc) pp.add_processor(seq_len_proc) +pp.add_processor(model_proc) +pp.add_processor(output_proc) -pp.add_processor() - - - -te_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/middle_files/{}_test.txt'.format(ds_name, ds_name) -te_dataset = reader.load(te_filename) -pp(te_dataset) - -batch_size = 64 -te_batcher = Batch(te_dataset, batch_size, SequentialSampler(), use_cuda=False) -pre, rec, f1 = calculate_pre_rec_f1(cws_model, te_batcher) -print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1 * 100, - pre * 100, - rec * 100)) # TODO 这里貌似需要区分test pipeline与infer pipeline -test_context_dict = {'pipeline': pp, +infer_context_dict = {'pipeline': pp, 'model': cws_model} -torch.save(test_context_dict, 'models/test_context.pkl') +torch.save(infer_context_dict, 'models/infer_context.pkl') # TODO 还需要考虑如何替换回原文的问题?