Browse Source

combine controller and trainer

tags/v0.1.0
FengZiYjun 7 years ago
parent
commit
83fe6f9f21
4 changed files with 134 additions and 149 deletions
  1. +0
    -11
      fastNLP/action/action.py
  2. +132
    -42
      fastNLP/action/trainer.py
  3. +2
    -1
      fastNLP/loader/config_loader.py
  4. +0
    -95
      fastNLP/models/base_model.py

+ 0
- 11
fastNLP/action/action.py View File

@@ -1,4 +1,3 @@
from saver.logger import Logger


class Action(object):
@@ -8,16 +7,6 @@ class Action(object):

def __init__(self):
super(Action, self).__init__()
self.logger = Logger("logger_output.txt")

def load_config(self, args):
raise NotImplementedError

def load_dataset(self, args):
raise NotImplementedError

def log(self, string):
self.logger.log(string)

def batchify(self, batch_size, X, Y=None):
"""


+ 132
- 42
fastNLP/action/trainer.py View File

@@ -1,36 +1,56 @@
from collections import namedtuple

from .action import Action
from .tester import Tester
import numpy as np
import torch

from fastNLP.action.action import Action
from fastNLP.action.tester import Tester

class Trainer(Action):
"""
Trainer is a common training pipeline shared among all models.

class BaseTrainer(Action):
"""Base trainer for all trainers.
Trainer receives a model and data, and then performs training.

Subclasses must implement the following abstract methods:
- prepare_input
- mode
- define_optimizer
- data_forward
- grad_backward
- get_loss
"""
TrainConfig = namedtuple("config", ["epochs", "validate", "save_when_better",
"log_per_step", "log_validation", "batch_size"])

def __init__(self, train_args):
"""
:param train_args: namedtuple
training parameters
"""
super(Trainer, self).__init__()
super(BaseTrainer, self).__init__()
self.n_epochs = train_args.epochs
self.validate = train_args.validate
self.save_when_better = train_args.save_when_better
self.log_per_step = train_args.log_per_step
self.log_validation = train_args.log_validation
self.batch_size = train_args.batch_size
self.model = None

def train(self, network, train_data, dev_data=None):
"""
:param network: the models controller
"""General training loop.
:param network: a model
:param train_data: raw data for training
:param dev_data: raw data for validation
This method will call all the base methods of network (implemented in models.base_model).

The method is framework independent.
Work by calling the following methods:
- prepare_input
- mode
- define_optimizer
- data_forward
- get_loss
- grad_backward
- update
Subclasses must implement these methods with a specific framework.
"""
train_x, train_y = network.prepare_input(train_data)
self.model = network
train_x, train_y = self.prepare_input(train_data)

iterations, train_batch_generator = self.batchify(self.batch_size, train_x, train_y)

@@ -39,55 +59,125 @@ class Trainer(Action):
evaluator = Tester(test_args)

best_loss = 1e10
loss_history = list()

for epoch in range(self.n_epochs):
network.mode(test=False) # turn on the train mode
self.mode(test=False) # turn on the train mode

network.define_optimizer()
self.define_optimizer()
for step in range(iterations):
batch_x, batch_y = train_batch_generator.__next__()

prediction = network.data_forward(batch_x)

loss = network.get_loss(prediction, batch_y)
network.grad_backward()
prediction = self.data_forward(network, batch_x)

if step % self.log_per_step == 0:
print("step ", step)
loss_history.append(loss)
self.log(self.make_log(epoch, step, loss))
loss = self.get_loss(prediction, batch_y)
self.grad_backward(loss)
self.update()

#################### evaluate over dev set ###################
if self.validate:
if dev_data is None:
raise RuntimeError("No validation data provided.")
# give all controls to tester
evaluator.test(network, dev_data)

if self.log_validation:
self.log(self.make_valid_log(epoch, evaluator.loss))
if evaluator.loss < best_loss:
best_loss = evaluator.loss
if self.save_when_better:
self.save_model(network)

# finish training

def make_log(self, *args):
return "make a log"
def prepare_input(self, data):
"""
Perform data transformation from raw input to vector/matrix inputs.
:param data: raw inputs
:return (X, Y): tuple, input features and labels
"""
raise NotImplementedError

def make_valid_log(self, *args):
return "make a valid log"
def mode(self, test=False):
"""
Tell the network to be trained or not.
:param test: bool
"""
raise NotImplementedError

def save_model(self, model):
model.save()
def define_optimizer(self):
"""
Define framework-specific optimizer specified by the models.
"""
raise NotImplementedError

def load_data(self, data_name):
print("load data")
def update(self):
"""
Perform weight update on a model.

