From 7ecd8c9c146a23867381b3e89c2502ecd7f4f3e7 Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Mon, 7 Jan 2019 19:45:19 +0800 Subject: [PATCH] finish POS tagging API --- fastNLP/api/api.py | 57 +++++++++++++------ fastNLP/core/metrics.py | 6 +- .../process/cws_processor.py | 14 +++-- reproduction/pos_tag_model/train_pos_tag.py | 25 +++++--- 4 files changed, 69 insertions(+), 33 deletions(-) diff --git a/fastNLP/api/api.py b/fastNLP/api/api.py index bd14197a..2e6cc247 100644 --- a/fastNLP/api/api.py +++ b/fastNLP/api/api.py @@ -17,8 +17,7 @@ from fastNLP.core.sampler import SequentialSampler from fastNLP.core.batch import Batch from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1 from fastNLP.api.pipeline import Pipeline -from fastNLP.core.metrics import SeqLabelEvaluator2 -from fastNLP.core.tester import Tester + # TODO add pretrain urls model_urls = { @@ -29,6 +28,7 @@ model_urls = { class API: def __init__(self): self.pipeline = None + self._dict = None def predict(self, *args, **kwargs): raise NotImplementedError @@ -38,8 +38,8 @@ class API: _dict = torch.load(path, map_location='cpu') else: _dict = load_url(path, map_location='cpu') - self.pipeline = _dict['pipeline'] self._dict = _dict + self.pipeline = _dict['pipeline'] for processor in self.pipeline.pipeline: if isinstance(processor, ModelProcessor): processor.set_model_device(device) @@ -48,8 +48,10 @@ class API: class POS(API): """FastNLP API for Part-Of-Speech tagging. - """ + :param str model_path: the path to the model. + :param str device: device name such as "cpu" or "cuda:0". Use the same notation as PyTorch. + """ def __init__(self, model_path=None, device='cpu'): super(POS, self).__init__() if model_path is None: @@ -75,12 +77,28 @@ class POS(API): # 2. 组建dataset dataset = DataSet() - dataset.add_field('words', sentence_list) + dataset.add_field("words", sentence_list) # 3. 使用pipeline self.pipeline(dataset) - output = dataset['word_pos_output'].content + def decode_tags(ins): + pred_tags = ins["tag"] + chars = ins["words"] + words = [] + start_idx = 0 + for idx, tag in enumerate(pred_tags): + if tag[0] == "S": + words.append(chars[start_idx:idx + 1] + "/" + tag[2:]) + start_idx = idx + 1 + elif tag[0] == "E": + words.append("".join(chars[start_idx:idx + 1]) + "/" + tag[2:]) + start_idx = idx + 1 + return words + + dataset.apply(decode_tags, new_field_name="tag_output") + + output = dataset.field_arrays["tag_output"].content if isinstance(content, str): return output[0] elif isinstance(content, list): @@ -98,6 +116,7 @@ class POS(API): reader = ConllPOSReader() te_dataset = reader.load(filepath) + """ evaluator = SeqLabelEvaluator2('word_seq_origin_len') end_tagidx_set = set() tag_proc.vocab.build_vocab() @@ -108,15 +127,16 @@ class POS(API): end_tagidx_set.add(value) evaluator.end_tagidx_set = end_tagidx_set - default_valid_args = {"batch_size": 64, - "use_cuda": True, "evaluator": evaluator} - pp(te_dataset) te_dataset.set_target(truth=True) + default_valid_args = {"batch_size": 64, + "use_cuda": True, "evaluator": evaluator, + "model": model, "data": te_dataset} + tester = Tester(**default_valid_args) - test_result = tester.test(model, te_dataset) + test_result = tester.test() f1 = round(test_result['F'] * 100, 2) pre = round(test_result['P'] * 100, 2) @@ -124,6 +144,7 @@ class POS(API): # print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec)) return f1, pre, rec + """ class CWS(API): @@ -290,13 +311,13 @@ class Analyzer: if __name__ == "__main__": - # pos_model_path = '../../reproduction/pos_tag_model/pos_crf.pkl' - # pos = POS(device='cpu') - # s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。' , - # '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', - # '那么这款无人机到底有多厉害?'] + pos_model_path = '/home/zyfeng/fastnlp/reproduction/pos_tag_model/model_pp.pkl' + pos = POS(pos_model_path, device='cpu') + s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', + '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', + '那么这款无人机到底有多厉害?'] # print(pos.test('/Users/yh/Desktop/test_data/pos_test.conll')) - # print(pos.predict(s)) + print(pos.predict(s)) # cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf.pkl' # cws = CWS(device='cpu') @@ -306,9 +327,9 @@ if __name__ == "__main__": # print(cws.test('/Users/yh/Desktop/test_data/cws_test.conll')) # print(cws.predict(s)) - parser = Parser(device='cpu') + # parser = Parser(device='cpu') # print(parser.test('/Users/yh/Desktop/test_data/parser_test2.conll')) s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', '那么这款无人机到底有多厉害?'] - print(parser.predict(s)) + # print(parser.predict(s)) diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 36f9ab90..39d5bcf3 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -503,9 +503,9 @@ class SpanFPreRecMetric(MetricBase): f, pre, rec = self._compute_f_pre_rec(sum(self._true_positives.values()), sum(self._false_negatives.values()), sum(self._false_positives.values())) - evaluate_result['f'] = f - evaluate_result['pre'] = pre - evaluate_result['rec'] = rec + evaluate_result['f'] = round(f, 6) + evaluate_result['pre'] = round(pre, 6) + evaluate_result['rec'] = round(rec, 6) if reset: self._true_positives = defaultdict(int) diff --git a/reproduction/chinese_word_segment/process/cws_processor.py b/reproduction/chinese_word_segment/process/cws_processor.py index d2c5d1d5..3f7b6176 100644 --- a/reproduction/chinese_word_segment/process/cws_processor.py +++ b/reproduction/chinese_word_segment/process/cws_processor.py @@ -1,10 +1,9 @@ import re - -from fastNLP.core.vocabulary import Vocabulary -from fastNLP.core.dataset import DataSet from fastNLP.api.processor import Processor +from fastNLP.core.dataset import DataSet +from fastNLP.core.vocabulary import Vocabulary from reproduction.chinese_word_segment.process.span_converter import SpanConverter _SPECIAL_TAG_PATTERN = '<[a-zA-Z]+>' @@ -239,7 +238,7 @@ class VocabIndexerProcessor(Processor): """ def __init__(self, field_name, new_added_filed_name=None, min_freq=1, max_size=None, - verbose=1): + verbose=1, is_input=True): """ :param field_name: 从哪个field_name创建词表,以及对哪个field_name进行index操作 @@ -247,12 +246,14 @@ class VocabIndexerProcessor(Processor): :param min_freq: 创建的Vocabulary允许的单词最少出现次数. :param max_size: 创建的Vocabulary允许的最大的单词数量 :param verbose: 0, 不输出任何信息;1,输出信息 + :param bool is_input: """ super(VocabIndexerProcessor, self).__init__(field_name, new_added_filed_name) self.min_freq = min_freq self.max_size = max_size self.verbose =verbose + self.is_input = is_input def construct_vocab(self, *datasets): """ @@ -304,7 +305,10 @@ class VocabIndexerProcessor(Processor): for dataset in to_index_datasets: assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) dataset.apply(lambda ins: [self.vocab.to_index(token) for token in ins[self.field_name]], - new_field_name=self.new_added_field_name) + new_field_name=self.new_added_field_name, is_input=self.is_input) + # 只返回一个,infer时为了跟其他processor保持一致 + if len(to_index_datasets) == 1: + return to_index_datasets[0] def set_vocab(self, vocab): assert isinstance(vocab, Vocabulary), "Only fastNLP.core.Vocabulary is allowed, not {}.".format(type(vocab)) diff --git a/reproduction/pos_tag_model/train_pos_tag.py b/reproduction/pos_tag_model/train_pos_tag.py index e440b542..c01d50f3 100644 --- a/reproduction/pos_tag_model/train_pos_tag.py +++ b/reproduction/pos_tag_model/train_pos_tag.py @@ -1,5 +1,12 @@ +import os +import sys + import torch +# in order to run fastNLP without installation +sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) + + from fastNLP.api.pipeline import Pipeline from fastNLP.api.processor import SeqLenProcessor from fastNLP.core.metrics import SpanFPreRecMetric @@ -8,6 +15,7 @@ from fastNLP.io.config_io import ConfigLoader, ConfigSection from fastNLP.models.sequence_modeling import AdvSeqLabel from reproduction.chinese_word_segment.process.cws_processor import VocabIndexerProcessor from reproduction.pos_tag_model.pos_reader import ZhConllPOSReader +from fastNLP.api.processor import ModelProcessor, Index2WordProcessor cfgfile = './pos_tag.cfg' pickle_path = "save" @@ -25,16 +33,16 @@ def train(): print(dataset) print("dataset transformed") - vocab_proc = VocabIndexerProcessor("words") - tag_proc = VocabIndexerProcessor("tag") - seq_len_proc = SeqLenProcessor(field_name="words", new_added_field_name="word_seq_origin_len") + dataset.rename_field("tag", "truth") + + vocab_proc = VocabIndexerProcessor("words", new_added_filed_name="word_seq") + tag_proc = VocabIndexerProcessor("truth") + seq_len_proc = SeqLenProcessor(field_name="word_seq", new_added_field_name="word_seq_origin_len", is_input=True) vocab_proc(dataset) tag_proc(dataset) seq_len_proc(dataset) - dataset.rename_field("words", "word_seq") - dataset.rename_field("tag", "truth") dataset.set_input("word_seq", "word_seq_origin_len", "truth") dataset.set_target("truth", "word_seq_origin_len") @@ -53,11 +61,14 @@ def train(): target="truth", seq_lens="word_seq_origin_len"), dev_data=dataset, metric_key="f", - use_tqdm=False, use_cuda=True, print_every=20) + use_tqdm=False, use_cuda=True, print_every=20, n_epochs=1, save_path="./save") trainer.train() # save model & pipeline - pp = Pipeline([vocab_proc, seq_len_proc]) + model_proc = ModelProcessor(model, seq_len_field_name="word_seq_origin_len") + id2tag = Index2WordProcessor(tag_proc.vocab, "predict", "tag") + + pp = Pipeline([vocab_proc, seq_len_proc, model_proc, id2tag]) save_dict = {"pipeline": pp, "model": model, "tag_vocab": tag_proc.vocab} torch.save(save_dict, "model_pp.pkl") print("pipeline saved")