Browse Source

支持CWS的高级api“

tags/v0.3.0^2
yh 5 years ago
parent
commit
145125feb4
4 changed files with 23 additions and 18 deletions
  1. +8
    -14
      fastNLP/api/api.py
  2. +2
    -2
      fastNLP/api/processor.py
  3. +10
    -1
      reproduction/chinese_word_segment/process/cws_processor.py
  4. +3
    -1
      reproduction/chinese_word_segment/train_context.py

+ 8
- 14
fastNLP/api/api.py View File

@@ -13,9 +13,6 @@ from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader
from reproduction.pos_tag_model.pos_reader import ZhConllPOSReader
from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag
from fastNLP.core.instance import Instance
from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.batch import Batch
from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1
from fastNLP.api.pipeline import Pipeline
from fastNLP.core.metrics import SpanFPreRecMetric
from fastNLP.api.processor import IndexerProcessor
@@ -23,10 +20,9 @@ from fastNLP.api.processor import IndexerProcessor

# TODO add pretrain urls
model_urls = {
'cws': "http://123.206.98.91:8888/download/cws_crf-69e357c9.pkl"
'cws': "http://123.206.98.91:8888/download/cws_crf_1_11-457fc899.pkl"
}


class API:
def __init__(self):
self.pipeline = None
@@ -174,12 +170,9 @@ class CWS(API):
dataset.add_field('raw_sentence', sentence_list)

# 3. 使用pipeline
pipeline = self.pipeline.pipeline[:-3] + self.pipeline.pipeline[-2:]
pp = Pipeline(pipeline)
pp(dataset)
# self.pipeline(dataset)
self.pipeline(dataset)

output = dataset['output'].content
output = dataset.get_field('output').content
if isinstance(content, str):
return output[0]
elif isinstance(content, list):
@@ -324,7 +317,7 @@ class Analyzer:

def test(self, filepath):
output_dict = {}
if self.seg:
if self.cws:
seg_output = self.cws.test(filepath)
output_dict['seg'] = seg_output
if self.pos:
@@ -346,13 +339,14 @@ if __name__ == "__main__":
# print(pos.test("/home/zyfeng/data/sample.conllx"))
# print(pos.predict(s))

cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf.pkl'
cws = CWS(model_path=cws_model_path, device='cuda:0')
# cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf_1_11.pkl'
cws = CWS(device='cpu')
s = ['本品是一个抗酸抗胆汁的胃黏膜保护剂' ,
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
'那么这款无人机到底有多厉害?']
# print(cws.test('/home/hyan/ctb3/test.conllx'))
print(cws.test('/home/hyan/ctb3/test.conllx'))
print(cws.predict(s))
print(cws.predict('本品是一个抗酸抗胆汁的胃黏膜保护剂'))

# parser = Parser(device='cpu')
# print(parser.test('/Users/yh/Desktop/test_data/parser_test2.conll'))


+ 2
- 2
fastNLP/api/processor.py View File

@@ -270,8 +270,8 @@ class ModelProcessor(Processor):
for idx, seq_len in enumerate(seq_lens):
tmp_batch.append(value[idx, :seq_len])
batch_output[key].extend(tmp_batch)
batch_output[self.seq_len_field_name].extend(seq_lens)
if not self.seq_len_field_name in prediction:
batch_output[self.seq_len_field_name].extend(seq_lens)

# TODO 当前的实现会导致之后的processor需要知道model输出的output的key是什么
for field_name, fields in batch_output.items():


+ 10
- 1
reproduction/chinese_word_segment/process/cws_processor.py View File

@@ -238,7 +238,7 @@ class VocabIndexerProcessor(Processor):

"""
def __init__(self, field_name, new_added_filed_name=None, min_freq=1, max_size=None,
verbose=1, is_input=True):
verbose=0, is_input=True):
"""

:param field_name: 从哪个field_name创建词表,以及对哪个field_name进行index操作
@@ -320,6 +320,15 @@ class VocabIndexerProcessor(Processor):
def get_vocab_size(self):
return len(self.vocab)

def set_verbose(self, verbose):
"""
设置processor verbose状态。

:param verbose: int, 0,不输出任何信息;1,输出vocab 信息。
:return:
"""
self.verbose = verbose

class VocabProcessor(Processor):
def __init__(self, field_name, min_freq=1, max_size=None):



+ 3
- 1
reproduction/chinese_word_segment/train_context.py View File

@@ -139,14 +139,16 @@ from fastNLP.api.processor import ModelProcessor
from reproduction.chinese_word_segment.process.cws_processor import BMES2OutputProcessor

model_proc = ModelProcessor(cws_model)
output_proc = BMES2OutputProcessor(tag_field_name='pred')
output_proc = BMES2OutputProcessor(chars_field_name='chars_lst', tag_field_name='pred')

pp = Pipeline()
pp.add_processor(fs2hs_proc)
# pp.add_processor(sp_proc)
pp.add_processor(char_proc)
pp.add_processor(bigram_proc)
char_vocab_proc.set_verbose(0)
pp.add_processor(char_vocab_proc)
bigram_vocab_proc.set_verbose(0)
pp.add_processor(bigram_vocab_proc)
pp.add_processor(seq_len_proc)



Loading…
Cancel
Save