def load_config(self, args):
For PyTorch, just call optimizer to update.
"""
raise NotImplementedError

def load_dataset(self, args):
def data_forward(self, network, *x):
"""
Forward pass of the data.
:param network: a model
:param x: input feature matrix and label vector
:return: output by the models

For PyTorch, just do "network(*x)"
"""
raise NotImplementedError

def grad_backward(self, loss):
"""
Compute gradient with link rules.
:param loss: a scalar where back-prop starts

For PyTorch, just do "loss.backward()"
"""
raise NotImplementedError

def get_loss(self, predict, truth):
"""
Compute loss given prediction and ground truth.
:param predict: prediction label vector
:param truth: ground truth label vector
:return: a scalar
"""
raise NotImplementedError


class ToyTrainer(BaseTrainer):
"""A simple trainer for a PyTorch model."""

def __init__(self, train_args):
super(ToyTrainer, self).__init__(train_args)
self.test_mode = False
self.weight = np.random.rand(5, 1)
self.bias = np.random.rand()
self._loss = 0
self._optimizer = None

def prepare_input(self, data):
return data[:, :-1], data[:, -1]

def mode(self, test=False):
self.model.mode(test)

def data_forward(self, network, *x):
return np.matmul(x, self.weight) + self.bias

def grad_backward(self, loss):
loss.backward()

def get_loss(self, pred, truth):
self._loss = np.mean(np.square(pred - truth))
return self._loss

def define_optimizer(self):
self._optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01)

def update(self):
self._optimizer.step()


if __name__ == "__name__":
Config = namedtuple("config", ["epochs", "validate", "save_when_better", "log_per_step",
"log_validation", "batch_size"])
train_config = Config(epochs=5, validate=True, save_when_better=True, log_per_step=10, log_validation=True,
batch_size=32)
trainer = ToyTrainer(train_config)

+ 2
- 1
fastNLP/loader/config_loader.py View File

@@ -1,4 +1,4 @@
from loader.base_loader import BaseLoader
from fastNLP.loader.base_loader import BaseLoader


class ConfigLoader(BaseLoader):
@@ -11,3 +11,4 @@ class ConfigLoader(BaseLoader):
@staticmethod
def parse(string):
raise NotImplementedError


+ 0
- 95
fastNLP/models/base_model.py View File

@@ -1,4 +1,3 @@
import numpy as np
import torch


@@ -30,100 +29,6 @@ class BaseModel(torch.nn.Module):
raise NotImplementedError


class BaseController(object):
"""Base Controller for all controllers.
This class and its subclasses are actually "controllers" of the PyTorch models.
They act as an interface between Trainer and the PyTorch models.
This controller provides the following methods to be called by Trainer.
- prepare_input
- mode
- define_optimizer
- data_forward
- grad_backward
- get_loss
"""

def __init__(self):
"""
Define PyTorch model parameters here.
"""
pass

def prepare_input(self, data):
"""
Perform data transformation from raw input to vector/matrix inputs.
:param data: raw inputs
:return (X, Y): tuple, input features and labels
"""
raise NotImplementedError

def mode(self, test=False):
"""
Tell the network to be trained or not, required by PyTorch.
:param test: bool
"""
raise NotImplementedError

def define_optimizer(self):
"""
Define PyTorch optimizer specified by the models.
"""
raise NotImplementedError

def data_forward(self, *x):
"""
Forward pass of the data.
:param x: input feature matrix and label vector
:return: output by the models
"""
# required by PyTorch nn
raise NotImplementedError

def grad_backward(self):
"""
Perform gradient descent to update the models parameters.
"""
raise NotImplementedError

def get_loss(self, pred, truth):
"""
Compute loss given models prediction and ground truth. Loss function specified by the models.
:param pred: prediction label vector
:param truth: ground truth label vector
:return: a scalar
"""
raise NotImplementedError


class ToyController(BaseController):
"""This is for code testing."""

def __init__(self):
super(ToyController, self).__init__()
self.test_mode = False
self.weight = np.random.rand(5, 1)
self.bias = np.random.rand()
self._loss = 0

def prepare_input(self, data):
return data[:, :-1], data[:, -1]

def mode(self, test=False):
self.test_mode = test

def data_forward(self, x):
return np.matmul(x, self.weight) + self.bias

def grad_backward(self):
print("loss gradient backward")

def get_loss(self, pred, truth):
self._loss = np.mean(np.square(pred - truth))
return self._loss

def define_optimizer(self):
pass


class Vocabulary(object):
"""A look-up table that allows you to access `Lexeme` objects. The `Vocab`


Loading…
Cancel
Save