|
@@ -13,9 +13,6 @@ from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader |
|
|
from reproduction.pos_tag_model.pos_reader import ZhConllPOSReader |
|
|
from reproduction.pos_tag_model.pos_reader import ZhConllPOSReader |
|
|
from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag |
|
|
from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag |
|
|
from fastNLP.core.instance import Instance |
|
|
from fastNLP.core.instance import Instance |
|
|
from fastNLP.core.sampler import SequentialSampler |
|
|
|
|
|
from fastNLP.core.batch import Batch |
|
|
|
|
|
from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1 |
|
|
|
|
|
from fastNLP.api.pipeline import Pipeline |
|
|
from fastNLP.api.pipeline import Pipeline |
|
|
from fastNLP.core.metrics import SpanFPreRecMetric |
|
|
from fastNLP.core.metrics import SpanFPreRecMetric |
|
|
from fastNLP.api.processor import IndexerProcessor |
|
|
from fastNLP.api.processor import IndexerProcessor |
|
@@ -23,10 +20,9 @@ from fastNLP.api.processor import IndexerProcessor |
|
|
|
|
|
|
|
|
# TODO add pretrain urls |
|
|
# TODO add pretrain urls |
|
|
model_urls = { |
|
|
model_urls = { |
|
|
'cws': "http://123.206.98.91:8888/download/cws_crf-69e357c9.pkl" |
|
|
|
|
|
|
|
|
'cws': "http://123.206.98.91:8888/download/cws_crf_1_11-457fc899.pkl" |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class API: |
|
|
class API: |
|
|
def __init__(self): |
|
|
def __init__(self): |
|
|
self.pipeline = None |
|
|
self.pipeline = None |
|
@@ -174,12 +170,9 @@ class CWS(API): |
|
|
dataset.add_field('raw_sentence', sentence_list) |
|
|
dataset.add_field('raw_sentence', sentence_list) |
|
|
|
|
|
|
|
|
# 3. 使用pipeline |
|
|
# 3. 使用pipeline |
|
|
pipeline = self.pipeline.pipeline[:-3] + self.pipeline.pipeline[-2:] |
|
|
|
|
|
pp = Pipeline(pipeline) |
|
|
|
|
|
pp(dataset) |
|
|
|
|
|
# self.pipeline(dataset) |
|
|
|
|
|
|
|
|
self.pipeline(dataset) |
|
|
|
|
|
|
|
|
output = dataset['output'].content |
|
|
|
|
|
|
|
|
output = dataset.get_field('output').content |
|
|
if isinstance(content, str): |
|
|
if isinstance(content, str): |
|
|
return output[0] |
|
|
return output[0] |
|
|
elif isinstance(content, list): |
|
|
elif isinstance(content, list): |
|
@@ -324,7 +317,7 @@ class Analyzer: |
|
|
|
|
|
|
|
|
def test(self, filepath): |
|
|
def test(self, filepath): |
|
|
output_dict = {} |
|
|
output_dict = {} |
|
|
if self.seg: |
|
|
|
|
|
|
|
|
if self.cws: |
|
|
seg_output = self.cws.test(filepath) |
|
|
seg_output = self.cws.test(filepath) |
|
|
output_dict['seg'] = seg_output |
|
|
output_dict['seg'] = seg_output |
|
|
if self.pos: |
|
|
if self.pos: |
|
@@ -346,13 +339,14 @@ if __name__ == "__main__": |
|
|
# print(pos.test("/home/zyfeng/data/sample.conllx")) |
|
|
# print(pos.test("/home/zyfeng/data/sample.conllx")) |
|
|
# print(pos.predict(s)) |
|
|
# print(pos.predict(s)) |
|
|
|
|
|
|
|
|
cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf.pkl' |
|
|
|
|
|
cws = CWS(model_path=cws_model_path, device='cuda:0') |
|
|
|
|
|
|
|
|
# cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf_1_11.pkl' |
|
|
|
|
|
cws = CWS(device='cpu') |
|
|
s = ['本品是一个抗酸抗胆汁的胃黏膜保护剂' , |
|
|
s = ['本品是一个抗酸抗胆汁的胃黏膜保护剂' , |
|
|
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', |
|
|
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', |
|
|
'那么这款无人机到底有多厉害?'] |
|
|
'那么这款无人机到底有多厉害?'] |
|
|
# print(cws.test('/home/hyan/ctb3/test.conllx')) |
|
|
|
|
|
|
|
|
print(cws.test('/home/hyan/ctb3/test.conllx')) |
|
|
print(cws.predict(s)) |
|
|
print(cws.predict(s)) |
|
|
|
|
|
print(cws.predict('本品是一个抗酸抗胆汁的胃黏膜保护剂')) |
|
|
|
|
|
|
|
|
# parser = Parser(device='cpu') |
|
|
# parser = Parser(device='cpu') |
|
|
# print(parser.test('/Users/yh/Desktop/test_data/parser_test2.conll')) |
|
|
# print(parser.test('/Users/yh/Desktop/test_data/parser_test2.conll')) |
|
|