From b14dd588285d0452722b6529991e181fa3e65219 Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Sat, 19 Jan 2019 18:48:57 +0800 Subject: [PATCH] Update POS API --- fastNLP/api/api.py | 2 +- fastNLP/api/examples.py | 6 +++++- reproduction/POS_tagging/train_pos_tag.py | 22 +++++++++++----------- 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/fastNLP/api/api.py b/fastNLP/api/api.py index 38af57b3..0c5f17bc 100644 --- a/fastNLP/api/api.py +++ b/fastNLP/api/api.py @@ -18,7 +18,7 @@ from fastNLP.api.processor import IndexerProcessor # TODO add pretrain urls model_urls = { "cws": "http://123.206.98.91:8888/download/cws_crf_1_11-457fc899.pkl", - "pos": "http://123.206.98.91:8888/download/pos_tag_model_20190108-f3c60ee5.pkl", + "pos": "http://123.206.98.91:8888/download/pos_tag_model_20190119-43f8b435.pkl", "parser": "http://123.206.98.91:8888/download/biaffine_parser-3a2f052c.pkl" } diff --git a/fastNLP/api/examples.py b/fastNLP/api/examples.py index 10cc6edc..447d127a 100644 --- a/fastNLP/api/examples.py +++ b/fastNLP/api/examples.py @@ -16,6 +16,10 @@ def chinese_word_segmentation(): def pos_tagging(): + # 输入已分词序列 + text = ['编者 按: 7月 12日 , 英国 航空 航天 系统 公司 公布 了 该 公司 研制 的 第一款 高科技 隐形 无人机 雷电之神 。'] + text = [text[0].split()] + print(text) pos = POS(device='cpu') print(pos.predict(text)) @@ -26,4 +30,4 @@ def syntactic_parsing(): if __name__ == "__main__": - syntactic_parsing() + pos_tagging() diff --git a/reproduction/POS_tagging/train_pos_tag.py b/reproduction/POS_tagging/train_pos_tag.py index 6448c32b..06547701 100644 --- a/reproduction/POS_tagging/train_pos_tag.py +++ b/reproduction/POS_tagging/train_pos_tag.py @@ -14,7 +14,7 @@ from fastNLP.core.metrics import SpanFPreRecMetric from fastNLP.core.trainer import Trainer from fastNLP.io.config_io import ConfigLoader, ConfigSection from fastNLP.models.sequence_modeling import AdvSeqLabel -from fastNLP.io.dataset_loader import ZhConllPOSReader, ConllxDataLoader +from fastNLP.io.dataset_loader import ConllxDataLoader from fastNLP.api.processor import ModelProcessor, Index2WordProcessor @@ -35,7 +35,7 @@ def load_tencent_embed(embed_path, word2id): return embedding_tensor -def train(train_data_path, dev_data_path, checkpoint=None): +def train(train_data_path, dev_data_path, checkpoint=None, save=None): # load config train_param = ConfigSection() model_param = ConfigSection() @@ -44,9 +44,9 @@ def train(train_data_path, dev_data_path, checkpoint=None): # Data Loader print("loading training set...") - dataset = ConllxDataLoader().load(train_data_path) + dataset = ConllxDataLoader().load(train_data_path, return_dataset=True) print("loading dev set...") - dev_data = ConllxDataLoader().load(dev_data_path) + dev_data = ConllxDataLoader().load(dev_data_path, return_dataset=True) print(dataset) print("================= dataset ready =====================") @@ -54,9 +54,9 @@ def train(train_data_path, dev_data_path, checkpoint=None): dev_data.rename_field("tag", "truth") vocab_proc = VocabIndexerProcessor("words", new_added_filed_name="word_seq") - tag_proc = VocabIndexerProcessor("truth") + tag_proc = VocabIndexerProcessor("truth", is_input=True) seq_len_proc = SeqLenProcessor(field_name="word_seq", new_added_field_name="word_seq_origin_len", is_input=True) - set_input_proc = SetInputProcessor("word_seq", "word_seq_origin_len", "truth") + set_input_proc = SetInputProcessor("word_seq", "word_seq_origin_len") vocab_proc(dataset) tag_proc(dataset) @@ -93,7 +93,7 @@ def train(train_data_path, dev_data_path, checkpoint=None): target="truth", seq_lens="word_seq_origin_len"), dev_data=dev_data, metric_key="f", - use_tqdm=True, use_cuda=True, print_every=10, n_epochs=20, save_path="./save_0117") + use_tqdm=True, use_cuda=True, print_every=10, n_epochs=20, save_path=save) trainer.train(load_best_model=True) # save model & pipeline @@ -102,12 +102,12 @@ def train(train_data_path, dev_data_path, checkpoint=None): pp = Pipeline([vocab_proc, seq_len_proc, set_input_proc, model_proc, id2tag]) save_dict = {"pipeline": pp, "model": model, "tag_vocab": tag_proc.vocab} - torch.save(save_dict, "model_pp_0117.pkl") + torch.save(save_dict, os.path.join(save, "model_pp.pkl")) print("pipeline saved") def run_test(test_path): - test_data = ZhConllPOSReader().load(test_path) + test_data = ConllxDataLoader().load(test_path, return_dataset=True) with open("model_pp_0117.pkl", "rb") as f: save_dict = torch.load(f) @@ -157,7 +157,7 @@ if __name__ == "__main__": # 继续训练 python train_pos_tag.py -c -cp ./save/best_model.pkl if args.checkpoint is None: raise RuntimeError("Please provide the checkpoint. -cp ") - train(args.train, args.dev, args.checkpoint) + train(args.train, args.dev, args.checkpoint, save=args.save) else: # 一次训练 python train_pos_tag.py - train(args.train, args.dev) + train(args.train, args.dev, save=args.save)