Browse Source

Update POS API

tags/v0.3.1^2
FengZiYjun 5 years ago
parent
commit
b14dd58828
3 changed files with 17 additions and 13 deletions
  1. +1
    -1
      fastNLP/api/api.py
  2. +5
    -1
      fastNLP/api/examples.py
  3. +11
    -11
      reproduction/POS_tagging/train_pos_tag.py

+ 1
- 1
fastNLP/api/api.py View File

@@ -18,7 +18,7 @@ from fastNLP.api.processor import IndexerProcessor
# TODO add pretrain urls
model_urls = {
"cws": "http://123.206.98.91:8888/download/cws_crf_1_11-457fc899.pkl",
"pos": "http://123.206.98.91:8888/download/pos_tag_model_20190108-f3c60ee5.pkl",
"pos": "http://123.206.98.91:8888/download/pos_tag_model_20190119-43f8b435.pkl",
"parser": "http://123.206.98.91:8888/download/biaffine_parser-3a2f052c.pkl"
}



+ 5
- 1
fastNLP/api/examples.py View File

@@ -16,6 +16,10 @@ def chinese_word_segmentation():


def pos_tagging():
# 输入已分词序列
text = ['编者 按: 7月 12日 , 英国 航空 航天 系统 公司 公布 了 该 公司 研制 的 第一款 高科技 隐形 无人机 雷电之神 。']
text = [text[0].split()]
print(text)
pos = POS(device='cpu')
print(pos.predict(text))

@@ -26,4 +30,4 @@ def syntactic_parsing():


if __name__ == "__main__":
syntactic_parsing()
pos_tagging()

+ 11
- 11
reproduction/POS_tagging/train_pos_tag.py View File

@@ -14,7 +14,7 @@ from fastNLP.core.metrics import SpanFPreRecMetric
from fastNLP.core.trainer import Trainer
from fastNLP.io.config_io import ConfigLoader, ConfigSection
from fastNLP.models.sequence_modeling import AdvSeqLabel
from fastNLP.io.dataset_loader import ZhConllPOSReader, ConllxDataLoader
from fastNLP.io.dataset_loader import ConllxDataLoader
from fastNLP.api.processor import ModelProcessor, Index2WordProcessor


@@ -35,7 +35,7 @@ def load_tencent_embed(embed_path, word2id):
return embedding_tensor


def train(train_data_path, dev_data_path, checkpoint=None):
def train(train_data_path, dev_data_path, checkpoint=None, save=None):
# load config
train_param = ConfigSection()
model_param = ConfigSection()
@@ -44,9 +44,9 @@ def train(train_data_path, dev_data_path, checkpoint=None):

# Data Loader
print("loading training set...")
dataset = ConllxDataLoader().load(train_data_path)
dataset = ConllxDataLoader().load(train_data_path, return_dataset=True)
print("loading dev set...")
dev_data = ConllxDataLoader().load(dev_data_path)
dev_data = ConllxDataLoader().load(dev_data_path, return_dataset=True)
print(dataset)
print("================= dataset ready =====================")

@@ -54,9 +54,9 @@ def train(train_data_path, dev_data_path, checkpoint=None):
dev_data.rename_field("tag", "truth")

vocab_proc = VocabIndexerProcessor("words", new_added_filed_name="word_seq")
tag_proc = VocabIndexerProcessor("truth")
tag_proc = VocabIndexerProcessor("truth", is_input=True)
seq_len_proc = SeqLenProcessor(field_name="word_seq", new_added_field_name="word_seq_origin_len", is_input=True)
set_input_proc = SetInputProcessor("word_seq", "word_seq_origin_len", "truth")
set_input_proc = SetInputProcessor("word_seq", "word_seq_origin_len")

vocab_proc(dataset)
tag_proc(dataset)
@@ -93,7 +93,7 @@ def train(train_data_path, dev_data_path, checkpoint=None):
target="truth",
seq_lens="word_seq_origin_len"),
dev_data=dev_data, metric_key="f",
use_tqdm=True, use_cuda=True, print_every=10, n_epochs=20, save_path="./save_0117")
use_tqdm=True, use_cuda=True, print_every=10, n_epochs=20, save_path=save)
trainer.train(load_best_model=True)

# save model & pipeline
@@ -102,12 +102,12 @@ def train(train_data_path, dev_data_path, checkpoint=None):

pp = Pipeline([vocab_proc, seq_len_proc, set_input_proc, model_proc, id2tag])
save_dict = {"pipeline": pp, "model": model, "tag_vocab": tag_proc.vocab}
torch.save(save_dict, "model_pp_0117.pkl")
torch.save(save_dict, os.path.join(save, "model_pp.pkl"))
print("pipeline saved")


def run_test(test_path):
test_data = ZhConllPOSReader().load(test_path)
test_data = ConllxDataLoader().load(test_path, return_dataset=True)

with open("model_pp_0117.pkl", "rb") as f:
save_dict = torch.load(f)
@@ -157,7 +157,7 @@ if __name__ == "__main__":
# 继续训练 python train_pos_tag.py -c -cp ./save/best_model.pkl
if args.checkpoint is None:
raise RuntimeError("Please provide the checkpoint. -cp ")
train(args.train, args.dev, args.checkpoint)
train(args.train, args.dev, args.checkpoint, save=args.save)
else:
# 一次训练 python train_pos_tag.py
train(args.train, args.dev)
train(args.train, args.dev, save=args.save)

Loading…
Cancel
Save