diff --git a/fastNLP/api/api.py b/fastNLP/api/api.py index 4a0bbbae..6389686a 100644 --- a/fastNLP/api/api.py +++ b/fastNLP/api/api.py @@ -13,9 +13,6 @@ from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader from reproduction.pos_tag_model.pos_reader import ZhConllPOSReader from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag from fastNLP.core.instance import Instance -from fastNLP.core.sampler import SequentialSampler -from fastNLP.core.batch import Batch -from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1 from fastNLP.api.pipeline import Pipeline from fastNLP.core.metrics import SpanFPreRecMetric from fastNLP.api.processor import IndexerProcessor @@ -23,10 +20,9 @@ from fastNLP.api.processor import IndexerProcessor # TODO add pretrain urls model_urls = { - 'cws': "http://123.206.98.91:8888/download/cws_crf-69e357c9.pkl" + 'cws': "http://123.206.98.91:8888/download/cws_crf_1_11-457fc899.pkl" } - class API: def __init__(self): self.pipeline = None @@ -174,12 +170,9 @@ class CWS(API): dataset.add_field('raw_sentence', sentence_list) # 3. 使用pipeline - pipeline = self.pipeline.pipeline[:-3] + self.pipeline.pipeline[-2:] - pp = Pipeline(pipeline) - pp(dataset) - # self.pipeline(dataset) + self.pipeline(dataset) - output = dataset['output'].content + output = dataset.get_field('output').content if isinstance(content, str): return output[0] elif isinstance(content, list): @@ -324,7 +317,7 @@ class Analyzer: def test(self, filepath): output_dict = {} - if self.seg: + if self.cws: seg_output = self.cws.test(filepath) output_dict['seg'] = seg_output if self.pos: @@ -346,13 +339,14 @@ if __name__ == "__main__": # print(pos.test("/home/zyfeng/data/sample.conllx")) # print(pos.predict(s)) - cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf.pkl' - cws = CWS(model_path=cws_model_path, device='cuda:0') + # cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf_1_11.pkl' + cws = CWS(device='cpu') s = ['本品是一个抗酸抗胆汁的胃黏膜保护剂' , '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', '那么这款无人机到底有多厉害?'] - # print(cws.test('/home/hyan/ctb3/test.conllx')) + print(cws.test('/home/hyan/ctb3/test.conllx')) print(cws.predict(s)) + print(cws.predict('本品是一个抗酸抗胆汁的胃黏膜保护剂')) # parser = Parser(device='cpu') # print(parser.test('/Users/yh/Desktop/test_data/parser_test2.conll')) diff --git a/fastNLP/api/processor.py b/fastNLP/api/processor.py index afa8775b..48513699 100644 --- a/fastNLP/api/processor.py +++ b/fastNLP/api/processor.py @@ -270,8 +270,8 @@ class ModelProcessor(Processor): for idx, seq_len in enumerate(seq_lens): tmp_batch.append(value[idx, :seq_len]) batch_output[key].extend(tmp_batch) - - batch_output[self.seq_len_field_name].extend(seq_lens) + if not self.seq_len_field_name in prediction: + batch_output[self.seq_len_field_name].extend(seq_lens) # TODO 当前的实现会导致之后的processor需要知道model输出的output的key是什么 for field_name, fields in batch_output.items(): diff --git a/reproduction/chinese_word_segment/process/cws_processor.py b/reproduction/chinese_word_segment/process/cws_processor.py index 3f3bdf18..9e57d35a 100644 --- a/reproduction/chinese_word_segment/process/cws_processor.py +++ b/reproduction/chinese_word_segment/process/cws_processor.py @@ -238,7 +238,7 @@ class VocabIndexerProcessor(Processor): """ def __init__(self, field_name, new_added_filed_name=None, min_freq=1, max_size=None, - verbose=1, is_input=True): + verbose=0, is_input=True): """ :param field_name: 从哪个field_name创建词表,以及对哪个field_name进行index操作 @@ -320,6 +320,15 @@ class VocabIndexerProcessor(Processor): def get_vocab_size(self): return len(self.vocab) + def set_verbose(self, verbose): + """ + 设置processor verbose状态。 + + :param verbose: int, 0,不输出任何信息;1,输出vocab 信息。 + :return: + """ + self.verbose = verbose + class VocabProcessor(Processor): def __init__(self, field_name, min_freq=1, max_size=None): diff --git a/reproduction/chinese_word_segment/train_context.py b/reproduction/chinese_word_segment/train_context.py index 68d431d3..b0d238b3 100644 --- a/reproduction/chinese_word_segment/train_context.py +++ b/reproduction/chinese_word_segment/train_context.py @@ -139,14 +139,16 @@ from fastNLP.api.processor import ModelProcessor from reproduction.chinese_word_segment.process.cws_processor import BMES2OutputProcessor model_proc = ModelProcessor(cws_model) -output_proc = BMES2OutputProcessor(tag_field_name='pred') +output_proc = BMES2OutputProcessor(chars_field_name='chars_lst', tag_field_name='pred') pp = Pipeline() pp.add_processor(fs2hs_proc) # pp.add_processor(sp_proc) pp.add_processor(char_proc) pp.add_processor(bigram_proc) +char_vocab_proc.set_verbose(0) pp.add_processor(char_vocab_proc) +bigram_vocab_proc.set_verbose(0) pp.add_processor(bigram_vocab_proc) pp.add_processor(seq_len_proc)