From 73ba3b5eec62583475baaf85fa6c461a3aa03e5c Mon Sep 17 00:00:00 2001 From: yh_cc Date: Sat, 10 Nov 2018 15:17:58 +0800 Subject: [PATCH] bug fix for pipeline --- fastNLP/api/cws.py | 32 +++++++++++++++++++ fastNLP/api/pipeline.py | 2 +- .../chinese_word_segment/train_context.py | 13 ++++++++ 3 files changed, 46 insertions(+), 1 deletion(-) create mode 100644 fastNLP/api/cws.py diff --git a/fastNLP/api/cws.py b/fastNLP/api/cws.py new file mode 100644 index 00000000..ea6f96e6 --- /dev/null +++ b/fastNLP/api/cws.py @@ -0,0 +1,32 @@ + + +from fastNLP.api.api import API +from fastNLP.core.dataset import DataSet + +class CWS(API): + def __init__(self, model_path='xxx'): + super(CWS, self).__init__() + self.load(model_path) + + def predict(self, sentence, pretrain=False): + + if hasattr(self, 'model') and hasattr(self, 'pipeline'): + raise ValueError("You have to load model first. Or specify pretrain=True.") + + sentence_list = [] + # 1. 检查sentence的类型 + if isinstance(sentence, str): + sentence_list.append(sentence) + elif isinstance(sentence, list): + sentence_list = sentence + + # 2. 组建dataset + dataset = DataSet() + dataset.add_field('raw_sentence', sentence_list) + + # 3. 使用pipeline + self.pipeline(dataset) + + # 4. TODO 这里应该要交给一个iterator一样的东西预测这个结果 + + # 5. TODO 得到结果,需要考虑是否需要反转回去, 及post_process的操作 diff --git a/fastNLP/api/pipeline.py b/fastNLP/api/pipeline.py index 745c8874..0edceb19 100644 --- a/fastNLP/api/pipeline.py +++ b/fastNLP/api/pipeline.py @@ -13,7 +13,7 @@ class Pipeline: def process(self, dataset): assert len(self.pipeline)!=0, "You need to add some processor first." - for proc_name, proc in self.pipeline: + for proc in self.pipeline: dataset = proc(dataset) return dataset diff --git a/reproduction/chinese_word_segment/train_context.py b/reproduction/chinese_word_segment/train_context.py index e43f8a24..184380e0 100644 --- a/reproduction/chinese_word_segment/train_context.py +++ b/reproduction/chinese_word_segment/train_context.py @@ -223,8 +223,21 @@ pp = Pipeline() pp.add_processor(fs2hs_proc) pp.add_processor(sp_proc) pp.add_processor(char_proc) +pp.add_processor(tag_proc) pp.add_processor(bigram_proc) pp.add_processor(char_index_proc) pp.add_processor(bigram_index_proc) pp.add_processor(seq_len_proc) +te_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/pku/middle_files/pku_test.txt' +te_dataset = reader.load(te_filename) +pp(te_dataset) + +batch_size = 64 +te_batcher = Batch(te_dataset, batch_size, SequentialSampler(), use_cuda=False) +pre, rec, f1 = calculate_pre_rec_f1(cws_model, te_batcher) +print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1 * 100, + pre * 100, + rec * 100)) + +