diff --git a/reproduction/pos_tag_model/pos_tag.cfg b/reproduction/pos_tag_model/pos_tag.cfg index f8224234..c9ee8320 100644 --- a/reproduction/pos_tag_model/pos_tag.cfg +++ b/reproduction/pos_tag_model/pos_tag.cfg @@ -10,7 +10,7 @@ eval_sort_key = 'accuracy' [model] rnn_hidden_units = 300 -word_emb_dim = 300 +word_emb_dim = 100 dropout = 0.5 use_crf = true print_every_step = 10 diff --git a/reproduction/pos_tag_model/train_pos_tag.py b/reproduction/pos_tag_model/train_pos_tag.py index c01d50f3..adc9359c 100644 --- a/reproduction/pos_tag_model/train_pos_tag.py +++ b/reproduction/pos_tag_model/train_pos_tag.py @@ -1,4 +1,6 @@ +import argparse import os +import pickle import sys import torch @@ -21,7 +23,20 @@ cfgfile = './pos_tag.cfg' pickle_path = "save" -def train(): +def load_tencent_embed(embed_path, word2id): + hit = 0 + with open(embed_path, "rb") as f: + embed_dict = pickle.load(f) + embedding_tensor = torch.randn(len(word2id), 200) + for key in word2id: + if key in embed_dict: + embedding_tensor[word2id[key]] = torch.Tensor(embed_dict[key]) + hit += 1 + print("vocab_size={} hit={} hit/vocab_size={}".format(len(word2id), hit, hit / len(word2id))) + return embedding_tensor + + +def train(checkpoint=None): # load config train_param = ConfigSection() model_param = ConfigSection() @@ -54,15 +69,21 @@ def train(): print("vocab_size={} num_classes={}".format(model_param["vocab_size"], model_param["num_classes"])) # define a model - model = AdvSeqLabel(model_param, id2words=tag_proc.vocab.idx2word) + if checkpoint is None: + # pre_trained = load_tencent_embed("/home/zyfeng/data/char_tencent_embedding.pkl", vocab_proc.vocab.word2idx) + pre_trained = None + model = AdvSeqLabel(model_param, id2words=tag_proc.vocab.idx2word, emb=pre_trained) + print(model) + else: + model = torch.load(checkpoint) # call trainer to train trainer = Trainer(dataset, model, loss=None, metrics=SpanFPreRecMetric(tag_proc.vocab, pred="predict", target="truth", seq_lens="word_seq_origin_len"), dev_data=dataset, metric_key="f", - use_tqdm=False, use_cuda=True, print_every=20, n_epochs=1, save_path="./save") - trainer.train() + use_tqdm=True, use_cuda=True, print_every=5, n_epochs=6, save_path="./save") + trainer.train(load_best_model=True) # save model & pipeline model_proc = ModelProcessor(model, seq_len_field_name="word_seq_origin_len") @@ -73,10 +94,20 @@ def train(): torch.save(save_dict, "model_pp.pkl") print("pipeline saved") - -def infer(): - pass + torch.save(model, "./save/best_model.pkl") if __name__ == "__main__": - train() + parser = argparse.ArgumentParser() + parser.add_argument("-c", "--restart", action="store_true", help="whether to continue training") + parser.add_argument("-cp", "--checkpoint", type=str, help="checkpoint of the trained model") + args = parser.parse_args() + + if args.restart is True: + # 继续训练 python train_pos_tag.py -c -cp ./save/best_model.pkl + if args.checkpoint is None: + raise RuntimeError("Please provide the checkpoint. -cp ") + train(args.checkpoint) + else: + # 一次训练 python train_pos_tag.py + train() diff --git a/reproduction/pos_tag_model/utils.py b/reproduction/pos_tag_model/utils.py new file mode 100644 index 00000000..bf10bf47 --- /dev/null +++ b/reproduction/pos_tag_model/utils.py @@ -0,0 +1,25 @@ +import pickle + + +def load_embed(embed_path): + embed_dict = {} + with open(embed_path, "r", encoding="utf-8") as f: + for line in f: + tokens = line.split(" ") + if len(tokens) <= 5: + continue + key = tokens[0] + if len(key) == 1: + value = [float(x) for x in tokens[1:]] + embed_dict[key] = value + return embed_dict + + +if __name__ == "__main__": + embed_dict = load_embed("/home/zyfeng/data/small.txt") + + print(embed_dict.keys()) + + with open("./char_tencent_embedding.pkl", "wb") as f: + pickle.dump(embed_dict, f) + print("finished")