From a73087e913ea6c7faad53a104983f87b0a8b2bef Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Tue, 10 Jul 2018 22:00:24 +0800 Subject: [PATCH] refactor Tester; Tester + Trainer for seq modeling work --- fastNLP/action/tester.py | 161 +++++++++++++++++++++++++------------- fastNLP/action/trainer.py | 36 +++++---- test/test_POS_pipeline.py | 2 +- 3 files changed, 125 insertions(+), 74 deletions(-) diff --git a/fastNLP/action/tester.py b/fastNLP/action/tester.py index 7f660bb0..2a71cf4d 100644 --- a/fastNLP/action/tester.py +++ b/fastNLP/action/tester.py @@ -1,87 +1,136 @@ -from collections import namedtuple +import _pickle -import numpy as np +import torch from fastNLP.action.action import Action +from fastNLP.action.action import RandomSampler, Batchifier +from fastNLP.modules.utils import seq_mask -class Tester(Action): +class BaseTester(Action): """docstring for Tester""" - TestConfig = namedtuple("config", ["validate_in_training", "save_dev_input", "save_output", - "save_loss", "batch_size"]) - def __init__(self, test_args): """ :param test_args: named tuple """ - super(Tester, self).__init__() - self.validate_in_training = test_args.validate_in_training - self.save_dev_input = test_args.save_dev_input + super(BaseTester, self).__init__() + self.validate_in_training = test_args["validate_in_training"] self.valid_x = None self.valid_y = None - self.save_output = test_args.save_output + self.save_output = test_args["save_output"] self.output = None - self.save_loss = test_args.save_loss + self.save_loss = test_args["save_loss"] self.mean_loss = None - self.batch_size = test_args.batch_size - - def test(self, network, data): - print("testing") - network.mode(test=True) # turn on the testing mode - if self.save_dev_input: - if self.valid_x is None: - valid_x, valid_y = network.prepare_input(data) - self.valid_x = valid_x - self.valid_y = valid_y - else: - valid_x = self.valid_x - valid_y = self.valid_y - else: - valid_x, valid_y = network.prepare_input(data) + self.batch_size = test_args["batch_size"] + self.pickle_path = test_args["pickle_path"] + self.iterator = None - # split into batches by self.batch_size - iterations, test_batch_generator = self.batchify(self.batch_size, valid_x, valid_y) + def test(self, network): + # print("--------------testing----------------") + self.mode(network, test=True) - batch_output = list() - loss_history = list() - # turn on the testing mode of the network - network.mode(test=True) + dev_data = self.prepare_input(self.pickle_path) - for step in range(iterations): - batch_x, batch_y = test_batch_generator.__next__() + self.iterator = iter(Batchifier(RandomSampler(dev_data), self.batch_size, drop_last=True)) - # forward pass from test input to predicted output - prediction = network.data_forward(batch_x) + batch_output = list() + eval_history = list() + num_iter = len(dev_data) // self.batch_size + + for step in range(num_iter): + batch_x, batch_y = self.batchify(dev_data) - loss = network.get_loss(prediction, batch_y) + prediction = self.data_forward(network, batch_x) + eval_results = self.evaluate(prediction, batch_y) if self.save_output: - batch_output.append(prediction.data) + batch_output.append(prediction) if self.save_loss: - loss_history.append(loss) - self.log(self.make_log(step, loss)) - - if self.save_loss: - self.mean_loss = np.mean(np.array(loss_history)) - if self.save_output: - self.output = self.make_output(batch_output) + eval_history.append(eval_results) - @property - def loss(self): - return self.mean_loss + def prepare_input(self, data_path): + data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb")) + return data_dev - @property - def result(self): - return self.output + def batchify(self, data): + """ + 1. Perform batching from data and produce a batch of training data. + 2. Add padding. + :param data: list. Each entry is a sample, which is also a list of features and label(s). + E.g. + [ + [[word_11, word_12, word_13], [label_11. label_12]], # sample 1 + [[word_21, word_22, word_23], [label_21. label_22]], # sample 2 + ... + ] + :return batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] + batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] + """ + indices = next(self.iterator) + batch = [data[idx] for idx in indices] + batch_x = [sample[0] for sample in batch] + batch_y = [sample[1] for sample in batch] + batch_x = self.pad(batch_x) + return batch_x, batch_y @staticmethod - def make_output(batch_outputs): - # construct full prediction with batch outputs - return np.concatenate(batch_outputs, axis=0) + def pad(batch, fill=0): + """ + Pad a batch of samples to maximum length. + :param batch: list of list + :param fill: word index to pad, default 0. + :return: a padded batch + """ + max_length = max([len(x) for x in batch]) + for idx, sample in enumerate(batch): + if len(sample) < max_length: + batch[idx] = sample + [fill * (max_length - len(sample))] + return batch - def load_config(self, args): + def data_forward(self, network, data): raise NotImplementedError - def load_dataset(self, args): + def evaluate(self, predict, truth): raise NotImplementedError + + @property + def matrices(self): + raise NotImplementedError + + def mode(self, model, test=True): + """To do: combine this function with Trainer""" + if test: + model.eval() + else: + model.train() + + +class POSTester(BaseTester): + """ + Tester for sequence labeling. + """ + + def __init__(self, test_args): + super(POSTester, self).__init__(test_args) + self.max_len = None + self.mask = None + + def data_forward(self, network, x): + """To Do: combine with Trainer + + :param network: the PyTorch model + :param x: list of list, [batch_size, max_len] + :return y: [batch_size, num_classes] + """ + seq_len = [len(seq) for seq in x] + x = torch.Tensor(x).long() + self.batch_size = x.size(0) + self.max_len = x.size(1) + self.mask = seq_mask(seq_len, self.max_len) + y = network(x) + return y + + def evaluate(self, predict, truth): + """To Do: """ + return 0 diff --git a/fastNLP/action/trainer.py b/fastNLP/action/trainer.py index 6f51435a..034b46ca 100644 --- a/fastNLP/action/trainer.py +++ b/fastNLP/action/trainer.py @@ -5,7 +5,7 @@ import torch from fastNLP.action.action import Action from fastNLP.action.action import RandomSampler, Batchifier -from fastNLP.action.tester import Tester +from fastNLP.action.tester import POSTester from fastNLP.modules.utils import seq_mask @@ -43,7 +43,7 @@ class BaseTrainer(Action): self.optimizer = None def train(self, network): - """General training loop. + """General Training Steps :param network: a model The method is framework independent. @@ -57,23 +57,27 @@ class BaseTrainer(Action): - update Subclasses must implement these methods with a specific framework. """ + # prepare model and data self.model = network data_train, data_dev, data_test, embedding = self.prepare_input(self.pickle_path) - test_args = Tester.TestConfig(save_output=True, validate_in_training=True, - save_dev_input=True, save_loss=True, batch_size=self.batch_size) - evaluator = Tester(test_args) + # define tester over dev data + valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, + "save_loss": True, "batch_size": self.batch_size, "pickle_path": self.pickle_path} + validator = POSTester(valid_args) - best_loss = 1e10 + # main training epochs iterations = len(data_train) // self.batch_size - for epoch in range(self.n_epochs): + + # turn on network training mode; define optimizer; prepare batch iterator self.mode(test=False) self.define_optimizer() self.iterator = iter(Batchifier(RandomSampler(data_train), self.batch_size, drop_last=True)) + # training iterations in one epoch for step in range(iterations): - batch_x, batch_y = self.batchify(self.batch_size, data_train) + batch_x, batch_y = self.batchify(data_train) prediction = self.data_forward(network, batch_x) @@ -84,9 +88,7 @@ class BaseTrainer(Action): if self.validate: if data_dev is None: raise RuntimeError("No validation data provided.") - evaluator.test(network, data_dev) - if evaluator.loss < best_loss: - best_loss = evaluator.loss + validator.test(network) # finish training @@ -162,11 +164,10 @@ class BaseTrainer(Action): """ raise NotImplementedError - def batchify(self, batch_size, data): + def batchify(self, data): """ 1. Perform batching from data and produce a batch of training data. 2. Add padding. - :param batch_size: int, the size of a batch :param data: list. Each entry is a sample, which is also a list of features and label(s). E.g. [ @@ -200,7 +201,9 @@ class BaseTrainer(Action): class ToyTrainer(BaseTrainer): - """A simple trainer for a PyTorch model.""" + """ + deprecated + """ def __init__(self, train_args): super(ToyTrainer, self).__init__(train_args) @@ -235,7 +238,7 @@ class ToyTrainer(BaseTrainer): class WordSegTrainer(BaseTrainer): """ - reserve for changes + deprecated """ def __init__(self, train_args): @@ -319,7 +322,6 @@ class WordSegTrainer(BaseTrainer): self.optimizer.step() - class POSTrainer(BaseTrainer): """ Trainer for Sequence Modeling @@ -391,4 +393,4 @@ if __name__ == "__name__": train_args = {"epochs": 1, "validate": False, "batch_size": 3, "pickle_path": "./"} trainer = BaseTrainer(train_args) data_train = [[[1, 2, 3, 4], [0]] * 10] + [[[1, 3, 5, 2], [1]] * 10] - trainer.batchify(batch_size=3, data=data_train) + trainer.batchify(data=data_train) diff --git a/test/test_POS_pipeline.py b/test/test_POS_pipeline.py index c6e3fd83..af22e3b9 100644 --- a/test/test_POS_pipeline.py +++ b/test/test_POS_pipeline.py @@ -23,7 +23,7 @@ if __name__ == "__main__": # Trainer train_args = {"epochs": 20, "batch_size": 1, "num_classes": num_classes, - "vocab_size": vocab_size, "pickle_path": pickle_path, "validate": False} + "vocab_size": vocab_size, "pickle_path": pickle_path, "validate": True} trainer = POSTrainer(train_args) # Model