Browse Source

finish POS tagging API

tags/v0.3.0^2
FengZiYjun 5 years ago
parent
commit
7ecd8c9c14
4 changed files with 69 additions and 33 deletions
  1. +39
    -18
      fastNLP/api/api.py
  2. +3
    -3
      fastNLP/core/metrics.py
  3. +9
    -5
      reproduction/chinese_word_segment/process/cws_processor.py
  4. +18
    -7
      reproduction/pos_tag_model/train_pos_tag.py

+ 39
- 18
fastNLP/api/api.py View File

@@ -17,8 +17,7 @@ from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.batch import Batch
from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1
from fastNLP.api.pipeline import Pipeline
from fastNLP.core.metrics import SeqLabelEvaluator2
from fastNLP.core.tester import Tester


# TODO add pretrain urls
model_urls = {
@@ -29,6 +28,7 @@ model_urls = {
class API:
def __init__(self):
self.pipeline = None
self._dict = None

def predict(self, *args, **kwargs):
raise NotImplementedError
@@ -38,8 +38,8 @@ class API:
_dict = torch.load(path, map_location='cpu')
else:
_dict = load_url(path, map_location='cpu')
self.pipeline = _dict['pipeline']
self._dict = _dict
self.pipeline = _dict['pipeline']
for processor in self.pipeline.pipeline:
if isinstance(processor, ModelProcessor):
processor.set_model_device(device)
@@ -48,8 +48,10 @@ class API:
class POS(API):
"""FastNLP API for Part-Of-Speech tagging.

"""
:param str model_path: the path to the model.
:param str device: device name such as "cpu" or "cuda:0". Use the same notation as PyTorch.

"""
def __init__(self, model_path=None, device='cpu'):
super(POS, self).__init__()
if model_path is None:
@@ -75,12 +77,28 @@ class POS(API):

# 2. 组建dataset
dataset = DataSet()
dataset.add_field('words', sentence_list)
dataset.add_field("words", sentence_list)

# 3. 使用pipeline
self.pipeline(dataset)

output = dataset['word_pos_output'].content
def decode_tags(ins):
pred_tags = ins["tag"]
chars = ins["words"]
words = []
start_idx = 0
for idx, tag in enumerate(pred_tags):
if tag[0] == "S":
words.append(chars[start_idx:idx + 1] + "/" + tag[2:])
start_idx = idx + 1
elif tag[0] == "E":
words.append("".join(chars[start_idx:idx + 1]) + "/" + tag[2:])
start_idx = idx + 1
return words

dataset.apply(decode_tags, new_field_name="tag_output")

output = dataset.field_arrays["tag_output"].content
if isinstance(content, str):
return output[0]
elif isinstance(content, list):
@@ -98,6 +116,7 @@ class POS(API):
reader = ConllPOSReader()
te_dataset = reader.load(filepath)

"""
evaluator = SeqLabelEvaluator2('word_seq_origin_len')
end_tagidx_set = set()
tag_proc.vocab.build_vocab()
@@ -108,15 +127,16 @@ class POS(API):
end_tagidx_set.add(value)
evaluator.end_tagidx_set = end_tagidx_set

default_valid_args = {"batch_size": 64,
"use_cuda": True, "evaluator": evaluator}

pp(te_dataset)
te_dataset.set_target(truth=True)

default_valid_args = {"batch_size": 64,
"use_cuda": True, "evaluator": evaluator,
"model": model, "data": te_dataset}

tester = Tester(**default_valid_args)

test_result = tester.test(model, te_dataset)
test_result = tester.test()

f1 = round(test_result['F'] * 100, 2)
pre = round(test_result['P'] * 100, 2)
@@ -124,6 +144,7 @@ class POS(API):
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec))

