diff --git a/fastNLP/action/action.py b/fastNLP/action/action.py index c85a74df..5512c7b1 100644 --- a/fastNLP/action/action.py +++ b/fastNLP/action/action.py @@ -1,4 +1,3 @@ -from saver.logger import Logger class Action(object): @@ -8,16 +7,6 @@ class Action(object): def __init__(self): super(Action, self).__init__() - self.logger = Logger("logger_output.txt") - - def load_config(self, args): - raise NotImplementedError - - def load_dataset(self, args): - raise NotImplementedError - - def log(self, string): - self.logger.log(string) def batchify(self, batch_size, X, Y=None): """ diff --git a/fastNLP/action/trainer.py b/fastNLP/action/trainer.py index b3640ba2..79f14df3 100644 --- a/fastNLP/action/trainer.py +++ b/fastNLP/action/trainer.py @@ -1,36 +1,56 @@ from collections import namedtuple -from .action import Action -from .tester import Tester +import numpy as np +import torch +from fastNLP.action.action import Action +from fastNLP.action.tester import Tester -class Trainer(Action): - """ - Trainer is a common training pipeline shared among all models. + +class BaseTrainer(Action): + """Base trainer for all trainers. + Trainer receives a model and data, and then performs training. + + Subclasses must implement the following abstract methods: + - prepare_input + - mode + - define_optimizer + - data_forward + - grad_backward + - get_loss """ TrainConfig = namedtuple("config", ["epochs", "validate", "save_when_better", "log_per_step", "log_validation", "batch_size"]) def __init__(self, train_args): """ - :param train_args: namedtuple + training parameters """ - super(Trainer, self).__init__() + super(BaseTrainer, self).__init__() self.n_epochs = train_args.epochs self.validate = train_args.validate - self.save_when_better = train_args.save_when_better - self.log_per_step = train_args.log_per_step - self.log_validation = train_args.log_validation self.batch_size = train_args.batch_size + self.model = None def train(self, network, train_data, dev_data=None): - """ - :param network: the models controller + """General training loop. + :param network: a model :param train_data: raw data for training :param dev_data: raw data for validation - This method will call all the base methods of network (implemented in models.base_model). + + The method is framework independent. + Work by calling the following methods: + - prepare_input + - mode + - define_optimizer + - data_forward + - get_loss + - grad_backward + - update + Subclasses must implement these methods with a specific framework. """ - train_x, train_y = network.prepare_input(train_data) + self.model = network + train_x, train_y = self.prepare_input(train_data) iterations, train_batch_generator = self.batchify(self.batch_size, train_x, train_y) @@ -39,55 +59,125 @@ class Trainer(Action): evaluator = Tester(test_args) best_loss = 1e10 - loss_history = list() for epoch in range(self.n_epochs): - network.mode(test=False) # turn on the train mode + self.mode(test=False) # turn on the train mode - network.define_optimizer() + self.define_optimizer() for step in range(iterations): batch_x, batch_y = train_batch_generator.__next__() - prediction = network.data_forward(batch_x) - - loss = network.get_loss(prediction, batch_y) - network.grad_backward() + prediction = self.data_forward(network, batch_x) - if step % self.log_per_step == 0: - print("step ", step) - loss_history.append(loss) - self.log(self.make_log(epoch, step, loss)) + loss = self.get_loss(prediction, batch_y) + self.grad_backward(loss) + self.update() - #################### evaluate over dev set ################### if self.validate: if dev_data is None: raise RuntimeError("No validation data provided.") - # give all controls to tester evaluator.test(network, dev_data) - - if self.log_validation: - self.log(self.make_valid_log(epoch, evaluator.loss)) if evaluator.loss < best_loss: best_loss = evaluator.loss - if self.save_when_better: - self.save_model(network) # finish training - def make_log(self, *args): - return "make a log" + def prepare_input(self, data): + """ + Perform data transformation from raw input to vector/matrix inputs. + :param data: raw inputs + :return (X, Y): tuple, input features and labels + """ + raise NotImplementedError - def make_valid_log(self, *args): - return "make a valid log" + def mode(self, test=False): + """ + Tell the network to be trained or not. + :param test: bool + """ + raise NotImplementedError - def save_model(self, model): - model.save() + def define_optimizer(self): + """ + Define framework-specific optimizer specified by the models. + """ + raise NotImplementedError - def load_data(self, data_name): - print("load data") + def update(self): + """ + Perform weight update on a model. - def load_config(self, args): + For PyTorch, just call optimizer to update. + """ raise NotImplementedError - def load_dataset(self, args): + def data_forward(self, network, *x): + """ + Forward pass of the data. + :param network: a model + :param x: input feature matrix and label vector + :return: output by the models + + For PyTorch, just do "network(*x)" + """ raise NotImplementedError + + def grad_backward(self, loss): + """ + Compute gradient with link rules. + :param loss: a scalar where back-prop starts + + For PyTorch, just do "loss.backward()" + """ + raise NotImplementedError + + def get_loss(self, predict, truth): + """ + Compute loss given prediction and ground truth. + :param predict: prediction label vector + :param truth: ground truth label vector + :return: a scalar + """ + raise NotImplementedError + + +class ToyTrainer(BaseTrainer): + """A simple trainer for a PyTorch model.""" + + def __init__(self, train_args): + super(ToyTrainer, self).__init__(train_args) + self.test_mode = False + self.weight = np.random.rand(5, 1) + self.bias = np.random.rand() + self._loss = 0 + self._optimizer = None + + def prepare_input(self, data): + return data[:, :-1], data[:, -1] + + def mode(self, test=False): + self.model.mode(test) + + def data_forward(self, network, *x): + return np.matmul(x, self.weight) + self.bias + + def grad_backward(self, loss): + loss.backward() + + def get_loss(self, pred, truth): + self._loss = np.mean(np.square(pred - truth)) + return self._loss + + def define_optimizer(self): + self._optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01) + + def update(self): + self._optimizer.step() + + +if __name__ == "__name__": + Config = namedtuple("config", ["epochs", "validate", "save_when_better", "log_per_step", + "log_validation", "batch_size"]) + train_config = Config(epochs=5, validate=True, save_when_better=True, log_per_step=10, log_validation=True, + batch_size=32) + trainer = ToyTrainer(train_config) diff --git a/fastNLP/loader/config_loader.py b/fastNLP/loader/config_loader.py index fa1d446d..0f40ec51 100644 --- a/fastNLP/loader/config_loader.py +++ b/fastNLP/loader/config_loader.py @@ -1,4 +1,4 @@ -from loader.base_loader import BaseLoader +from fastNLP.loader.base_loader import BaseLoader class ConfigLoader(BaseLoader): @@ -11,3 +11,4 @@ class ConfigLoader(BaseLoader): @staticmethod def parse(string): raise NotImplementedError + diff --git a/fastNLP/models/base_model.py b/fastNLP/models/base_model.py index 1a2782c3..9249e2e3 100644 --- a/fastNLP/models/base_model.py +++ b/fastNLP/models/base_model.py @@ -1,4 +1,3 @@ -import numpy as np import torch @@ -30,100 +29,6 @@ class BaseModel(torch.nn.Module): raise NotImplementedError -class BaseController(object): - """Base Controller for all controllers. - This class and its subclasses are actually "controllers" of the PyTorch models. - They act as an interface between Trainer and the PyTorch models. - This controller provides the following methods to be called by Trainer. - - prepare_input - - mode - - define_optimizer - - data_forward - - grad_backward - - get_loss - """ - - def __init__(self): - """ - Define PyTorch model parameters here. - """ - pass - - def prepare_input(self, data): - """ - Perform data transformation from raw input to vector/matrix inputs. - :param data: raw inputs - :return (X, Y): tuple, input features and labels - """ - raise NotImplementedError - - def mode(self, test=False): - """ - Tell the network to be trained or not, required by PyTorch. - :param test: bool - """ - raise NotImplementedError - - def define_optimizer(self): - """ - Define PyTorch optimizer specified by the models. - """ - raise NotImplementedError - - def data_forward(self, *x): - """ - Forward pass of the data. - :param x: input feature matrix and label vector - :return: output by the models - """ - # required by PyTorch nn - raise NotImplementedError - - def grad_backward(self): - """ - Perform gradient descent to update the models parameters. - """ - raise NotImplementedError - - def get_loss(self, pred, truth): - """ - Compute loss given models prediction and ground truth. Loss function specified by the models. - :param pred: prediction label vector - :param truth: ground truth label vector - :return: a scalar - """ - raise NotImplementedError - - -class ToyController(BaseController): - """This is for code testing.""" - - def __init__(self): - super(ToyController, self).__init__() - self.test_mode = False - self.weight = np.random.rand(5, 1) - self.bias = np.random.rand() - self._loss = 0 - - def prepare_input(self, data): - return data[:, :-1], data[:, -1] - - def mode(self, test=False): - self.test_mode = test - - def data_forward(self, x): - return np.matmul(x, self.weight) + self.bias - - def grad_backward(self): - print("loss gradient backward") - - def get_loss(self, pred, truth): - self._loss = np.mean(np.square(pred - truth)) - return self._loss - - def define_optimizer(self): - pass - class Vocabulary(object): """A look-up table that allows you to access `Lexeme` objects. The `Vocab`