Browse Source

bug fix for pipeline

tags/v0.2.0
yh_cc 5 years ago
parent
commit
73ba3b5eec
3 changed files with 46 additions and 1 deletions
  1. +32
    -0
      fastNLP/api/cws.py
  2. +1
    -1
      fastNLP/api/pipeline.py
  3. +13
    -0
      reproduction/chinese_word_segment/train_context.py

+ 32
- 0
fastNLP/api/cws.py View File

@@ -0,0 +1,32 @@


from fastNLP.api.api import API
from fastNLP.core.dataset import DataSet

class CWS(API):
def __init__(self, model_path='xxx'):
super(CWS, self).__init__()
self.load(model_path)

def predict(self, sentence, pretrain=False):

if hasattr(self, 'model') and hasattr(self, 'pipeline'):
raise ValueError("You have to load model first. Or specify pretrain=True.")

sentence_list = []
# 1. 检查sentence的类型
if isinstance(sentence, str):
sentence_list.append(sentence)
elif isinstance(sentence, list):
sentence_list = sentence

# 2. 组建dataset
dataset = DataSet()
dataset.add_field('raw_sentence', sentence_list)

# 3. 使用pipeline
self.pipeline(dataset)

# 4. TODO 这里应该要交给一个iterator一样的东西预测这个结果

# 5. TODO 得到结果,需要考虑是否需要反转回去, 及post_process的操作

+ 1
- 1
fastNLP/api/pipeline.py View File

@@ -13,7 +13,7 @@ class Pipeline:
def process(self, dataset): def process(self, dataset):
assert len(self.pipeline)!=0, "You need to add some processor first." assert len(self.pipeline)!=0, "You need to add some processor first."


for proc_name, proc in self.pipeline:
for proc in self.pipeline:
dataset = proc(dataset) dataset = proc(dataset)


return dataset return dataset


+ 13
- 0
reproduction/chinese_word_segment/train_context.py View File

@@ -223,8 +223,21 @@ pp = Pipeline()
pp.add_processor(fs2hs_proc) pp.add_processor(fs2hs_proc)
pp.add_processor(sp_proc) pp.add_processor(sp_proc)
pp.add_processor(char_proc) pp.add_processor(char_proc)
pp.add_processor(tag_proc)
pp.add_processor(bigram_proc) pp.add_processor(bigram_proc)
pp.add_processor(char_index_proc) pp.add_processor(char_index_proc)
pp.add_processor(bigram_index_proc) pp.add_processor(bigram_index_proc)
pp.add_processor(seq_len_proc) pp.add_processor(seq_len_proc)


te_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/pku/middle_files/pku_test.txt'
te_dataset = reader.load(te_filename)
pp(te_dataset)

batch_size = 64
te_batcher = Batch(te_dataset, batch_size, SequentialSampler(), use_cuda=False)
pre, rec, f1 = calculate_pre_rec_f1(cws_model, te_batcher)
print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1 * 100,
pre * 100,
rec * 100))



Loading…
Cancel
Save