From ec165ce4ac258dd3107ad57aee1142eeca668c1a Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Sun, 15 Jul 2018 14:59:41 +0800 Subject: [PATCH] add model saver and loader --- fastNLP/loader/model_loader.py | 19 +++++++++++++++++++ fastNLP/saver/base_saver.py | 9 --------- fastNLP/saver/model_saver.py | 12 +++++++++++- test/test_POS_pipeline.py | 25 +++++++++++++++++++++---- 4 files changed, 51 insertions(+), 14 deletions(-) create mode 100644 fastNLP/loader/model_loader.py diff --git a/fastNLP/loader/model_loader.py b/fastNLP/loader/model_loader.py new file mode 100644 index 00000000..8224b3f2 --- /dev/null +++ b/fastNLP/loader/model_loader.py @@ -0,0 +1,19 @@ +import torch + +from fastNLP.loader.base_loader import BaseLoader + + +class ModelLoader(BaseLoader): + """ + Loader for models. + """ + + def __init__(self, data_name, data_path): + super(ModelLoader, self).__init__(data_name, data_path) + + def load_pytorch(self, empty_model): + """ + Load model parameters from .pkl files into the empty PyTorch model. + :param empty_model: a PyTorch model with initialized parameters. + """ + empty_model.load_state_dict(torch.load(self.data_path)) diff --git a/fastNLP/saver/base_saver.py b/fastNLP/saver/base_saver.py index d721da2c..3a350c0b 100644 --- a/fastNLP/saver/base_saver.py +++ b/fastNLP/saver/base_saver.py @@ -3,12 +3,3 @@ class BaseSaver(object): def __init__(self, save_path): self.save_path = save_path - - def save_bytes(self): - raise NotImplementedError - - def save_str(self): - raise NotImplementedError - - def compress(self): - raise NotImplementedError diff --git a/fastNLP/saver/model_saver.py b/fastNLP/saver/model_saver.py index 3b3cbeca..97675142 100644 --- a/fastNLP/saver/model_saver.py +++ b/fastNLP/saver/model_saver.py @@ -1,4 +1,6 @@ -from saver.base_saver import BaseSaver +import torch + +from fastNLP.saver.base_saver import BaseSaver class ModelSaver(BaseSaver): @@ -6,3 +8,11 @@ class ModelSaver(BaseSaver): def __init__(self, save_path): super(ModelSaver, self).__init__(save_path) + + def save_pytorch(self, model): + """ + Save a pytorch model into .pkl file. + :param model: a PyTorch model + :return: + """ + torch.save(model.state_dict(), self.save_path) diff --git a/test/test_POS_pipeline.py b/test/test_POS_pipeline.py index af22e3b9..6922a1e9 100644 --- a/test/test_POS_pipeline.py +++ b/test/test_POS_pipeline.py @@ -1,10 +1,11 @@ import sys - sys.path.append("..") - from fastNLP.action.trainer import POSTrainer from fastNLP.loader.dataset_loader import POSDatasetLoader from fastNLP.loader.preprocess import POSPreprocess +from fastNLP.saver.model_saver import ModelSaver +from fastNLP.loader.model_loader import ModelLoader +from fastNLP.action.tester import POSTester from fastNLP.models.sequence_modeling import SeqLabeling data_name = "people.txt" @@ -13,8 +14,8 @@ pickle_path = "data_for_tests" if __name__ == "__main__": # Data Loader - pos = POSDatasetLoader(data_name, data_path) - train_data = pos.load_lines() + pos_loader = POSDatasetLoader(data_name, data_path) + train_data = pos_loader.load_lines() # Preprocessor p = POSPreprocess(train_data, pickle_path) @@ -33,3 +34,19 @@ if __name__ == "__main__": trainer.train(model) print("Training finished!") + + saver = ModelSaver("./saved_model.pkl") + saver.save_pytorch(model) + print("Model saved!") + + del model, trainer, pos_loader + + model = SeqLabeling(100, 1, num_classes, vocab_size, bi_direction=True) + ModelLoader("xxx", "./saved_model.pkl").load_pytorch(model) + print("model loaded!") + + test_args = {"save_output": True, "validate_in_training": False, "save_dev_input": False, + "save_loss": True, "batch_size": 1, "pickle_path": pickle_path} + tester = POSTester(test_args) + tester.test(model) + print("model tested!")