@@ -17,8 +17,7 @@ from fastNLP.core.sampler import SequentialSampler | |||||
from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1 | from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1 | ||||
from fastNLP.api.pipeline import Pipeline | from fastNLP.api.pipeline import Pipeline | ||||
from fastNLP.core.metrics import SeqLabelEvaluator2 | |||||
from fastNLP.core.tester import Tester | |||||
# TODO add pretrain urls | # TODO add pretrain urls | ||||
model_urls = { | model_urls = { | ||||
@@ -29,6 +28,7 @@ model_urls = { | |||||
class API: | class API: | ||||
def __init__(self): | def __init__(self): | ||||
self.pipeline = None | self.pipeline = None | ||||
self._dict = None | |||||
def predict(self, *args, **kwargs): | def predict(self, *args, **kwargs): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
@@ -38,8 +38,8 @@ class API: | |||||
_dict = torch.load(path, map_location='cpu') | _dict = torch.load(path, map_location='cpu') | ||||
else: | else: | ||||
_dict = load_url(path, map_location='cpu') | _dict = load_url(path, map_location='cpu') | ||||
self.pipeline = _dict['pipeline'] | |||||
self._dict = _dict | self._dict = _dict | ||||
self.pipeline = _dict['pipeline'] | |||||
for processor in self.pipeline.pipeline: | for processor in self.pipeline.pipeline: | ||||
if isinstance(processor, ModelProcessor): | if isinstance(processor, ModelProcessor): | ||||
processor.set_model_device(device) | processor.set_model_device(device) | ||||
@@ -48,8 +48,10 @@ class API: | |||||
class POS(API): | class POS(API): | ||||
"""FastNLP API for Part-Of-Speech tagging. | """FastNLP API for Part-Of-Speech tagging. | ||||
""" | |||||
:param str model_path: the path to the model. | |||||
:param str device: device name such as "cpu" or "cuda:0". Use the same notation as PyTorch. | |||||
""" | |||||
def __init__(self, model_path=None, device='cpu'): | def __init__(self, model_path=None, device='cpu'): | ||||
super(POS, self).__init__() | super(POS, self).__init__() | ||||
if model_path is None: | if model_path is None: | ||||
@@ -75,12 +77,28 @@ class POS(API): | |||||
# 2. 组建dataset | # 2. 组建dataset | ||||
dataset = DataSet() | dataset = DataSet() | ||||
dataset.add_field('words', sentence_list) | |||||
dataset.add_field("words", sentence_list) | |||||
# 3. 使用pipeline | # 3. 使用pipeline | ||||
self.pipeline(dataset) | self.pipeline(dataset) | ||||
output = dataset['word_pos_output'].content | |||||
def decode_tags(ins): | |||||
pred_tags = ins["tag"] | |||||
chars = ins["words"] | |||||
words = [] | |||||
start_idx = 0 | |||||
for idx, tag in enumerate(pred_tags): | |||||
if tag[0] == "S": | |||||
words.append(chars[start_idx:idx + 1] + "/" + tag[2:]) | |||||
start_idx = idx + 1 | |||||
elif tag[0] == "E": | |||||
words.append("".join(chars[start_idx:idx + 1]) + "/" + tag[2:]) | |||||
start_idx = idx + 1 | |||||
return words | |||||
dataset.apply(decode_tags, new_field_name="tag_output") | |||||
output = dataset.field_arrays["tag_output"].content | |||||
if isinstance(content, str): | if isinstance(content, str): | ||||
return output[0] | return output[0] | ||||
elif isinstance(content, list): | elif isinstance(content, list): | ||||
@@ -98,6 +116,7 @@ class POS(API): | |||||
reader = ConllPOSReader() | reader = ConllPOSReader() | ||||
te_dataset = reader.load(filepath) | te_dataset = reader.load(filepath) | ||||
""" | |||||
evaluator = SeqLabelEvaluator2('word_seq_origin_len') | evaluator = SeqLabelEvaluator2('word_seq_origin_len') | ||||
end_tagidx_set = set() | end_tagidx_set = set() | ||||
tag_proc.vocab.build_vocab() | tag_proc.vocab.build_vocab() | ||||
@@ -108,15 +127,16 @@ class POS(API): | |||||
end_tagidx_set.add(value) | end_tagidx_set.add(value) | ||||
evaluator.end_tagidx_set = end_tagidx_set | evaluator.end_tagidx_set = end_tagidx_set | ||||
default_valid_args = {"batch_size": 64, | |||||
"use_cuda": True, "evaluator": evaluator} | |||||
pp(te_dataset) | pp(te_dataset) | ||||
te_dataset.set_target(truth=True) | te_dataset.set_target(truth=True) | ||||
default_valid_args = {"batch_size": 64, | |||||
"use_cuda": True, "evaluator": evaluator, | |||||
"model": model, "data": te_dataset} | |||||
tester = Tester(**default_valid_args) | tester = Tester(**default_valid_args) | ||||
test_result = tester.test(model, te_dataset) | |||||
test_result = tester.test() | |||||
f1 = round(test_result['F'] * 100, 2) | f1 = round(test_result['F'] * 100, 2) | ||||
pre = round(test_result['P'] * 100, 2) | pre = round(test_result['P'] * 100, 2) | ||||
@@ -124,6 +144,7 @@ class POS(API): | |||||
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec)) | # print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec)) | ||||
return f1, pre, rec | return f1, pre, rec | ||||
""" | |||||
class CWS(API): | class CWS(API): | ||||
@@ -290,13 +311,13 @@ class Analyzer: | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
# pos_model_path = '../../reproduction/pos_tag_model/pos_crf.pkl' | |||||
# pos = POS(device='cpu') | |||||
# s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。' , | |||||
# '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||||
# '那么这款无人机到底有多厉害?'] | |||||
pos_model_path = '/home/zyfeng/fastnlp/reproduction/pos_tag_model/model_pp.pkl' | |||||
pos = POS(pos_model_path, device='cpu') | |||||
s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', | |||||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||||
'那么这款无人机到底有多厉害?'] | |||||
# print(pos.test('/Users/yh/Desktop/test_data/pos_test.conll')) | # print(pos.test('/Users/yh/Desktop/test_data/pos_test.conll')) | ||||
# print(pos.predict(s)) | |||||
print(pos.predict(s)) | |||||
# cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf.pkl' | # cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf.pkl' | ||||
# cws = CWS(device='cpu') | # cws = CWS(device='cpu') | ||||
@@ -306,9 +327,9 @@ if __name__ == "__main__": | |||||
# print(cws.test('/Users/yh/Desktop/test_data/cws_test.conll')) | # print(cws.test('/Users/yh/Desktop/test_data/cws_test.conll')) | ||||
# print(cws.predict(s)) | # print(cws.predict(s)) | ||||
parser = Parser(device='cpu') | |||||
# parser = Parser(device='cpu') | |||||
# print(parser.test('/Users/yh/Desktop/test_data/parser_test2.conll')) | # print(parser.test('/Users/yh/Desktop/test_data/parser_test2.conll')) | ||||
s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', | s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', | ||||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | ||||
'那么这款无人机到底有多厉害?'] | '那么这款无人机到底有多厉害?'] | ||||
print(parser.predict(s)) | |||||
# print(parser.predict(s)) |
@@ -503,9 +503,9 @@ class SpanFPreRecMetric(MetricBase): | |||||
f, pre, rec = self._compute_f_pre_rec(sum(self._true_positives.values()), | f, pre, rec = self._compute_f_pre_rec(sum(self._true_positives.values()), | ||||
sum(self._false_negatives.values()), | sum(self._false_negatives.values()), | ||||
sum(self._false_positives.values())) | sum(self._false_positives.values())) | ||||
evaluate_result['f'] = f | |||||
evaluate_result['pre'] = pre | |||||
evaluate_result['rec'] = rec | |||||
evaluate_result['f'] = round(f, 6) | |||||
evaluate_result['pre'] = round(pre, 6) | |||||
evaluate_result['rec'] = round(rec, 6) | |||||
if reset: | if reset: | ||||
self._true_positives = defaultdict(int) | self._true_positives = defaultdict(int) | ||||
@@ -1,10 +1,9 @@ | |||||
import re | import re | ||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.api.processor import Processor | from fastNLP.api.processor import Processor | ||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from reproduction.chinese_word_segment.process.span_converter import SpanConverter | from reproduction.chinese_word_segment.process.span_converter import SpanConverter | ||||
_SPECIAL_TAG_PATTERN = '<[a-zA-Z]+>' | _SPECIAL_TAG_PATTERN = '<[a-zA-Z]+>' | ||||
@@ -239,7 +238,7 @@ class VocabIndexerProcessor(Processor): | |||||
""" | """ | ||||
def __init__(self, field_name, new_added_filed_name=None, min_freq=1, max_size=None, | def __init__(self, field_name, new_added_filed_name=None, min_freq=1, max_size=None, | ||||
verbose=1): | |||||
verbose=1, is_input=True): | |||||
""" | """ | ||||
:param field_name: 从哪个field_name创建词表,以及对哪个field_name进行index操作 | :param field_name: 从哪个field_name创建词表,以及对哪个field_name进行index操作 | ||||
@@ -247,12 +246,14 @@ class VocabIndexerProcessor(Processor): | |||||
:param min_freq: 创建的Vocabulary允许的单词最少出现次数. | :param min_freq: 创建的Vocabulary允许的单词最少出现次数. | ||||
:param max_size: 创建的Vocabulary允许的最大的单词数量 | :param max_size: 创建的Vocabulary允许的最大的单词数量 | ||||
:param verbose: 0, 不输出任何信息;1,输出信息 | :param verbose: 0, 不输出任何信息;1,输出信息 | ||||
:param bool is_input: | |||||
""" | """ | ||||
super(VocabIndexerProcessor, self).__init__(field_name, new_added_filed_name) | super(VocabIndexerProcessor, self).__init__(field_name, new_added_filed_name) | ||||
self.min_freq = min_freq | self.min_freq = min_freq | ||||
self.max_size = max_size | self.max_size = max_size | ||||
self.verbose =verbose | self.verbose =verbose | ||||
self.is_input = is_input | |||||
def construct_vocab(self, *datasets): | def construct_vocab(self, *datasets): | ||||
""" | """ | ||||
@@ -304,7 +305,10 @@ class VocabIndexerProcessor(Processor): | |||||
for dataset in to_index_datasets: | for dataset in to_index_datasets: | ||||
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) | assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) | ||||
dataset.apply(lambda ins: [self.vocab.to_index(token) for token in ins[self.field_name]], | dataset.apply(lambda ins: [self.vocab.to_index(token) for token in ins[self.field_name]], | ||||
new_field_name=self.new_added_field_name) | |||||
new_field_name=self.new_added_field_name, is_input=self.is_input) | |||||
# 只返回一个,infer时为了跟其他processor保持一致 | |||||
if len(to_index_datasets) == 1: | |||||
return to_index_datasets[0] | |||||
def set_vocab(self, vocab): | def set_vocab(self, vocab): | ||||
assert isinstance(vocab, Vocabulary), "Only fastNLP.core.Vocabulary is allowed, not {}.".format(type(vocab)) | assert isinstance(vocab, Vocabulary), "Only fastNLP.core.Vocabulary is allowed, not {}.".format(type(vocab)) | ||||
@@ -1,5 +1,12 @@ | |||||
import os | |||||
import sys | |||||
import torch | import torch | ||||
# in order to run fastNLP without installation | |||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | |||||
from fastNLP.api.pipeline import Pipeline | from fastNLP.api.pipeline import Pipeline | ||||
from fastNLP.api.processor import SeqLenProcessor | from fastNLP.api.processor import SeqLenProcessor | ||||
from fastNLP.core.metrics import SpanFPreRecMetric | from fastNLP.core.metrics import SpanFPreRecMetric | ||||
@@ -8,6 +15,7 @@ from fastNLP.io.config_io import ConfigLoader, ConfigSection | |||||
from fastNLP.models.sequence_modeling import AdvSeqLabel | from fastNLP.models.sequence_modeling import AdvSeqLabel | ||||
from reproduction.chinese_word_segment.process.cws_processor import VocabIndexerProcessor | from reproduction.chinese_word_segment.process.cws_processor import VocabIndexerProcessor | ||||
from reproduction.pos_tag_model.pos_reader import ZhConllPOSReader | from reproduction.pos_tag_model.pos_reader import ZhConllPOSReader | ||||
from fastNLP.api.processor import ModelProcessor, Index2WordProcessor | |||||
cfgfile = './pos_tag.cfg' | cfgfile = './pos_tag.cfg' | ||||
pickle_path = "save" | pickle_path = "save" | ||||
@@ -25,16 +33,16 @@ def train(): | |||||
print(dataset) | print(dataset) | ||||
print("dataset transformed") | print("dataset transformed") | ||||
vocab_proc = VocabIndexerProcessor("words") | |||||
tag_proc = VocabIndexerProcessor("tag") | |||||
seq_len_proc = SeqLenProcessor(field_name="words", new_added_field_name="word_seq_origin_len") | |||||
dataset.rename_field("tag", "truth") | |||||
vocab_proc = VocabIndexerProcessor("words", new_added_filed_name="word_seq") | |||||
tag_proc = VocabIndexerProcessor("truth") | |||||
seq_len_proc = SeqLenProcessor(field_name="word_seq", new_added_field_name="word_seq_origin_len", is_input=True) | |||||
vocab_proc(dataset) | vocab_proc(dataset) | ||||
tag_proc(dataset) | tag_proc(dataset) | ||||
seq_len_proc(dataset) | seq_len_proc(dataset) | ||||
dataset.rename_field("words", "word_seq") | |||||
dataset.rename_field("tag", "truth") | |||||
dataset.set_input("word_seq", "word_seq_origin_len", "truth") | dataset.set_input("word_seq", "word_seq_origin_len", "truth") | ||||
dataset.set_target("truth", "word_seq_origin_len") | dataset.set_target("truth", "word_seq_origin_len") | ||||
@@ -53,11 +61,14 @@ def train(): | |||||
target="truth", | target="truth", | ||||
seq_lens="word_seq_origin_len"), | seq_lens="word_seq_origin_len"), | ||||
dev_data=dataset, metric_key="f", | dev_data=dataset, metric_key="f", | ||||
use_tqdm=False, use_cuda=True, print_every=20) | |||||
use_tqdm=False, use_cuda=True, print_every=20, n_epochs=1, save_path="./save") | |||||
trainer.train() | trainer.train() | ||||
# save model & pipeline | # save model & pipeline | ||||
pp = Pipeline([vocab_proc, seq_len_proc]) | |||||
model_proc = ModelProcessor(model, seq_len_field_name="word_seq_origin_len") | |||||
id2tag = Index2WordProcessor(tag_proc.vocab, "predict", "tag") | |||||
pp = Pipeline([vocab_proc, seq_len_proc, model_proc, id2tag]) | |||||
save_dict = {"pipeline": pp, "model": model, "tag_vocab": tag_proc.vocab} | save_dict = {"pipeline": pp, "model": model, "tag_vocab": tag_proc.vocab} | ||||
torch.save(save_dict, "model_pp.pkl") | torch.save(save_dict, "model_pp.pkl") | ||||
print("pipeline saved") | print("pipeline saved") | ||||