diff --git a/fastNLP/api/api.py b/fastNLP/api/api.py index 2e6cc247..47c29214 100644 --- a/fastNLP/api/api.py +++ b/fastNLP/api/api.py @@ -10,13 +10,15 @@ from fastNLP.core.dataset import DataSet from fastNLP.api.model_zoo import load_url from fastNLP.api.processor import ModelProcessor from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader -from reproduction.pos_tag_model.pos_reader import ConllPOSReader +from reproduction.pos_tag_model.pos_reader import ZhConllPOSReader from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag from fastNLP.core.instance import Instance from fastNLP.core.sampler import SequentialSampler from fastNLP.core.batch import Batch from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1 from fastNLP.api.pipeline import Pipeline +from fastNLP.core.metrics import SpanFPreRecMetric +from fastNLP.api.processor import IndexerProcessor # TODO add pretrain urls @@ -65,7 +67,7 @@ class POS(API): :param content: list of list of str. Each string is a token(word). :return answer: list of list of str. Each string is a tag. """ - if not hasattr(self, 'pipeline'): + if not hasattr(self, "pipeline"): raise ValueError("You have to load model first.") sentence_list = [] @@ -104,47 +106,35 @@ class POS(API): elif isinstance(content, list): return output - def test(self, filepath): - - tag_proc = self._dict['tag_indexer'] - - model = self.pipeline.pipeline[2].model - pipeline = self.pipeline.pipeline[0:2] - pipeline.append(tag_proc) - pp = Pipeline(pipeline) - - reader = ConllPOSReader() - te_dataset = reader.load(filepath) - - """ - evaluator = SeqLabelEvaluator2('word_seq_origin_len') - end_tagidx_set = set() - tag_proc.vocab.build_vocab() - for key, value in tag_proc.vocab.word2idx.items(): - if key.startswith('E-'): - end_tagidx_set.add(value) - if key.startswith('S-'): - end_tagidx_set.add(value) - evaluator.end_tagidx_set = end_tagidx_set - - pp(te_dataset) - te_dataset.set_target(truth=True) - - default_valid_args = {"batch_size": 64, - "use_cuda": True, "evaluator": evaluator, - "model": model, "data": te_dataset} - - tester = Tester(**default_valid_args) - - test_result = tester.test() - - f1 = round(test_result['F'] * 100, 2) - pre = round(test_result['P'] * 100, 2) - rec = round(test_result['R'] * 100, 2) - # print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec)) - - return f1, pre, rec - """ + def test(self, file_path): + test_data = ZhConllPOSReader().load(file_path) + + tag_vocab = self._dict["tag_vocab"] + pipeline = self._dict["pipeline"] + index_tag = IndexerProcessor(vocab=tag_vocab, field_name="tag", new_added_field_name="truth", is_input=False) + pipeline.pipeline = [index_tag] + pipeline.pipeline + + pipeline(test_data) + test_data.set_target("truth") + prediction = test_data.field_arrays["predict"].content + truth = test_data.field_arrays["truth"].content + seq_len = test_data.field_arrays["word_seq_origin_len"].content + + # padding by hand + max_length = max([len(seq) for seq in prediction]) + for idx in range(len(prediction)): + prediction[idx] = list(prediction[idx]) + ([0] * (max_length - len(prediction[idx]))) + truth[idx] = list(truth[idx]) + ([0] * (max_length - len(truth[idx]))) + evaluator = SpanFPreRecMetric(tag_vocab=tag_vocab, pred="predict", target="truth", + seq_lens="word_seq_origin_len") + evaluator({"predict": torch.Tensor(prediction), "word_seq_origin_len": torch.Tensor(seq_len)}, + {"truth": torch.Tensor(truth)}) + test_result = evaluator.get_metric() + f1 = round(test_result['f'] * 100, 2) + pre = round(test_result['pre'] * 100, 2) + rec = round(test_result['rec'] * 100, 2) + + return {"F1": f1, "precision": pre, "recall": rec} class CWS(API): @@ -316,8 +306,8 @@ if __name__ == "__main__": s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', '那么这款无人机到底有多厉害?'] - # print(pos.test('/Users/yh/Desktop/test_data/pos_test.conll')) - print(pos.predict(s)) + print(pos.test("/home/zyfeng/data/sample.conllx")) + # print(pos.predict(s)) # cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf.pkl' # cws = CWS(device='cpu') diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index 7b3b6d11..05160312 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -1,6 +1,8 @@ import numpy as np import torch +from fastNLP.core.sampler import RandomSampler + class Batch(object): """Batch is an iterable object which iterates over mini-batches. @@ -17,7 +19,7 @@ class Batch(object): """ - def __init__(self, dataset, batch_size, sampler, as_numpy=False): + def __init__(self, dataset, batch_size, sampler=RandomSampler(), as_numpy=False): self.dataset = dataset self.batch_size = batch_size self.sampler = sampler diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 39d5bcf3..dfb20480 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -451,8 +451,8 @@ class SpanFPreRecMetric(MetricBase): batch_size = pred.size(0) for i in range(batch_size): - pred_tags = pred[i, :seq_lens[i]].tolist() - gold_tags = target[i, :seq_lens[i]].tolist() + pred_tags = pred[i, :int(seq_lens[i])].tolist() + gold_tags = target[i, :int(seq_lens[i])].tolist() pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags] gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags] diff --git a/reproduction/pos_tag_model/pos_tag.cfg b/reproduction/pos_tag_model/pos_tag.cfg index f8224234..c9ee8320 100644 --- a/reproduction/pos_tag_model/pos_tag.cfg +++ b/reproduction/pos_tag_model/pos_tag.cfg @@ -10,7 +10,7 @@ eval_sort_key = 'accuracy' [model] rnn_hidden_units = 300 -word_emb_dim = 300 +word_emb_dim = 100 dropout = 0.5 use_crf = true print_every_step = 10 diff --git a/reproduction/pos_tag_model/train_pos_tag.py b/reproduction/pos_tag_model/train_pos_tag.py index c01d50f3..adc9359c 100644 --- a/reproduction/pos_tag_model/train_pos_tag.py +++ b/reproduction/pos_tag_model/train_pos_tag.py @@ -1,4 +1,6 @@ +import argparse import os +import pickle import sys import torch @@ -21,7 +23,20 @@ cfgfile = './pos_tag.cfg' pickle_path = "save" -def train(): +def load_tencent_embed(embed_path, word2id): + hit = 0 + with open(embed_path, "rb") as f: + embed_dict = pickle.load(f) + embedding_tensor = torch.randn(len(word2id), 200) + for key in word2id: + if key in embed_dict: + embedding_tensor[word2id[key]] = torch.Tensor(embed_dict[key]) + hit += 1 + print("vocab_size={} hit={} hit/vocab_size={}".format(len(word2id), hit, hit / len(word2id))) + return embedding_tensor + + +def train(checkpoint=None): # load config train_param = ConfigSection() model_param = ConfigSection() @@ -54,15 +69,21 @@ def train(): print("vocab_size={} num_classes={}".format(model_param["vocab_size"], model_param["num_classes"])) # define a model - model = AdvSeqLabel(model_param, id2words=tag_proc.vocab.idx2word) + if checkpoint is None: + # pre_trained = load_tencent_embed("/home/zyfeng/data/char_tencent_embedding.pkl", vocab_proc.vocab.word2idx) + pre_trained = None + model = AdvSeqLabel(model_param, id2words=tag_proc.vocab.idx2word, emb=pre_trained) + print(model) + else: + model = torch.load(checkpoint) # call trainer to train trainer = Trainer(dataset, model, loss=None, metrics=SpanFPreRecMetric(tag_proc.vocab, pred="predict", target="truth", seq_lens="word_seq_origin_len"), dev_data=dataset, metric_key="f", - use_tqdm=False, use_cuda=True, print_every=20, n_epochs=1, save_path="./save") - trainer.train() + use_tqdm=True, use_cuda=True, print_every=5, n_epochs=6, save_path="./save") + trainer.train(load_best_model=True) # save model & pipeline model_proc = ModelProcessor(model, seq_len_field_name="word_seq_origin_len") @@ -73,10 +94,20 @@ def train(): torch.save(save_dict, "model_pp.pkl") print("pipeline saved") - -def infer(): - pass + torch.save(model, "./save/best_model.pkl") if __name__ == "__main__": - train() + parser = argparse.ArgumentParser() + parser.add_argument("-c", "--restart", action="store_true", help="whether to continue training") + parser.add_argument("-cp", "--checkpoint", type=str, help="checkpoint of the trained model") + args = parser.parse_args() + + if args.restart is True: + # 继续训练 python train_pos_tag.py -c -cp ./save/best_model.pkl + if args.checkpoint is None: + raise RuntimeError("Please provide the checkpoint. -cp ") + train(args.checkpoint) + else: + # 一次训练 python train_pos_tag.py + train() diff --git a/reproduction/pos_tag_model/utils.py b/reproduction/pos_tag_model/utils.py new file mode 100644 index 00000000..bf10bf47 --- /dev/null +++ b/reproduction/pos_tag_model/utils.py @@ -0,0 +1,25 @@ +import pickle + + +def load_embed(embed_path): + embed_dict = {} + with open(embed_path, "r", encoding="utf-8") as f: + for line in f: + tokens = line.split(" ") + if len(tokens) <= 5: + continue + key = tokens[0] + if len(key) == 1: + value = [float(x) for x in tokens[1:]] + embed_dict[key] = value + return embed_dict + + +if __name__ == "__main__": + embed_dict = load_embed("/home/zyfeng/data/small.txt") + + print(embed_dict.keys()) + + with open("./char_tencent_embedding.pkl", "wb") as f: + pickle.dump(embed_dict, f) + print("finished")