@@ -17,8 +17,7 @@ from fastNLP.core.sampler import SequentialSampler | |||
from fastNLP.core.batch import Batch | |||
from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1 | |||
from fastNLP.api.pipeline import Pipeline | |||
from fastNLP.core.metrics import SeqLabelEvaluator2 | |||
from fastNLP.core.tester import Tester | |||
# TODO add pretrain urls | |||
model_urls = { | |||
@@ -29,6 +28,7 @@ model_urls = { | |||
class API: | |||
def __init__(self): | |||
self.pipeline = None | |||
self._dict = None | |||
def predict(self, *args, **kwargs): | |||
raise NotImplementedError | |||
@@ -38,8 +38,8 @@ class API: | |||
_dict = torch.load(path, map_location='cpu') | |||
else: | |||
_dict = load_url(path, map_location='cpu') | |||
self.pipeline = _dict['pipeline'] | |||
self._dict = _dict | |||
self.pipeline = _dict['pipeline'] | |||
for processor in self.pipeline.pipeline: | |||
if isinstance(processor, ModelProcessor): | |||
processor.set_model_device(device) | |||
@@ -48,8 +48,10 @@ class API: | |||
class POS(API): | |||
"""FastNLP API for Part-Of-Speech tagging. | |||
""" | |||
:param str model_path: the path to the model. | |||
:param str device: device name such as "cpu" or "cuda:0". Use the same notation as PyTorch. | |||
""" | |||
def __init__(self, model_path=None, device='cpu'): | |||
super(POS, self).__init__() | |||
if model_path is None: | |||
@@ -75,12 +77,28 @@ class POS(API): | |||
# 2. 组建dataset | |||
dataset = DataSet() | |||
dataset.add_field('words', sentence_list) | |||
dataset.add_field("words", sentence_list) | |||
# 3. 使用pipeline | |||
self.pipeline(dataset) | |||
output = dataset['word_pos_output'].content | |||
def decode_tags(ins): | |||
pred_tags = ins["tag"] | |||
chars = ins["words"] | |||
words = [] | |||
start_idx = 0 | |||
for idx, tag in enumerate(pred_tags): | |||
if tag[0] == "S": | |||
words.append(chars[start_idx:idx + 1] + "/" + tag[2:]) | |||
start_idx = idx + 1 | |||
elif tag[0] == "E": | |||
words.append("".join(chars[start_idx:idx + 1]) + "/" + tag[2:]) | |||
start_idx = idx + 1 | |||
return words | |||
dataset.apply(decode_tags, new_field_name="tag_output") | |||
output = dataset.field_arrays["tag_output"].content | |||
if isinstance(content, str): | |||
return output[0] | |||
elif isinstance(content, list): | |||
@@ -98,6 +116,7 @@ class POS(API): | |||
reader = ConllPOSReader() | |||
te_dataset = reader.load(filepath) | |||
""" | |||
evaluator = SeqLabelEvaluator2('word_seq_origin_len') | |||
end_tagidx_set = set() | |||
tag_proc.vocab.build_vocab() | |||
@@ -108,15 +127,16 @@ class POS(API): | |||
end_tagidx_set.add(value) | |||
evaluator.end_tagidx_set = end_tagidx_set | |||
default_valid_args = {"batch_size": 64, | |||
"use_cuda": True, "evaluator": evaluator} | |||
pp(te_dataset) | |||
te_dataset.set_target(truth=True) | |||
default_valid_args = {"batch_size": 64, | |||
"use_cuda": True, "evaluator": evaluator, | |||
"model": model, "data": te_dataset} | |||
tester = Tester(**default_valid_args) | |||
test_result = tester.test(model, te_dataset) | |||
test_result = tester.test() | |||
f1 = round(test_result['F'] * 100, 2) | |||
pre = round(test_result['P'] * 100, 2) | |||
@@ -124,6 +144,7 @@ class POS(API): | |||
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec)) | |||
return f1, pre, rec | |||
""" | |||
class CWS(API): | |||
@@ -290,13 +311,13 @@ class Analyzer: | |||
if __name__ == "__main__": | |||
# pos_model_path = '../../reproduction/pos_tag_model/pos_crf.pkl' | |||
# pos = POS(device='cpu') | |||
# s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。' , | |||
# '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||
# '那么这款无人机到底有多厉害?'] | |||
pos_model_path = '/home/zyfeng/fastnlp/reproduction/pos_tag_model/model_pp.pkl' | |||
pos = POS(pos_model_path, device='cpu') | |||
s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', | |||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||
'那么这款无人机到底有多厉害?'] | |||
# print(pos.test('/Users/yh/Desktop/test_data/pos_test.conll')) | |||
# print(pos.predict(s)) | |||
print(pos.predict(s)) | |||
# cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf.pkl' | |||
# cws = CWS(device='cpu') | |||
@@ -306,9 +327,9 @@ if __name__ == "__main__": | |||
# print(cws.test('/Users/yh/Desktop/test_data/cws_test.conll')) | |||
# print(cws.predict(s)) | |||
parser = Parser(device='cpu') | |||
# parser = Parser(device='cpu') | |||
# print(parser.test('/Users/yh/Desktop/test_data/parser_test2.conll')) | |||
s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', | |||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||
'那么这款无人机到底有多厉害?'] | |||
print(parser.predict(s)) | |||
# print(parser.predict(s)) |
@@ -503,9 +503,9 @@ class SpanFPreRecMetric(MetricBase): | |||
f, pre, rec = self._compute_f_pre_rec(sum(self._true_positives.values()), | |||
sum(self._false_negatives.values()), | |||
sum(self._false_positives.values())) | |||
evaluate_result['f'] = f | |||
evaluate_result['pre'] = pre | |||
evaluate_result['rec'] = rec | |||
evaluate_result['f'] = round(f, 6) | |||
evaluate_result['pre'] = round(pre, 6) | |||
evaluate_result['rec'] = round(rec, 6) | |||
if reset: | |||
self._true_positives = defaultdict(int) | |||
@@ -1,10 +1,9 @@ | |||
import re | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.api.processor import Processor | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from reproduction.chinese_word_segment.process.span_converter import SpanConverter | |||
_SPECIAL_TAG_PATTERN = '<[a-zA-Z]+>' | |||
@@ -239,7 +238,7 @@ class VocabIndexerProcessor(Processor): | |||
""" | |||
def __init__(self, field_name, new_added_filed_name=None, min_freq=1, max_size=None, | |||
verbose=1): | |||
verbose=1, is_input=True): | |||
""" | |||
:param field_name: 从哪个field_name创建词表,以及对哪个field_name进行index操作 | |||
@@ -247,12 +246,14 @@ class VocabIndexerProcessor(Processor): | |||
:param min_freq: 创建的Vocabulary允许的单词最少出现次数. | |||
:param max_size: 创建的Vocabulary允许的最大的单词数量 | |||
:param verbose: 0, 不输出任何信息;1,输出信息 | |||
:param bool is_input: | |||
""" | |||
super(VocabIndexerProcessor, self).__init__(field_name, new_added_filed_name) | |||
self.min_freq = min_freq | |||
self.max_size = max_size | |||
self.verbose =verbose | |||
self.is_input = is_input | |||
def construct_vocab(self, *datasets): | |||
""" | |||
@@ -304,7 +305,10 @@ class VocabIndexerProcessor(Processor): | |||
for dataset in to_index_datasets: | |||
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) | |||
dataset.apply(lambda ins: [self.vocab.to_index(token) for token in ins[self.field_name]], | |||
new_field_name=self.new_added_field_name) | |||
new_field_name=self.new_added_field_name, is_input=self.is_input) | |||
# 只返回一个,infer时为了跟其他processor保持一致 | |||
if len(to_index_datasets) == 1: | |||
return to_index_datasets[0] | |||
def set_vocab(self, vocab): | |||
assert isinstance(vocab, Vocabulary), "Only fastNLP.core.Vocabulary is allowed, not {}.".format(type(vocab)) | |||
@@ -1,5 +1,12 @@ | |||
import os | |||
import sys | |||
import torch | |||
# in order to run fastNLP without installation | |||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | |||
from fastNLP.api.pipeline import Pipeline | |||
from fastNLP.api.processor import SeqLenProcessor | |||
from fastNLP.core.metrics import SpanFPreRecMetric | |||
@@ -8,6 +15,7 @@ from fastNLP.io.config_io import ConfigLoader, ConfigSection | |||
from fastNLP.models.sequence_modeling import AdvSeqLabel | |||
from reproduction.chinese_word_segment.process.cws_processor import VocabIndexerProcessor | |||
from reproduction.pos_tag_model.pos_reader import ZhConllPOSReader | |||
from fastNLP.api.processor import ModelProcessor, Index2WordProcessor | |||
cfgfile = './pos_tag.cfg' | |||
pickle_path = "save" | |||
@@ -25,16 +33,16 @@ def train(): | |||
print(dataset) | |||
print("dataset transformed") | |||
vocab_proc = VocabIndexerProcessor("words") | |||
tag_proc = VocabIndexerProcessor("tag") | |||
seq_len_proc = SeqLenProcessor(field_name="words", new_added_field_name="word_seq_origin_len") | |||
dataset.rename_field("tag", "truth") | |||
vocab_proc = VocabIndexerProcessor("words", new_added_filed_name="word_seq") | |||
tag_proc = VocabIndexerProcessor("truth") | |||
seq_len_proc = SeqLenProcessor(field_name="word_seq", new_added_field_name="word_seq_origin_len", is_input=True) | |||
vocab_proc(dataset) | |||
tag_proc(dataset) | |||
seq_len_proc(dataset) | |||
dataset.rename_field("words", "word_seq") | |||
dataset.rename_field("tag", "truth") | |||
dataset.set_input("word_seq", "word_seq_origin_len", "truth") | |||
dataset.set_target("truth", "word_seq_origin_len") | |||
@@ -53,11 +61,14 @@ def train(): | |||
target="truth", | |||
seq_lens="word_seq_origin_len"), | |||
dev_data=dataset, metric_key="f", | |||
use_tqdm=False, use_cuda=True, print_every=20) | |||
use_tqdm=False, use_cuda=True, print_every=20, n_epochs=1, save_path="./save") | |||
trainer.train() | |||
# save model & pipeline | |||
pp = Pipeline([vocab_proc, seq_len_proc]) | |||
model_proc = ModelProcessor(model, seq_len_field_name="word_seq_origin_len") | |||
id2tag = Index2WordProcessor(tag_proc.vocab, "predict", "tag") | |||
pp = Pipeline([vocab_proc, seq_len_proc, model_proc, id2tag]) | |||
save_dict = {"pipeline": pp, "model": model, "tag_vocab": tag_proc.vocab} | |||
torch.save(save_dict, "model_pp.pkl") | |||
print("pipeline saved") | |||