|
|
@@ -1,4 +1,6 @@ |
|
|
|
import argparse |
|
|
|
import os |
|
|
|
import pickle |
|
|
|
import sys |
|
|
|
|
|
|
|
import torch |
|
|
@@ -21,7 +23,20 @@ cfgfile = './pos_tag.cfg' |
|
|
|
pickle_path = "save" |
|
|
|
|
|
|
|
|
|
|
|
def train(): |
|
|
|
def load_tencent_embed(embed_path, word2id): |
|
|
|
hit = 0 |
|
|
|
with open(embed_path, "rb") as f: |
|
|
|
embed_dict = pickle.load(f) |
|
|
|
embedding_tensor = torch.randn(len(word2id), 200) |
|
|
|
for key in word2id: |
|
|
|
if key in embed_dict: |
|
|
|
embedding_tensor[word2id[key]] = torch.Tensor(embed_dict[key]) |
|
|
|
hit += 1 |
|
|
|
print("vocab_size={} hit={} hit/vocab_size={}".format(len(word2id), hit, hit / len(word2id))) |
|
|
|
return embedding_tensor |
|
|
|
|
|
|
|
|
|
|
|
def train(checkpoint=None): |
|
|
|
# load config |
|
|
|
train_param = ConfigSection() |
|
|
|
model_param = ConfigSection() |
|
|
@@ -54,15 +69,21 @@ def train(): |
|
|
|
print("vocab_size={} num_classes={}".format(model_param["vocab_size"], model_param["num_classes"])) |
|
|
|
|
|
|
|
# define a model |
|
|
|
model = AdvSeqLabel(model_param, id2words=tag_proc.vocab.idx2word) |
|
|
|
if checkpoint is None: |
|
|
|
# pre_trained = load_tencent_embed("/home/zyfeng/data/char_tencent_embedding.pkl", vocab_proc.vocab.word2idx) |
|
|
|
pre_trained = None |
|
|
|
model = AdvSeqLabel(model_param, id2words=tag_proc.vocab.idx2word, emb=pre_trained) |
|
|
|
print(model) |
|
|
|
else: |
|
|
|
model = torch.load(checkpoint) |
|
|
|
|
|
|
|
# call trainer to train |
|
|
|
trainer = Trainer(dataset, model, loss=None, metrics=SpanFPreRecMetric(tag_proc.vocab, pred="predict", |
|
|
|
target="truth", |
|
|
|
seq_lens="word_seq_origin_len"), |
|
|
|
dev_data=dataset, metric_key="f", |
|
|
|
use_tqdm=False, use_cuda=True, print_every=20, n_epochs=1, save_path="./save") |
|
|
|
trainer.train() |
|
|
|
use_tqdm=True, use_cuda=True, print_every=5, n_epochs=6, save_path="./save") |
|
|
|
trainer.train(load_best_model=True) |
|
|
|
|
|
|
|
# save model & pipeline |
|
|
|
model_proc = ModelProcessor(model, seq_len_field_name="word_seq_origin_len") |
|
|
@@ -73,10 +94,20 @@ def train(): |
|
|
|
torch.save(save_dict, "model_pp.pkl") |
|
|
|
print("pipeline saved") |
|
|
|
|
|
|
|
|
|
|
|
def infer(): |
|
|
|
pass |
|
|
|
torch.save(model, "./save/best_model.pkl") |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
train() |
|
|
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument("-c", "--restart", action="store_true", help="whether to continue training") |
|
|
|
parser.add_argument("-cp", "--checkpoint", type=str, help="checkpoint of the trained model") |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
if args.restart is True: |
|
|
|
# 继续训练 python train_pos_tag.py -c -cp ./save/best_model.pkl |
|
|
|
if args.checkpoint is None: |
|
|
|
raise RuntimeError("Please provide the checkpoint. -cp ") |
|
|
|
train(args.checkpoint) |
|
|
|
else: |
|
|
|
# 一次训练 python train_pos_tag.py |
|
|
|
train() |