Browse Source

Merge branch 'dev' of github.com:choosewhatulike/fastNLP-private into dev

tags/v0.3.0^2
yh 6 years ago
parent
commit
4fd6e12fa9
6 changed files with 105 additions and 57 deletions
  1. +35
    -45
      fastNLP/api/api.py
  2. +3
    -1
      fastNLP/core/batch.py
  3. +2
    -2
      fastNLP/core/metrics.py
  4. +1
    -1
      reproduction/pos_tag_model/pos_tag.cfg
  5. +39
    -8
      reproduction/pos_tag_model/train_pos_tag.py
  6. +25
    -0
      reproduction/pos_tag_model/utils.py

+ 35
- 45
fastNLP/api/api.py View File

@@ -10,13 +10,15 @@ from fastNLP.core.dataset import DataSet
from fastNLP.api.model_zoo import load_url
from fastNLP.api.processor import ModelProcessor
from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader
from reproduction.pos_tag_model.pos_reader import ConllPOSReader
from reproduction.pos_tag_model.pos_reader import ZhConllPOSReader
from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag
from fastNLP.core.instance import Instance
from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.batch import Batch
from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1
from fastNLP.api.pipeline import Pipeline
from fastNLP.core.metrics import SpanFPreRecMetric
from fastNLP.api.processor import IndexerProcessor


# TODO add pretrain urls
@@ -65,7 +67,7 @@ class POS(API):
:param content: list of list of str. Each string is a token(word).
:return answer: list of list of str. Each string is a tag.
"""
if not hasattr(self, 'pipeline'):
if not hasattr(self, "pipeline"):
raise ValueError("You have to load model first.")

sentence_list = []
@@ -104,47 +106,35 @@ class POS(API):
elif isinstance(content, list):
return output

def test(self, filepath):

tag_proc = self._dict['tag_indexer']

model = self.pipeline.pipeline[2].model
pipeline = self.pipeline.pipeline[0:2]
pipeline.append(tag_proc)
pp = Pipeline(pipeline)

reader = ConllPOSReader()
te_dataset = reader.load(filepath)

"""
evaluator = SeqLabelEvaluator2('word_seq_origin_len')
end_tagidx_set = set()
tag_proc.vocab.build_vocab()
for key, value in tag_proc.vocab.word2idx.items():
if key.startswith('E-'):
end_tagidx_set.add(value)
if key.startswith('S-'):
end_tagidx_set.add(value)
evaluator.end_tagidx_set = end_tagidx_set

pp(te_dataset)
te_dataset.set_target(truth=True)

default_valid_args = {"batch_size": 64,
"use_cuda": True, "evaluator": evaluator,
"model": model, "data": te_dataset}

tester = Tester(**default_valid_args)

test_result = tester.test()

f1 = round(test_result['F'] * 100, 2)
pre = round(test_result['P'] * 100, 2)
rec = round(test_result['R'] * 100, 2)
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec))

return f1, pre, rec
"""
def test(self, file_path):
test_data = ZhConllPOSReader().load(file_path)

tag_vocab = self._dict["tag_vocab"]
pipeline = self._dict["pipeline"]
index_tag = IndexerProcessor(vocab=tag_vocab, field_name="tag", new_added_field_name="truth", is_input=False)
pipeline.pipeline = [index_tag] + pipeline.pipeline

pipeline(test_data)
test_data.set_target("truth")
prediction = test_data.field_arrays["predict"].content
truth = test_data.field_arrays["truth"].content
seq_len = test_data.field_arrays["word_seq_origin_len"].content

# padding by hand
max_length = max([len(seq) for seq in prediction])
for idx in range(len(prediction)):
prediction[idx] = list(prediction[idx]) + ([0] * (max_length - len(prediction[idx])))
truth[idx] = list(truth[idx]) + ([0] * (max_length - len(truth[idx])))
evaluator = SpanFPreRecMetric(tag_vocab=tag_vocab, pred="predict", target="truth",
seq_lens="word_seq_origin_len")
evaluator({"predict": torch.Tensor(prediction), "word_seq_origin_len": torch.Tensor(seq_len)},
{"truth": torch.Tensor(truth)})
test_result = evaluator.get_metric()
f1 = round(test_result['f'] * 100, 2)
pre = round(test_result['pre'] * 100, 2)
rec = round(test_result['rec'] * 100, 2)

return {"F1": f1, "precision": pre, "recall": rec}


class CWS(API):
@@ -316,8 +306,8 @@ if __name__ == "__main__":
s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。',
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
'那么这款无人机到底有多厉害?']
# print(pos.test('/Users/yh/Desktop/test_data/pos_test.conll'))
print(pos.predict(s))
print(pos.test("/home/zyfeng/data/sample.conllx"))
# print(pos.predict(s))

# cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf.pkl'
# cws = CWS(device='cpu')


+ 3
- 1
fastNLP/core/batch.py View File

@@ -1,6 +1,8 @@
import numpy as np
import torch

from fastNLP.core.sampler import RandomSampler


class Batch(object):
"""Batch is an iterable object which iterates over mini-batches.
@@ -17,7 +19,7 @@ class Batch(object):

