From ceffed6a1615cfbb7fe1520bdd6fd3f0d9670473 Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Tue, 3 Jul 2018 09:00:29 +0800 Subject: [PATCH] update trainer: add sampling and padding in batchify, add pkl loading in prepare_input, check model loss in get_loss --- fastNLP/action/action.py | 86 +++++++++++++++++++++++++++------------ fastNLP/action/trainer.py | 71 ++++++++++++++++++++++++++------ 2 files changed, 119 insertions(+), 38 deletions(-) diff --git a/fastNLP/action/action.py b/fastNLP/action/action.py index 5512c7b1..ea12a37e 100644 --- a/fastNLP/action/action.py +++ b/fastNLP/action/action.py @@ -1,3 +1,4 @@ +import numpy as np class Action(object): @@ -8,28 +9,63 @@ class Action(object): def __init__(self): super(Action, self).__init__() - def batchify(self, batch_size, X, Y=None): - """ - :param batch_size: int - :param X: feature matrix of size [n_sample, m_feature] - :param Y: label vector of size [n_sample, 1] (optional) - :return iteration:int, the number of step in each epoch - generator:generator, to generate batch inputs - """ - n_samples = X.shape[0] - num_iter = n_samples // batch_size - if Y is None: - generator = self._batch_generate(batch_size, num_iter, X) - else: - generator = self._batch_generate(batch_size, num_iter, X, Y) - return num_iter, generator - - @staticmethod - def _batch_generate(batch_size, num_iter, *data): - for step in range(num_iter): - start = batch_size * step - end = batch_size * (step + 1) - yield tuple([x[start:end] for x in data]) - - def make_log(self, *args): - return "log" + +class BaseSampler(object): + """ + Base class for all samplers. + """ + + def __init__(self, data_set): + self.data_set_length = len(data_set) + + def __len__(self): + return self.data_set_length + + def __iter__(self): + raise NotImplementedError + + +class SequentialSampler(BaseSampler): + """ + Sample data in the original order. + """ + + def __init__(self, data_set): + super(SequentialSampler, self).__init__(data_set) + + def __iter__(self): + return iter(range(self.data_set_length)) + + +class RandomSampler(BaseSampler): + """ + Sample data in random permutation order. + """ + + def __init__(self, data_set): + super(RandomSampler, self).__init__(data_set) + + def __iter__(self): + return iter(np.random.permutation(self.data_set_length)) + + +class Batchifier(object): + """ + Wrap random or sequential sampler to generate a mini-batch. + """ + + def __init__(self, sampler, batch_size, drop_last=True): + super(Batchifier, self).__init__() + self.sampler = sampler + self.batch_size = batch_size + self.drop_last = drop_last + + def __iter__(self): + batch = [] + for idx in self.sampler: + batch.append(idx) + if len(batch) == self.batch_size: + yield batch + batch = [] + if len(batch) < self.batch_size and self.drop_last is False: + yield batch diff --git a/fastNLP/action/trainer.py b/fastNLP/action/trainer.py index 0bbcccd7..8b9eb717 100644 --- a/fastNLP/action/trainer.py +++ b/fastNLP/action/trainer.py @@ -1,9 +1,11 @@ +import pickle from collections import namedtuple import numpy as np import torch from fastNLP.action.action import Action +from fastNLP.action.action import RandomSampler, Batchifier from fastNLP.action.tester import Tester @@ -31,8 +33,10 @@ class BaseTrainer(Action): self.validate = train_args.validate self.batch_size = train_args.batch_size self.model = None + self.iterator = None + self.loss_func = None - def train(self, network, train_data, dev_data=None): + def train(self, network): """General training loop. :param network: a model :param train_data: raw data for training @@ -50,22 +54,21 @@ class BaseTrainer(Action): Subclasses must implement these methods with a specific framework. """ self.model = network - train_x, train_y = self.prepare_input(train_data) - - iterations, train_batch_generator = self.batchify(self.batch_size, train_x, train_y) + data_train, data_dev, data_test, embedding = self.prepare_input("./save/") test_args = Tester.TestConfig(save_output=True, validate_in_training=True, save_dev_input=True, save_loss=True, batch_size=self.batch_size) evaluator = Tester(test_args) best_loss = 1e10 + iterations = len(data_train) // self.batch_size for epoch in range(self.n_epochs): - self.mode(test=False) # turn on the train mode + self.mode(test=False) self.define_optimizer() for step in range(iterations): - batch_x, batch_y = train_batch_generator.__next__() + batch_x, batch_y = self.batchify(self.batch_size, data_train) prediction = self.data_forward(network, batch_x) @@ -74,21 +77,23 @@ class BaseTrainer(Action): self.update() if self.validate: - if dev_data is None: + if data_dev is None: raise RuntimeError("No validation data provided.") - evaluator.test(network, dev_data) + evaluator.test(network, data_dev) if evaluator.loss < best_loss: best_loss = evaluator.loss # finish training - def prepare_input(self, data): + def prepare_input(self, data_path): """ - Perform data transformation from raw input to vector/matrix inputs. - :param data: raw inputs - :return (X, Y): tuple, input features and labels + To do: Load pkl files of train/dev/test and embedding """ - raise NotImplementedError + data_train = pickle.load(open(data_path + "data_train.pkl", "rb")) + data_dev = pickle.load(open(data_path + "data_dev.pkl", "rb")) + data_test = pickle.load(open(data_path + "data_test.pkl", "rb")) + embedding = pickle.load(open(data_path + "embedding.pkl", "rb")) + return data_train, data_dev, data_test, embedding def mode(self, test=False): """ @@ -138,8 +143,48 @@ class BaseTrainer(Action): :param truth: ground truth label vector :return: a scalar """ + if self.loss_func is None: + if hasattr(self.model, "loss"): + self.loss_func = self.model.loss + else: + self.loss_func = self.define_loss() + return self.loss_func(predict, truth) + + def define_loss(self): raise NotImplementedError + def batchify(self, batch_size, data): + """ + Perform batching from data and produce a batch of training data. + Add padding. + :param batch_size: + :param data: + :param pad: + :return: batch_x, batch_y + """ + if self.iterator is None: + self.iterator = iter(Batchifier(RandomSampler(data), batch_size, drop_last=True)) + indices = next(self.iterator) + batch = [data[idx] for idx in indices] + batch_x = [sample[0] for sample in batch] + batch_y = [sample[1] for sample in batch] + batch_x = self.pad(batch_x) + return batch_x, batch_y + + @staticmethod + def pad(batch, fill=0): + """ + Pad a batch of samples to maximum length. + :param batch: list of list + :param fill: word index to pad, default 0. + :return: a padded batch + """ + max_length = max([len(x) for x in batch]) + for idx, sample in enumerate(batch): + if len(sample) < max_length: + batch[idx] = sample + [fill * (max_length - len(sample))] + return batch + class ToyTrainer(BaseTrainer): """A simple trainer for a PyTorch model."""