Browse Source

增加infer的pipeline

tags/v0.2.0
yh 6 years ago
parent
commit
9fc20ac7b8
1 changed files with 6 additions and 17 deletions
  1. +6
    -17
      reproduction/chinese_word_segment/train_context.py

+ 6
- 17
reproduction/chinese_word_segment/train_context.py View File

@@ -209,9 +209,10 @@ torch.save(test_context_dict, 'models/test_context.pkl')
# 4. 组装需要存下的内容 # 4. 组装需要存下的内容


from fastNLP.api.processor import ModelProcessor from fastNLP.api.processor import ModelProcessor
from reproduction.chinese_word_segment.process.cws_processor import SegApp2OutputProcessor


model_proc = ModelProcessor(cws_model) model_proc = ModelProcessor(cws_model)
index2word_proc =
output_proc = SegApp2OutputProcessor()


pp = Pipeline() pp = Pipeline()
pp.add_processor(fs2hs_proc) pp.add_processor(fs2hs_proc)
@@ -222,27 +223,15 @@ pp.add_processor(char_index_proc)
pp.add_processor(bigram_index_proc) pp.add_processor(bigram_index_proc)
pp.add_processor(seq_len_proc) pp.add_processor(seq_len_proc)


pp.add_processor(model_proc)
pp.add_processor(output_proc)


pp.add_processor()



te_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/middle_files/{}_test.txt'.format(ds_name, ds_name)
te_dataset = reader.load(te_filename)
pp(te_dataset)

batch_size = 64
te_batcher = Batch(te_dataset, batch_size, SequentialSampler(), use_cuda=False)
pre, rec, f1 = calculate_pre_rec_f1(cws_model, te_batcher)
print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1 * 100,
pre * 100,
rec * 100))


# TODO 这里貌似需要区分test pipeline与infer pipeline # TODO 这里貌似需要区分test pipeline与infer pipeline


test_context_dict = {'pipeline': pp,
infer_context_dict = {'pipeline': pp,
'model': cws_model} 'model': cws_model}
torch.save(test_context_dict, 'models/test_context.pkl')
torch.save(infer_context_dict, 'models/infer_context.pkl')




# TODO 还需要考虑如何替换回原文的问题? # TODO 还需要考虑如何替换回原文的问题?


Loading…
Cancel
Save