- update predictor.py to remove unused methods - update model_loader.py & model_saver.py to support entire model saving & loading - update pos tagging training scripttags/v0.2.0
@@ -0,0 +1,44 @@ | |||
import pickle | |||
import numpy as np | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.loader.model_loader import ModelLoader | |||
from fastNLP.core.predictor import Predictor | |||
class POS_tagger: | |||
def __init__(self): | |||
pass | |||
def predict(self, query): | |||
""" | |||
:param query: List[str] | |||
:return answer: List[str] | |||
""" | |||
# TODO: 根据query 构建DataSet | |||
pos_dataset = DataSet() | |||
pos_dataset["text_field"] = np.array(query) | |||
# 加载pipeline和model | |||
pipeline = self.load_pipeline("./xxxx") | |||
# 将DataSet作为参数运行 pipeline | |||
pos_dataset = pipeline(pos_dataset) | |||
# 加载模型 | |||
model = ModelLoader().load_pytorch("./xxx") | |||
# 调 predictor | |||
predictor = Predictor() | |||
output = predictor.predict(model, pos_dataset) | |||
# TODO: 转成最终输出 | |||
return None | |||
@staticmethod | |||
def load_pipeline(path): | |||
with open(path, "r") as fp: | |||
pipeline = pickle.load(fp) | |||
return pipeline |
@@ -2,9 +2,7 @@ import numpy as np | |||
import torch | |||
from fastNLP.core.batch import Batch | |||
from fastNLP.core.preprocess import load_pickle | |||
from fastNLP.core.sampler import SequentialSampler | |||
from fastNLP.loader.dataset_loader import convert_seq2seq_dataset, convert_seq2tag_dataset, convert_seq_dataset | |||
class Predictor(object): | |||
@@ -16,19 +14,9 @@ class Predictor(object): | |||
Currently, Predictor does not support GPU. | |||
""" | |||
def __init__(self, pickle_path, post_processor): | |||
""" | |||
:param pickle_path: str, the path to the pickle files. | |||
:param post_processor: a function or callable object, that takes list of batch outputs as input | |||
""" | |||
def __init__(self): | |||
self.batch_size = 1 | |||
self.batch_output = [] | |||
self.pickle_path = pickle_path | |||
self._post_processor = post_processor | |||
self.label_vocab = load_pickle(self.pickle_path, "label2id.pkl") | |||
self.word_vocab = load_pickle(self.pickle_path, "word2id.pkl") | |||
def predict(self, network, data): | |||
"""Perform inference using the trained model. | |||
@@ -37,9 +25,6 @@ class Predictor(object): | |||
:param data: a DataSet object. | |||
:return: list of list of strings, [num_examples, tag_seq_length] | |||
""" | |||
# transform strings into DataSet object | |||
# data = self.prepare_input(data) | |||
# turn on the testing mode; clean up the history | |||
self.mode(network, test=True) | |||
batch_output = [] | |||
@@ -51,7 +36,7 @@ class Predictor(object): | |||
prediction = self.data_forward(network, batch_x) | |||
batch_output.append(prediction) | |||
return self._post_processor(batch_output, self.label_vocab) | |||
return batch_output | |||
def mode(self, network, test=True): | |||
if test: | |||
@@ -64,37 +49,19 @@ class Predictor(object): | |||
y = network(**x) | |||
return y | |||
def prepare_input(self, data): | |||
"""Transform two-level list of strings into an DataSet object. | |||
In the training pipeline, this is done by Preprocessor. But in inference time, we do not call Preprocessor. | |||
:param data: list of list of strings. | |||
:: | |||
[ | |||
[word_11, word_12, ...], | |||
[word_21, word_22, ...], | |||
... | |||
] | |||
:return data_set: a DataSet instance. | |||
""" | |||
assert isinstance(data, list) | |||
data = convert_seq_dataset(data) | |||
data.index_field("word_seq", self.word_vocab) | |||
class SeqLabelInfer(Predictor): | |||
def __init__(self, pickle_path): | |||
print( | |||
"[FastNLP Warning] SeqLabelInfer will be deprecated. Please use Predictor directly.") | |||
super(SeqLabelInfer, self).__init__(pickle_path, seq_label_post_processor) | |||
super(SeqLabelInfer, self).__init__() | |||
class ClassificationInfer(Predictor): | |||
def __init__(self, pickle_path): | |||
print( | |||
"[FastNLP Warning] ClassificationInfer will be deprecated. Please use Predictor directly.") | |||
super(ClassificationInfer, self).__init__(pickle_path, text_classify_post_processor) | |||
super(ClassificationInfer, self).__init__() | |||
def seq_label_post_processor(batch_outputs, label_vocab): | |||
@@ -8,8 +8,8 @@ class ModelLoader(BaseLoader): | |||
Loader for models. | |||
""" | |||
def __init__(self, data_path): | |||
super(ModelLoader, self).__init__(data_path) | |||
def __init__(self): | |||
super(ModelLoader, self).__init__() | |||
@staticmethod | |||
def load_pytorch(empty_model, model_path): | |||
@@ -19,3 +19,10 @@ class ModelLoader(BaseLoader): | |||
:param model_path: str, the path to the saved model. | |||
""" | |||
empty_model.load_state_dict(torch.load(model_path)) | |||
@staticmethod | |||
def load_pytorch(model_path): | |||
"""Load the entire model. | |||
""" | |||
return torch.load(model_path) |
@@ -127,7 +127,8 @@ class AdvSeqLabel(SeqLabeling): | |||
:param word_seq: LongTensor, [batch_size, mex_len] | |||
:param word_seq_origin_len: list of int. | |||
:param truth: LongTensor, [batch_size, max_len] | |||
:return y: | |||
:return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting. | |||
If truth is not None, return loss, a scalar. Used in training. | |||
""" | |||
self.mask = self.make_mask(word_seq, word_seq_origin_len) | |||
@@ -15,10 +15,14 @@ class ModelSaver(object): | |||
""" | |||
self.save_path = save_path | |||
def save_pytorch(self, model): | |||
def save_pytorch(self, model, param_only=True): | |||
"""Save a pytorch model into .pkl file. | |||
:param model: a PyTorch model | |||
:param param_only: bool, whether only to save the model parameters or the entire model. | |||
""" | |||
torch.save(model.state_dict(), self.save_path) | |||
if param_only is True: | |||
torch.save(model.state_dict(), self.save_path) | |||
else: | |||
torch.save(model, self.save_path) |
@@ -59,42 +59,37 @@ def infer(): | |||
print("Inference finished!") | |||
def train(): | |||
# Config Loader | |||
train_args = ConfigSection() | |||
test_args = ConfigSection() | |||
ConfigLoader("good_name").load_config(cfgfile, {"train": train_args, "test": test_args}) | |||
def train(): | |||
# load config | |||
trainer_args = ConfigSection() | |||
model_args = ConfigSection() | |||
ConfigLoader().load_config(cfgfile, {"train": train_args, "test": test_args}) | |||
# Data Loader | |||
loader = PeopleDailyCorpusLoader() | |||
train_data, _ = loader.load() | |||
# Preprocessor | |||
preprocessor = SeqLabelPreprocess() | |||
data_train, data_dev = preprocessor.run(train_data, pickle_path=pickle_path, train_dev_split=0.3) | |||
train_args["vocab_size"] = preprocessor.vocab_size | |||
train_args["num_classes"] = preprocessor.num_classes | |||
# TODO: define processors | |||
# define pipeline | |||
pp = Pipeline() | |||
# TODO: pp.add_processor() | |||
# Trainer | |||
trainer = SeqLabelTrainer(**train_args.data) | |||
# run the pipeline, get data_set | |||
train_data = pp(train_data) | |||
# Model | |||
# define a model | |||
model = AdvSeqLabel(train_args) | |||
try: | |||
ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||
print('model parameter loaded!') | |||
except Exception as e: | |||
print("No saved model. Continue.") | |||
pass | |||
# Start training | |||
# call trainer to train | |||
trainer = SeqLabelTrainer(train_args) | |||
trainer.train(model, data_train, data_dev) | |||
print("Training finished!") | |||
# Saver | |||
saver = ModelSaver("./save/saved_model.pkl") | |||
saver.save_pytorch(model) | |||
print("Model saved!") | |||
# save model | |||
ModelSaver("./saved_model.pkl").save_pytorch(model, param_only=False) | |||
# TODO:save pipeline | |||
def test(): | |||