Browse Source

- add interfaces for pos_tagging API

- update predictor.py to remove unused methods
- update model_loader.py & model_saver.py to support entire model saving & loading
- update pos tagging training script
tags/v0.2.0
FengZiYjun 5 years ago
parent
commit
79105381f5
6 changed files with 85 additions and 67 deletions
  1. +44
    -0
      fastNLP/api/pos_tagger.py
  2. +4
    -37
      fastNLP/core/predictor.py
  3. +9
    -2
      fastNLP/loader/model_loader.py
  4. +2
    -1
      fastNLP/models/sequence_modeling.py
  5. +6
    -2
      fastNLP/saver/model_saver.py
  6. +20
    -25
      reproduction/pos_tag_model/train_pos_tag.py

+ 44
- 0
fastNLP/api/pos_tagger.py View File

@@ -0,0 +1,44 @@
import pickle

import numpy as np

from fastNLP.core.dataset import DataSet
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.core.predictor import Predictor


class POS_tagger:
def __init__(self):
pass

def predict(self, query):
"""
:param query: List[str]
:return answer: List[str]

"""
# TODO: 根据query 构建DataSet
pos_dataset = DataSet()
pos_dataset["text_field"] = np.array(query)

# 加载pipeline和model
pipeline = self.load_pipeline("./xxxx")

# 将DataSet作为参数运行 pipeline
pos_dataset = pipeline(pos_dataset)

# 加载模型
model = ModelLoader().load_pytorch("./xxx")

# 调 predictor
predictor = Predictor()
output = predictor.predict(model, pos_dataset)

# TODO: 转成最终输出
return None

@staticmethod
def load_pipeline(path):
with open(path, "r") as fp:
pipeline = pickle.load(fp)
return pipeline

+ 4
- 37
fastNLP/core/predictor.py View File

@@ -2,9 +2,7 @@ import numpy as np
import torch

from fastNLP.core.batch import Batch
from fastNLP.core.preprocess import load_pickle
from fastNLP.core.sampler import SequentialSampler
from fastNLP.loader.dataset_loader import convert_seq2seq_dataset, convert_seq2tag_dataset, convert_seq_dataset


class Predictor(object):
@@ -16,19 +14,9 @@ class Predictor(object):
Currently, Predictor does not support GPU.
"""

def __init__(self, pickle_path, post_processor):
"""

:param pickle_path: str, the path to the pickle files.
:param post_processor: a function or callable object, that takes list of batch outputs as input

"""
def __init__(self):
self.batch_size = 1
self.batch_output = []
self.pickle_path = pickle_path
self._post_processor = post_processor
self.label_vocab = load_pickle(self.pickle_path, "label2id.pkl")
self.word_vocab = load_pickle(self.pickle_path, "word2id.pkl")

def predict(self, network, data):
"""Perform inference using the trained model.
@@ -37,9 +25,6 @@ class Predictor(object):
:param data: a DataSet object.
:return: list of list of strings, [num_examples, tag_seq_length]
"""
# transform strings into DataSet object
# data = self.prepare_input(data)

# turn on the testing mode; clean up the history
self.mode(network, test=True)
batch_output = []
@@ -51,7 +36,7 @@ class Predictor(object):
prediction = self.data_forward(network, batch_x)
batch_output.append(prediction)

return self._post_processor(batch_output, self.label_vocab)
return batch_output

def mode(self, network, test=True):
if test:
@@ -64,37 +49,19 @@ class Predictor(object):
y = network(**x)
return y

def prepare_input(self, data):
"""Transform two-level list of strings into an DataSet object.
In the training pipeline, this is done by Preprocessor. But in inference time, we do not call Preprocessor.

:param data: list of list of strings.
::
[
[word_11, word_12, ...],
[word_21, word_22, ...],
...
]

