Browse Source

add model saver and loader

tags/v0.1.0
FengZiYjun 6 years ago
parent
commit
ec165ce4ac
4 changed files with 51 additions and 14 deletions
  1. +19
    -0
      fastNLP/loader/model_loader.py
  2. +0
    -9
      fastNLP/saver/base_saver.py
  3. +11
    -1
      fastNLP/saver/model_saver.py
  4. +21
    -4
      test/test_POS_pipeline.py

+ 19
- 0
fastNLP/loader/model_loader.py View File

@@ -0,0 +1,19 @@
import torch

from fastNLP.loader.base_loader import BaseLoader


class ModelLoader(BaseLoader):
"""
Loader for models.
"""

def __init__(self, data_name, data_path):
super(ModelLoader, self).__init__(data_name, data_path)

def load_pytorch(self, empty_model):
"""
Load model parameters from .pkl files into the empty PyTorch model.
:param empty_model: a PyTorch model with initialized parameters.
"""
empty_model.load_state_dict(torch.load(self.data_path))

+ 0
- 9
fastNLP/saver/base_saver.py View File

@@ -3,12 +3,3 @@ class BaseSaver(object):

def __init__(self, save_path):
self.save_path = save_path

def save_bytes(self):
raise NotImplementedError

def save_str(self):
raise NotImplementedError

def compress(self):
raise NotImplementedError

+ 11
- 1
fastNLP/saver/model_saver.py View File

@@ -1,4 +1,6 @@
from saver.base_saver import BaseSaver
import torch

from fastNLP.saver.base_saver import BaseSaver


class ModelSaver(BaseSaver):
@@ -6,3 +8,11 @@ class ModelSaver(BaseSaver):

def __init__(self, save_path):
super(ModelSaver, self).__init__(save_path)

def save_pytorch(self, model):
"""
Save a pytorch model into .pkl file.
:param model: a PyTorch model
:return:
"""
torch.save(model.state_dict(), self.save_path)

+ 21
- 4
test/test_POS_pipeline.py View File

@@ -1,10 +1,11 @@
import sys

sys.path.append("..")

from fastNLP.action.trainer import POSTrainer
from fastNLP.loader.dataset_loader import POSDatasetLoader
from fastNLP.loader.preprocess import POSPreprocess
from fastNLP.saver.model_saver import ModelSaver
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.action.tester import POSTester
from fastNLP.models.sequence_modeling import SeqLabeling

data_name = "people.txt"
@@ -13,8 +14,8 @@ pickle_path = "data_for_tests"

if __name__ == "__main__":
# Data Loader
pos = POSDatasetLoader(data_name, data_path)
train_data = pos.load_lines()
pos_loader = POSDatasetLoader(data_name, data_path)
train_data = pos_loader.load_lines()

# Preprocessor
p = POSPreprocess(train_data, pickle_path)
@@ -33,3 +34,19 @@ if __name__ == "__main__":
trainer.train(model)

print("Training finished!")

saver = ModelSaver("./saved_model.pkl")
saver.save_pytorch(model)
print("Model saved!")

del model, trainer, pos_loader

model = SeqLabeling(100, 1, num_classes, vocab_size, bi_direction=True)
ModelLoader("xxx", "./saved_model.pkl").load_pytorch(model)
print("model loaded!")

test_args = {"save_output": True, "validate_in_training": False, "save_dev_input": False,
"save_loss": True, "batch_size": 1, "pickle_path": pickle_path}
tester = POSTester(test_args)
tester.test(model)
print("model tested!")

Loading…
Cancel
Save