return f1, pre, rec
"""


class CWS(API):
@@ -290,13 +311,13 @@ class Analyzer:


if __name__ == "__main__":
# pos_model_path = '../../reproduction/pos_tag_model/pos_crf.pkl'
# pos = POS(device='cpu')
# s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。' ,
# '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
# '那么这款无人机到底有多厉害?']
pos_model_path = '/home/zyfeng/fastnlp/reproduction/pos_tag_model/model_pp.pkl'
pos = POS(pos_model_path, device='cpu')
s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。',
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
'那么这款无人机到底有多厉害?']
# print(pos.test('/Users/yh/Desktop/test_data/pos_test.conll'))
# print(pos.predict(s))
print(pos.predict(s))

# cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf.pkl'
# cws = CWS(device='cpu')
@@ -306,9 +327,9 @@ if __name__ == "__main__":
# print(cws.test('/Users/yh/Desktop/test_data/cws_test.conll'))
# print(cws.predict(s))

parser = Parser(device='cpu')
# parser = Parser(device='cpu')
# print(parser.test('/Users/yh/Desktop/test_data/parser_test2.conll'))
s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。',
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
'那么这款无人机到底有多厉害?']
print(parser.predict(s))
# print(parser.predict(s))

+ 3
- 3
fastNLP/core/metrics.py View File

@@ -503,9 +503,9 @@ class SpanFPreRecMetric(MetricBase):
f, pre, rec = self._compute_f_pre_rec(sum(self._true_positives.values()),
sum(self._false_negatives.values()),
sum(self._false_positives.values()))
evaluate_result['f'] = f
evaluate_result['pre'] = pre
evaluate_result['rec'] = rec
evaluate_result['f'] = round(f, 6)
evaluate_result['pre'] = round(pre, 6)
evaluate_result['rec'] = round(rec, 6)

if reset:
self._true_positives = defaultdict(int)


+ 9
- 5
reproduction/chinese_word_segment/process/cws_processor.py View File

@@ -1,10 +1,9 @@

import re


from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.dataset import DataSet
from fastNLP.api.processor import Processor
from fastNLP.core.dataset import DataSet
from fastNLP.core.vocabulary import Vocabulary
from reproduction.chinese_word_segment.process.span_converter import SpanConverter

_SPECIAL_TAG_PATTERN = '<[a-zA-Z]+>'
@@ -239,7 +238,7 @@ class VocabIndexerProcessor(Processor):

"""
def __init__(self, field_name, new_added_filed_name=None, min_freq=1, max_size=None,
verbose=1):
verbose=1, is_input=True):
"""

:param field_name: 从哪个field_name创建词表,以及对哪个field_name进行index操作
@@ -247,12 +246,14 @@ class VocabIndexerProcessor(Processor):
:param min_freq: 创建的Vocabulary允许的单词最少出现次数.
:param max_size: 创建的Vocabulary允许的最大的单词数量
:param verbose: 0, 不输出任何信息;1,输出信息
:param bool is_input:
"""
super(VocabIndexerProcessor, self).__init__(field_name, new_added_filed_name)
self.min_freq = min_freq
self.max_size = max_size

self.verbose =verbose
self.is_input = is_input

def construct_vocab(self, *datasets):
"""
@@ -304,7 +305,10 @@ class VocabIndexerProcessor(Processor):
for dataset in to_index_datasets:
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset))
dataset.apply(lambda ins: [self.vocab.to_index(token) for token in ins[self.field_name]],
new_field_name=self.new_added_field_name)
new_field_name=self.new_added_field_name, is_input=self.is_input)
# 只返回一个,infer时为了跟其他processor保持一致
if len(to_index_datasets) == 1:
return to_index_datasets[0]

def set_vocab(self, vocab):
assert isinstance(vocab, Vocabulary), "Only fastNLP.core.Vocabulary is allowed, not {}.".format(type(vocab))


+ 18
- 7
reproduction/pos_tag_model/train_pos_tag.py View File

@@ -1,5 +1,12 @@
import os
import sys

import torch

# in order to run fastNLP without installation
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))


from fastNLP.api.pipeline import Pipeline
from fastNLP.api.processor import SeqLenProcessor
from fastNLP.core.metrics import SpanFPreRecMetric
@@ -8,6 +15,7 @@ from fastNLP.io.config_io import ConfigLoader, ConfigSection
from fastNLP.models.sequence_modeling import AdvSeqLabel
from reproduction.chinese_word_segment.process.cws_processor import VocabIndexerProcessor
from reproduction.pos_tag_model.pos_reader import ZhConllPOSReader
from fastNLP.api.processor import ModelProcessor, Index2WordProcessor

cfgfile = './pos_tag.cfg'
pickle_path = "save"
@@ -25,16 +33,16 @@ def train():
print(dataset)
print("dataset transformed")

vocab_proc = VocabIndexerProcessor("words")
tag_proc = VocabIndexerProcessor("tag")
seq_len_proc = SeqLenProcessor(field_name="words", new_added_field_name="word_seq_origin_len")
dataset.rename_field("tag", "truth")

vocab_proc = VocabIndexerProcessor("words", new_added_filed_name="word_seq")
tag_proc = VocabIndexerProcessor("truth")
seq_len_proc = SeqLenProcessor(field_name="word_seq", new_added_field_name="word_seq_origin_len", is_input=True)

vocab_proc(dataset)
tag_proc(dataset)
seq_len_proc(dataset)

dataset.rename_field("words", "word_seq")
dataset.rename_field("tag", "truth")
dataset.set_input("word_seq", "word_seq_origin_len", "truth")
dataset.set_target("truth", "word_seq_origin_len")

@@ -53,11 +61,14 @@ def train():
target="truth",
seq_lens="word_seq_origin_len"),
dev_data=dataset, metric_key="f",
use_tqdm=False, use_cuda=True, print_every=20)
use_tqdm=False, use_cuda=True, print_every=20, n_epochs=1, save_path="./save")
trainer.train()

# save model & pipeline
pp = Pipeline([vocab_proc, seq_len_proc])
model_proc = ModelProcessor(model, seq_len_field_name="word_seq_origin_len")
id2tag = Index2WordProcessor(tag_proc.vocab, "predict", "tag")

pp = Pipeline([vocab_proc, seq_len_proc, model_proc, id2tag])
save_dict = {"pipeline": pp, "model": model, "tag_vocab": tag_proc.vocab}
torch.save(save_dict, "model_pp.pkl")
print("pipeline saved")


Loading…
Cancel
Save