:return data_set: a DataSet instance.
"""
assert isinstance(data, list)
data = convert_seq_dataset(data)
data.index_field("word_seq", self.word_vocab)


class SeqLabelInfer(Predictor):
def __init__(self, pickle_path):
print(
"[FastNLP Warning] SeqLabelInfer will be deprecated. Please use Predictor directly.")
super(SeqLabelInfer, self).__init__(pickle_path, seq_label_post_processor)
super(SeqLabelInfer, self).__init__()


class ClassificationInfer(Predictor):
def __init__(self, pickle_path):
print(
"[FastNLP Warning] ClassificationInfer will be deprecated. Please use Predictor directly.")
super(ClassificationInfer, self).__init__(pickle_path, text_classify_post_processor)
super(ClassificationInfer, self).__init__()


def seq_label_post_processor(batch_outputs, label_vocab):


+ 9
- 2
fastNLP/loader/model_loader.py View File

@@ -8,8 +8,8 @@ class ModelLoader(BaseLoader):
Loader for models.
"""

def __init__(self, data_path):
super(ModelLoader, self).__init__(data_path)
def __init__(self):
super(ModelLoader, self).__init__()

@staticmethod
def load_pytorch(empty_model, model_path):
@@ -19,3 +19,10 @@ class ModelLoader(BaseLoader):
:param model_path: str, the path to the saved model.
"""
empty_model.load_state_dict(torch.load(model_path))
@staticmethod
def load_pytorch(model_path):
"""Load the entire model.

"""
return torch.load(model_path)

+ 2
- 1
fastNLP/models/sequence_modeling.py View File

@@ -127,7 +127,8 @@ class AdvSeqLabel(SeqLabeling):
:param word_seq: LongTensor, [batch_size, mex_len]
:param word_seq_origin_len: list of int.
:param truth: LongTensor, [batch_size, max_len]
:return y:
:return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting.
If truth is not None, return loss, a scalar. Used in training.
"""
self.mask = self.make_mask(word_seq, word_seq_origin_len)



+ 6
- 2
fastNLP/saver/model_saver.py View File

@@ -15,10 +15,14 @@ class ModelSaver(object):
"""
self.save_path = save_path

def save_pytorch(self, model):
def save_pytorch(self, model, param_only=True):
"""Save a pytorch model into .pkl file.

:param model: a PyTorch model
:param param_only: bool, whether only to save the model parameters or the entire model.

"""
torch.save(model.state_dict(), self.save_path)
if param_only is True:
torch.save(model.state_dict(), self.save_path)
else:
torch.save(model, self.save_path)

+ 20
- 25
reproduction/pos_tag_model/train_pos_tag.py View File

@@ -59,42 +59,37 @@ def infer():
print("Inference finished!")


def train():
# Config Loader
train_args = ConfigSection()
test_args = ConfigSection()
ConfigLoader("good_name").load_config(cfgfile, {"train": train_args, "test": test_args})
def train():
# load config
trainer_args = ConfigSection()
model_args = ConfigSection()
ConfigLoader().load_config(cfgfile, {"train": train_args, "test": test_args})

# Data Loader
loader = PeopleDailyCorpusLoader()
train_data, _ = loader.load()

# Preprocessor
preprocessor = SeqLabelPreprocess()
data_train, data_dev = preprocessor.run(train_data, pickle_path=pickle_path, train_dev_split=0.3)
train_args["vocab_size"] = preprocessor.vocab_size
train_args["num_classes"] = preprocessor.num_classes
# TODO: define processors
# define pipeline
pp = Pipeline()
# TODO: pp.add_processor()

# Trainer
trainer = SeqLabelTrainer(**train_args.data)
# run the pipeline, get data_set
train_data = pp(train_data)

# Model
# define a model
model = AdvSeqLabel(train_args)
try:
ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
print('model parameter loaded!')
except Exception as e:
print("No saved model. Continue.")
pass

# Start training
# call trainer to train
trainer = SeqLabelTrainer(train_args)
trainer.train(model, data_train, data_dev)
print("Training finished!")

# Saver
saver = ModelSaver("./save/saved_model.pkl")
saver.save_pytorch(model)
print("Model saved!")
# save model
ModelSaver("./saved_model.pkl").save_pytorch(model, param_only=False)

# TODO:save pipeline



def test():


Loading…
Cancel
Save