From 7ea015c0f96b27bcb6091154adfac4ffae563766 Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Wed, 4 Jul 2018 23:28:48 +0800 Subject: [PATCH] update trainer: loading data with _pickle; add arguments comments. --- fastNLP/action/trainer.py | 51 ++++++++++++++++++++++----------------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/fastNLP/action/trainer.py b/fastNLP/action/trainer.py index 8b9eb717..437ab7d2 100644 --- a/fastNLP/action/trainer.py +++ b/fastNLP/action/trainer.py @@ -1,4 +1,4 @@ -import pickle +import _pickle from collections import namedtuple import numpy as np @@ -21,8 +21,7 @@ class BaseTrainer(Action): - grad_backward - get_loss """ - TrainConfig = namedtuple("config", ["epochs", "validate", "save_when_better", - "log_per_step", "log_validation", "batch_size"]) + TrainConfig = namedtuple("config", ["epochs", "validate", "batch_size", "pickle_path"]) def __init__(self, train_args): """ @@ -32,6 +31,7 @@ class BaseTrainer(Action): self.n_epochs = train_args.epochs self.validate = train_args.validate self.batch_size = train_args.batch_size + self.pickle_path = train_args.pickle_path self.model = None self.iterator = None self.loss_func = None @@ -39,8 +39,6 @@ class BaseTrainer(Action): def train(self, network): """General training loop. :param network: a model - :param train_data: raw data for training - :param dev_data: raw data for validation The method is framework independent. Work by calling the following methods: @@ -54,7 +52,7 @@ class BaseTrainer(Action): Subclasses must implement these methods with a specific framework. """ self.model = network - data_train, data_dev, data_test, embedding = self.prepare_input("./save/") + data_train, data_dev, data_test, embedding = self.prepare_input(self.pickle_path) test_args = Tester.TestConfig(save_output=True, validate_in_training=True, save_dev_input=True, save_loss=True, batch_size=self.batch_size) @@ -89,10 +87,10 @@ class BaseTrainer(Action): """ To do: Load pkl files of train/dev/test and embedding """ - data_train = pickle.load(open(data_path + "data_train.pkl", "rb")) - data_dev = pickle.load(open(data_path + "data_dev.pkl", "rb")) - data_test = pickle.load(open(data_path + "data_test.pkl", "rb")) - embedding = pickle.load(open(data_path + "embedding.pkl", "rb")) + data_train = _pickle.load(open(data_path + "data_train.pkl", "rb")) + data_dev = _pickle.load(open(data_path + "data_dev.pkl", "rb")) + data_test = _pickle.load(open(data_path + "data_test.pkl", "rb")) + embedding = _pickle.load(open(data_path + "embedding.pkl", "rb")) return data_train, data_dev, data_test, embedding def mode(self, test=False): @@ -147,20 +145,30 @@ class BaseTrainer(Action): if hasattr(self.model, "loss"): self.loss_func = self.model.loss else: - self.loss_func = self.define_loss() + self.define_loss() return self.loss_func(predict, truth) def define_loss(self): + """ + Assign an instance of loss function to self.loss_func + E.g. self.loss_func = nn.CrossEntropyLoss() + """ raise NotImplementedError def batchify(self, batch_size, data): """ - Perform batching from data and produce a batch of training data. - Add padding. - :param batch_size: - :param data: - :param pad: - :return: batch_x, batch_y + 1. Perform batching from data and produce a batch of training data. + 2. Add padding. + :param batch_size: int, the size of a batch + :param data: list. Each entry is a sample, which is also a list of features and label(s). + E.g. + [ + [[feature_1, feature_2, feature_3], [label_1. label_2]], # sample 1 + [[feature_1, feature_2, feature_3], [label_1. label_2]], # sample 2 + ... + ] + :return batch_x: list. Each entry is a list of features of a sample. + batch_y: list. Each entry is a list of labels of a sample. """ if self.iterator is None: self.iterator = iter(Batchifier(RandomSampler(data), batch_size, drop_last=True)) @@ -306,8 +314,7 @@ class WordSegTrainer(BaseTrainer): if __name__ == "__name__": - Config = namedtuple("config", ["epochs", "validate", "save_when_better", "log_per_step", - "log_validation", "batch_size"]) - train_config = Config(epochs=5, validate=True, save_when_better=True, log_per_step=10, log_validation=True, - batch_size=32) - trainer = ToyTrainer(train_config) + train_args = BaseTrainer.TrainConfig(epochs=1, validate=False, batch_size=3, pickle_path="./") + trainer = BaseTrainer(train_args) + data_train = [[[1, 2, 3, 4], [0]] * 10] + [[[1, 3, 5, 2], [1]] * 10] + trainer.batchify(batch_size=3, data=data_train)