"""

def __init__(self, dataset, batch_size, sampler, as_numpy=False):
def __init__(self, dataset, batch_size, sampler=RandomSampler(), as_numpy=False):
self.dataset = dataset
self.batch_size = batch_size
self.sampler = sampler


+ 2
- 2
fastNLP/core/metrics.py View File

@@ -451,8 +451,8 @@ class SpanFPreRecMetric(MetricBase):

batch_size = pred.size(0)
for i in range(batch_size):
pred_tags = pred[i, :seq_lens[i]].tolist()
gold_tags = target[i, :seq_lens[i]].tolist()
pred_tags = pred[i, :int(seq_lens[i])].tolist()
gold_tags = target[i, :int(seq_lens[i])].tolist()

pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags]
gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags]


+ 1
- 1
reproduction/pos_tag_model/pos_tag.cfg View File

@@ -10,7 +10,7 @@ eval_sort_key = 'accuracy'

[model]
rnn_hidden_units = 300
word_emb_dim = 300
word_emb_dim = 100
dropout = 0.5
use_crf = true
print_every_step = 10


+ 39
- 8
reproduction/pos_tag_model/train_pos_tag.py View File

@@ -1,4 +1,6 @@
import argparse
import os
import pickle
import sys

import torch
@@ -21,7 +23,20 @@ cfgfile = './pos_tag.cfg'
pickle_path = "save"


def train():
def load_tencent_embed(embed_path, word2id):
hit = 0
with open(embed_path, "rb") as f:
embed_dict = pickle.load(f)
embedding_tensor = torch.randn(len(word2id), 200)
for key in word2id:
if key in embed_dict:
embedding_tensor[word2id[key]] = torch.Tensor(embed_dict[key])
hit += 1
print("vocab_size={} hit={} hit/vocab_size={}".format(len(word2id), hit, hit / len(word2id)))
return embedding_tensor


def train(checkpoint=None):
# load config
train_param = ConfigSection()
model_param = ConfigSection()
@@ -54,15 +69,21 @@ def train():
print("vocab_size={} num_classes={}".format(model_param["vocab_size"], model_param["num_classes"]))

# define a model
model = AdvSeqLabel(model_param, id2words=tag_proc.vocab.idx2word)
if checkpoint is None:
# pre_trained = load_tencent_embed("/home/zyfeng/data/char_tencent_embedding.pkl", vocab_proc.vocab.word2idx)
pre_trained = None
model = AdvSeqLabel(model_param, id2words=tag_proc.vocab.idx2word, emb=pre_trained)
print(model)
else:
model = torch.load(checkpoint)

# call trainer to train
trainer = Trainer(dataset, model, loss=None, metrics=SpanFPreRecMetric(tag_proc.vocab, pred="predict",
target="truth",
seq_lens="word_seq_origin_len"),
dev_data=dataset, metric_key="f",
use_tqdm=False, use_cuda=True, print_every=20, n_epochs=1, save_path="./save")
trainer.train()
use_tqdm=True, use_cuda=True, print_every=5, n_epochs=6, save_path="./save")
trainer.train(load_best_model=True)

# save model & pipeline
model_proc = ModelProcessor(model, seq_len_field_name="word_seq_origin_len")
@@ -73,10 +94,20 @@ def train():
torch.save(save_dict, "model_pp.pkl")
print("pipeline saved")


def infer():
pass
torch.save(model, "./save/best_model.pkl")


if __name__ == "__main__":
train()
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--restart", action="store_true", help="whether to continue training")
parser.add_argument("-cp", "--checkpoint", type=str, help="checkpoint of the trained model")
args = parser.parse_args()

if args.restart is True:
# 继续训练 python train_pos_tag.py -c -cp ./save/best_model.pkl
if args.checkpoint is None:
raise RuntimeError("Please provide the checkpoint. -cp ")
train(args.checkpoint)
else:
# 一次训练 python train_pos_tag.py
train()

+ 25
- 0
reproduction/pos_tag_model/utils.py View File

@@ -0,0 +1,25 @@
import pickle


def load_embed(embed_path):
embed_dict = {}
with open(embed_path, "r", encoding="utf-8") as f:
for line in f:
tokens = line.split(" ")
if len(tokens) <= 5:
continue
key = tokens[0]
if len(key) == 1:
value = [float(x) for x in tokens[1:]]
embed_dict[key] = value
return embed_dict


if __name__ == "__main__":
embed_dict = load_embed("/home/zyfeng/data/small.txt")

print(embed_dict.keys())

with open("./char_tencent_embedding.pkl", "wb") as f:
pickle.dump(embed_dict, f)
print("finished")

Loading…
Cancel
Save