@@ -10,13 +10,15 @@ from fastNLP.core.dataset import DataSet | |||
from fastNLP.api.model_zoo import load_url | |||
from fastNLP.api.processor import ModelProcessor | |||
from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader | |||
from reproduction.pos_tag_model.pos_reader import ConllPOSReader | |||
from reproduction.pos_tag_model.pos_reader import ZhConllPOSReader | |||
from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag | |||
from fastNLP.core.instance import Instance | |||
from fastNLP.core.sampler import SequentialSampler | |||
from fastNLP.core.batch import Batch | |||
from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1 | |||
from fastNLP.api.pipeline import Pipeline | |||
from fastNLP.core.metrics import SpanFPreRecMetric | |||
from fastNLP.api.processor import IndexerProcessor | |||
# TODO add pretrain urls | |||
@@ -65,7 +67,7 @@ class POS(API): | |||
:param content: list of list of str. Each string is a token(word). | |||
:return answer: list of list of str. Each string is a tag. | |||
""" | |||
if not hasattr(self, 'pipeline'): | |||
if not hasattr(self, "pipeline"): | |||
raise ValueError("You have to load model first.") | |||
sentence_list = [] | |||
@@ -104,47 +106,35 @@ class POS(API): | |||
elif isinstance(content, list): | |||
return output | |||
def test(self, filepath): | |||
tag_proc = self._dict['tag_indexer'] | |||
model = self.pipeline.pipeline[2].model | |||
pipeline = self.pipeline.pipeline[0:2] | |||
pipeline.append(tag_proc) | |||
pp = Pipeline(pipeline) | |||
reader = ConllPOSReader() | |||
te_dataset = reader.load(filepath) | |||
""" | |||
evaluator = SeqLabelEvaluator2('word_seq_origin_len') | |||
end_tagidx_set = set() | |||
tag_proc.vocab.build_vocab() | |||
for key, value in tag_proc.vocab.word2idx.items(): | |||
if key.startswith('E-'): | |||
end_tagidx_set.add(value) | |||
if key.startswith('S-'): | |||
end_tagidx_set.add(value) | |||
evaluator.end_tagidx_set = end_tagidx_set | |||
pp(te_dataset) | |||
te_dataset.set_target(truth=True) | |||
default_valid_args = {"batch_size": 64, | |||
"use_cuda": True, "evaluator": evaluator, | |||
"model": model, "data": te_dataset} | |||
tester = Tester(**default_valid_args) | |||
test_result = tester.test() | |||
f1 = round(test_result['F'] * 100, 2) | |||
pre = round(test_result['P'] * 100, 2) | |||
rec = round(test_result['R'] * 100, 2) | |||
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec)) | |||
return f1, pre, rec | |||
""" | |||
def test(self, file_path): | |||
test_data = ZhConllPOSReader().load(file_path) | |||
tag_vocab = self._dict["tag_vocab"] | |||
pipeline = self._dict["pipeline"] | |||
index_tag = IndexerProcessor(vocab=tag_vocab, field_name="tag", new_added_field_name="truth", is_input=False) | |||
pipeline.pipeline = [index_tag] + pipeline.pipeline | |||
pipeline(test_data) | |||
test_data.set_target("truth") | |||
prediction = test_data.field_arrays["predict"].content | |||
truth = test_data.field_arrays["truth"].content | |||
seq_len = test_data.field_arrays["word_seq_origin_len"].content | |||
# padding by hand | |||
max_length = max([len(seq) for seq in prediction]) | |||
for idx in range(len(prediction)): | |||
prediction[idx] = list(prediction[idx]) + ([0] * (max_length - len(prediction[idx]))) | |||
truth[idx] = list(truth[idx]) + ([0] * (max_length - len(truth[idx]))) | |||
evaluator = SpanFPreRecMetric(tag_vocab=tag_vocab, pred="predict", target="truth", | |||
seq_lens="word_seq_origin_len") | |||
evaluator({"predict": torch.Tensor(prediction), "word_seq_origin_len": torch.Tensor(seq_len)}, | |||
{"truth": torch.Tensor(truth)}) | |||
test_result = evaluator.get_metric() | |||
f1 = round(test_result['f'] * 100, 2) | |||
pre = round(test_result['pre'] * 100, 2) | |||
rec = round(test_result['rec'] * 100, 2) | |||
return {"F1": f1, "precision": pre, "recall": rec} | |||
class CWS(API): | |||
@@ -316,8 +306,8 @@ if __name__ == "__main__": | |||
s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', | |||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||
'那么这款无人机到底有多厉害?'] | |||
# print(pos.test('/Users/yh/Desktop/test_data/pos_test.conll')) | |||
print(pos.predict(s)) | |||
print(pos.test("/home/zyfeng/data/sample.conllx")) | |||
# print(pos.predict(s)) | |||
# cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf.pkl' | |||
# cws = CWS(device='cpu') | |||
@@ -1,6 +1,8 @@ | |||
import numpy as np | |||
import torch | |||
from fastNLP.core.sampler import RandomSampler | |||
class Batch(object): | |||
"""Batch is an iterable object which iterates over mini-batches. | |||
@@ -17,7 +19,7 @@ class Batch(object): | |||
""" | |||
def __init__(self, dataset, batch_size, sampler, as_numpy=False): | |||
def __init__(self, dataset, batch_size, sampler=RandomSampler(), as_numpy=False): | |||
self.dataset = dataset | |||
self.batch_size = batch_size | |||
self.sampler = sampler | |||
@@ -451,8 +451,8 @@ class SpanFPreRecMetric(MetricBase): | |||
batch_size = pred.size(0) | |||
for i in range(batch_size): | |||
pred_tags = pred[i, :seq_lens[i]].tolist() | |||
gold_tags = target[i, :seq_lens[i]].tolist() | |||
pred_tags = pred[i, :int(seq_lens[i])].tolist() | |||
gold_tags = target[i, :int(seq_lens[i])].tolist() | |||
pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags] | |||
gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags] | |||
@@ -10,7 +10,7 @@ eval_sort_key = 'accuracy' | |||
[model] | |||
rnn_hidden_units = 300 | |||
word_emb_dim = 300 | |||
word_emb_dim = 100 | |||
dropout = 0.5 | |||
use_crf = true | |||
print_every_step = 10 | |||
@@ -1,4 +1,6 @@ | |||
import argparse | |||
import os | |||
import pickle | |||
import sys | |||
import torch | |||
@@ -21,7 +23,20 @@ cfgfile = './pos_tag.cfg' | |||
pickle_path = "save" | |||
def train(): | |||
def load_tencent_embed(embed_path, word2id): | |||
hit = 0 | |||
with open(embed_path, "rb") as f: | |||
embed_dict = pickle.load(f) | |||
embedding_tensor = torch.randn(len(word2id), 200) | |||
for key in word2id: | |||
if key in embed_dict: | |||
embedding_tensor[word2id[key]] = torch.Tensor(embed_dict[key]) | |||
hit += 1 | |||
print("vocab_size={} hit={} hit/vocab_size={}".format(len(word2id), hit, hit / len(word2id))) | |||
return embedding_tensor | |||
def train(checkpoint=None): | |||
# load config | |||
train_param = ConfigSection() | |||
model_param = ConfigSection() | |||
@@ -54,15 +69,21 @@ def train(): | |||
print("vocab_size={} num_classes={}".format(model_param["vocab_size"], model_param["num_classes"])) | |||
# define a model | |||
model = AdvSeqLabel(model_param, id2words=tag_proc.vocab.idx2word) | |||
if checkpoint is None: | |||
# pre_trained = load_tencent_embed("/home/zyfeng/data/char_tencent_embedding.pkl", vocab_proc.vocab.word2idx) | |||
pre_trained = None | |||
model = AdvSeqLabel(model_param, id2words=tag_proc.vocab.idx2word, emb=pre_trained) | |||
print(model) | |||
else: | |||
model = torch.load(checkpoint) | |||
# call trainer to train | |||
trainer = Trainer(dataset, model, loss=None, metrics=SpanFPreRecMetric(tag_proc.vocab, pred="predict", | |||
target="truth", | |||
seq_lens="word_seq_origin_len"), | |||
dev_data=dataset, metric_key="f", | |||
use_tqdm=False, use_cuda=True, print_every=20, n_epochs=1, save_path="./save") | |||
trainer.train() | |||
use_tqdm=True, use_cuda=True, print_every=5, n_epochs=6, save_path="./save") | |||
trainer.train(load_best_model=True) | |||
# save model & pipeline | |||
model_proc = ModelProcessor(model, seq_len_field_name="word_seq_origin_len") | |||
@@ -73,10 +94,20 @@ def train(): | |||
torch.save(save_dict, "model_pp.pkl") | |||
print("pipeline saved") | |||
def infer(): | |||
pass | |||
torch.save(model, "./save/best_model.pkl") | |||
if __name__ == "__main__": | |||
train() | |||
parser = argparse.ArgumentParser() | |||
parser.add_argument("-c", "--restart", action="store_true", help="whether to continue training") | |||
parser.add_argument("-cp", "--checkpoint", type=str, help="checkpoint of the trained model") | |||
args = parser.parse_args() | |||
if args.restart is True: | |||
# 继续训练 python train_pos_tag.py -c -cp ./save/best_model.pkl | |||
if args.checkpoint is None: | |||
raise RuntimeError("Please provide the checkpoint. -cp ") | |||
train(args.checkpoint) | |||
else: | |||
# 一次训练 python train_pos_tag.py | |||
train() |
@@ -0,0 +1,25 @@ | |||
import pickle | |||
def load_embed(embed_path): | |||
embed_dict = {} | |||
with open(embed_path, "r", encoding="utf-8") as f: | |||
for line in f: | |||
tokens = line.split(" ") | |||
if len(tokens) <= 5: | |||
continue | |||
key = tokens[0] | |||
if len(key) == 1: | |||
value = [float(x) for x in tokens[1:]] | |||
embed_dict[key] = value | |||
return embed_dict | |||
if __name__ == "__main__": | |||
embed_dict = load_embed("/home/zyfeng/data/small.txt") | |||
print(embed_dict.keys()) | |||
with open("./char_tencent_embedding.pkl", "wb") as f: | |||
pickle.dump(embed_dict, f) | |||
print("finished") |