From 79105381f54bf518a4be25ab30a6a1c7b340c255 Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Fri, 9 Nov 2018 19:52:31 +0800 Subject: [PATCH] - add interfaces for pos_tagging API - update predictor.py to remove unused methods - update model_loader.py & model_saver.py to support entire model saving & loading - update pos tagging training script --- fastNLP/api/pos_tagger.py | 44 ++++++++++++++++++++ fastNLP/core/predictor.py | 41 ++----------------- fastNLP/loader/model_loader.py | 11 ++++- fastNLP/models/sequence_modeling.py | 3 +- fastNLP/saver/model_saver.py | 8 +++- reproduction/pos_tag_model/train_pos_tag.py | 45 +++++++++------------ 6 files changed, 85 insertions(+), 67 deletions(-) create mode 100644 fastNLP/api/pos_tagger.py diff --git a/fastNLP/api/pos_tagger.py b/fastNLP/api/pos_tagger.py new file mode 100644 index 00000000..fbd689c1 --- /dev/null +++ b/fastNLP/api/pos_tagger.py @@ -0,0 +1,44 @@ +import pickle + +import numpy as np + +from fastNLP.core.dataset import DataSet +from fastNLP.loader.model_loader import ModelLoader +from fastNLP.core.predictor import Predictor + + +class POS_tagger: + def __init__(self): + pass + + def predict(self, query): + """ + :param query: List[str] + :return answer: List[str] + + """ + # TODO: 根据query 构建DataSet + pos_dataset = DataSet() + pos_dataset["text_field"] = np.array(query) + + # 加载pipeline和model + pipeline = self.load_pipeline("./xxxx") + + # 将DataSet作为参数运行 pipeline + pos_dataset = pipeline(pos_dataset) + + # 加载模型 + model = ModelLoader().load_pytorch("./xxx") + + # 调 predictor + predictor = Predictor() + output = predictor.predict(model, pos_dataset) + + # TODO: 转成最终输出 + return None + + @staticmethod + def load_pipeline(path): + with open(path, "r") as fp: + pipeline = pickle.load(fp) + return pipeline diff --git a/fastNLP/core/predictor.py b/fastNLP/core/predictor.py index c5d22df4..63e5b7ca 100644 --- a/fastNLP/core/predictor.py +++ b/fastNLP/core/predictor.py @@ -2,9 +2,7 @@ import numpy as np import torch from fastNLP.core.batch import Batch -from fastNLP.core.preprocess import load_pickle from fastNLP.core.sampler import SequentialSampler -from fastNLP.loader.dataset_loader import convert_seq2seq_dataset, convert_seq2tag_dataset, convert_seq_dataset class Predictor(object): @@ -16,19 +14,9 @@ class Predictor(object): Currently, Predictor does not support GPU. """ - def __init__(self, pickle_path, post_processor): - """ - - :param pickle_path: str, the path to the pickle files. - :param post_processor: a function or callable object, that takes list of batch outputs as input - - """ + def __init__(self): self.batch_size = 1 self.batch_output = [] - self.pickle_path = pickle_path - self._post_processor = post_processor - self.label_vocab = load_pickle(self.pickle_path, "label2id.pkl") - self.word_vocab = load_pickle(self.pickle_path, "word2id.pkl") def predict(self, network, data): """Perform inference using the trained model. @@ -37,9 +25,6 @@ class Predictor(object): :param data: a DataSet object. :return: list of list of strings, [num_examples, tag_seq_length] """ - # transform strings into DataSet object - # data = self.prepare_input(data) - # turn on the testing mode; clean up the history self.mode(network, test=True) batch_output = [] @@ -51,7 +36,7 @@ class Predictor(object): prediction = self.data_forward(network, batch_x) batch_output.append(prediction) - return self._post_processor(batch_output, self.label_vocab) + return batch_output def mode(self, network, test=True): if test: @@ -64,37 +49,19 @@ class Predictor(object): y = network(**x) return y - def prepare_input(self, data): - """Transform two-level list of strings into an DataSet object. - In the training pipeline, this is done by Preprocessor. But in inference time, we do not call Preprocessor. - - :param data: list of list of strings. - :: - [ - [word_11, word_12, ...], - [word_21, word_22, ...], - ... - ] - - :return data_set: a DataSet instance. - """ - assert isinstance(data, list) - data = convert_seq_dataset(data) - data.index_field("word_seq", self.word_vocab) - class SeqLabelInfer(Predictor): def __init__(self, pickle_path): print( "[FastNLP Warning] SeqLabelInfer will be deprecated. Please use Predictor directly.") - super(SeqLabelInfer, self).__init__(pickle_path, seq_label_post_processor) + super(SeqLabelInfer, self).__init__() class ClassificationInfer(Predictor): def __init__(self, pickle_path): print( "[FastNLP Warning] ClassificationInfer will be deprecated. Please use Predictor directly.") - super(ClassificationInfer, self).__init__(pickle_path, text_classify_post_processor) + super(ClassificationInfer, self).__init__() def seq_label_post_processor(batch_outputs, label_vocab): diff --git a/fastNLP/loader/model_loader.py b/fastNLP/loader/model_loader.py index c07576b8..5c8a1371 100644 --- a/fastNLP/loader/model_loader.py +++ b/fastNLP/loader/model_loader.py @@ -8,8 +8,8 @@ class ModelLoader(BaseLoader): Loader for models. """ - def __init__(self, data_path): - super(ModelLoader, self).__init__(data_path) + def __init__(self): + super(ModelLoader, self).__init__() @staticmethod def load_pytorch(empty_model, model_path): @@ -19,3 +19,10 @@ class ModelLoader(BaseLoader): :param model_path: str, the path to the saved model. """ empty_model.load_state_dict(torch.load(model_path)) + + @staticmethod + def load_pytorch(model_path): + """Load the entire model. + + """ + return torch.load(model_path) \ No newline at end of file diff --git a/fastNLP/models/sequence_modeling.py b/fastNLP/models/sequence_modeling.py index 464f99be..11e49ee1 100644 --- a/fastNLP/models/sequence_modeling.py +++ b/fastNLP/models/sequence_modeling.py @@ -127,7 +127,8 @@ class AdvSeqLabel(SeqLabeling): :param word_seq: LongTensor, [batch_size, mex_len] :param word_seq_origin_len: list of int. :param truth: LongTensor, [batch_size, max_len] - :return y: + :return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting. + If truth is not None, return loss, a scalar. Used in training. """ self.mask = self.make_mask(word_seq, word_seq_origin_len) diff --git a/fastNLP/saver/model_saver.py b/fastNLP/saver/model_saver.py index 74518a44..fd391f69 100644 --- a/fastNLP/saver/model_saver.py +++ b/fastNLP/saver/model_saver.py @@ -15,10 +15,14 @@ class ModelSaver(object): """ self.save_path = save_path - def save_pytorch(self, model): + def save_pytorch(self, model, param_only=True): """Save a pytorch model into .pkl file. :param model: a PyTorch model + :param param_only: bool, whether only to save the model parameters or the entire model. """ - torch.save(model.state_dict(), self.save_path) + if param_only is True: + torch.save(model.state_dict(), self.save_path) + else: + torch.save(model, self.save_path) diff --git a/reproduction/pos_tag_model/train_pos_tag.py b/reproduction/pos_tag_model/train_pos_tag.py index 45cfbbc0..fb077fe3 100644 --- a/reproduction/pos_tag_model/train_pos_tag.py +++ b/reproduction/pos_tag_model/train_pos_tag.py @@ -59,42 +59,37 @@ def infer(): print("Inference finished!") -def train(): - # Config Loader - train_args = ConfigSection() - test_args = ConfigSection() - ConfigLoader("good_name").load_config(cfgfile, {"train": train_args, "test": test_args}) +def train(): + # load config + trainer_args = ConfigSection() + model_args = ConfigSection() + ConfigLoader().load_config(cfgfile, {"train": train_args, "test": test_args}) # Data Loader loader = PeopleDailyCorpusLoader() train_data, _ = loader.load() - # Preprocessor - preprocessor = SeqLabelPreprocess() - data_train, data_dev = preprocessor.run(train_data, pickle_path=pickle_path, train_dev_split=0.3) - train_args["vocab_size"] = preprocessor.vocab_size - train_args["num_classes"] = preprocessor.num_classes + # TODO: define processors + + # define pipeline + pp = Pipeline() + # TODO: pp.add_processor() - # Trainer - trainer = SeqLabelTrainer(**train_args.data) + # run the pipeline, get data_set + train_data = pp(train_data) - # Model + # define a model model = AdvSeqLabel(train_args) - try: - ModelLoader.load_pytorch(model, "./save/saved_model.pkl") - print('model parameter loaded!') - except Exception as e: - print("No saved model. Continue.") - pass - # Start training + # call trainer to train + trainer = SeqLabelTrainer(train_args) trainer.train(model, data_train, data_dev) - print("Training finished!") - # Saver - saver = ModelSaver("./save/saved_model.pkl") - saver.save_pytorch(model) - print("Model saved!") + # save model + ModelSaver("./saved_model.pkl").save_pytorch(model, param_only=False) + + # TODO:save pipeline + def test():