|
|
@@ -14,7 +14,7 @@ from fastNLP.core.metrics import SpanFPreRecMetric |
|
|
|
from fastNLP.core.trainer import Trainer |
|
|
|
from fastNLP.io.config_io import ConfigLoader, ConfigSection |
|
|
|
from fastNLP.models.sequence_modeling import AdvSeqLabel |
|
|
|
from fastNLP.io.dataset_loader import ZhConllPOSReader, ConllxDataLoader |
|
|
|
from fastNLP.io.dataset_loader import ConllxDataLoader |
|
|
|
from fastNLP.api.processor import ModelProcessor, Index2WordProcessor |
|
|
|
|
|
|
|
|
|
|
@@ -35,7 +35,7 @@ def load_tencent_embed(embed_path, word2id): |
|
|
|
return embedding_tensor |
|
|
|
|
|
|
|
|
|
|
|
def train(train_data_path, dev_data_path, checkpoint=None): |
|
|
|
def train(train_data_path, dev_data_path, checkpoint=None, save=None): |
|
|
|
# load config |
|
|
|
train_param = ConfigSection() |
|
|
|
model_param = ConfigSection() |
|
|
@@ -44,9 +44,9 @@ def train(train_data_path, dev_data_path, checkpoint=None): |
|
|
|
|
|
|
|
# Data Loader |
|
|
|
print("loading training set...") |
|
|
|
dataset = ConllxDataLoader().load(train_data_path) |
|
|
|
dataset = ConllxDataLoader().load(train_data_path, return_dataset=True) |
|
|
|
print("loading dev set...") |
|
|
|
dev_data = ConllxDataLoader().load(dev_data_path) |
|
|
|
dev_data = ConllxDataLoader().load(dev_data_path, return_dataset=True) |
|
|
|
print(dataset) |
|
|
|
print("================= dataset ready =====================") |
|
|
|
|
|
|
@@ -54,9 +54,9 @@ def train(train_data_path, dev_data_path, checkpoint=None): |
|
|
|
dev_data.rename_field("tag", "truth") |
|
|
|
|
|
|
|
vocab_proc = VocabIndexerProcessor("words", new_added_filed_name="word_seq") |
|
|
|
tag_proc = VocabIndexerProcessor("truth") |
|
|
|
tag_proc = VocabIndexerProcessor("truth", is_input=True) |
|
|
|
seq_len_proc = SeqLenProcessor(field_name="word_seq", new_added_field_name="word_seq_origin_len", is_input=True) |
|
|
|
set_input_proc = SetInputProcessor("word_seq", "word_seq_origin_len", "truth") |
|
|
|
set_input_proc = SetInputProcessor("word_seq", "word_seq_origin_len") |
|
|
|
|
|
|
|
vocab_proc(dataset) |
|
|
|
tag_proc(dataset) |
|
|
@@ -93,7 +93,7 @@ def train(train_data_path, dev_data_path, checkpoint=None): |
|
|
|
target="truth", |
|
|
|
seq_lens="word_seq_origin_len"), |
|
|
|
dev_data=dev_data, metric_key="f", |
|
|
|
use_tqdm=True, use_cuda=True, print_every=10, n_epochs=20, save_path="./save_0117") |
|
|
|
use_tqdm=True, use_cuda=True, print_every=10, n_epochs=20, save_path=save) |
|
|
|
trainer.train(load_best_model=True) |
|
|
|
|
|
|
|
# save model & pipeline |
|
|
@@ -102,12 +102,12 @@ def train(train_data_path, dev_data_path, checkpoint=None): |
|
|
|
|
|
|
|
pp = Pipeline([vocab_proc, seq_len_proc, set_input_proc, model_proc, id2tag]) |
|
|
|
save_dict = {"pipeline": pp, "model": model, "tag_vocab": tag_proc.vocab} |
|
|
|
torch.save(save_dict, "model_pp_0117.pkl") |
|
|
|
torch.save(save_dict, os.path.join(save, "model_pp.pkl")) |
|
|
|
print("pipeline saved") |
|
|
|
|
|
|
|
|
|
|
|
def run_test(test_path): |
|
|
|
test_data = ZhConllPOSReader().load(test_path) |
|
|
|
test_data = ConllxDataLoader().load(test_path, return_dataset=True) |
|
|
|
|
|
|
|
with open("model_pp_0117.pkl", "rb") as f: |
|
|
|
save_dict = torch.load(f) |
|
|
@@ -157,7 +157,7 @@ if __name__ == "__main__": |
|
|
|
# 继续训练 python train_pos_tag.py -c -cp ./save/best_model.pkl |
|
|
|
if args.checkpoint is None: |
|
|
|
raise RuntimeError("Please provide the checkpoint. -cp ") |
|
|
|
train(args.train, args.dev, args.checkpoint) |
|
|
|
train(args.train, args.dev, args.checkpoint, save=args.save) |
|
|
|
else: |
|
|
|
# 一次训练 python train_pos_tag.py |
|
|
|
train(args.train, args.dev) |
|
|
|
train(args.train, args.dev, save=args.save) |