Browse Source

update trainer: add sampling and padding in batchify, add pkl loading in prepare_input, check model loss in get_loss

tags/v0.1.0
FengZiYjun 6 years ago
parent
commit
ceffed6a16
2 changed files with 119 additions and 38 deletions
  1. +61
    -25
      fastNLP/action/action.py
  2. +58
    -13
      fastNLP/action/trainer.py

+ 61
- 25
fastNLP/action/action.py View File

@@ -1,3 +1,4 @@
import numpy as np




class Action(object): class Action(object):
@@ -8,28 +9,63 @@ class Action(object):
def __init__(self): def __init__(self):
super(Action, self).__init__() super(Action, self).__init__()


def batchify(self, batch_size, X, Y=None):
"""
:param batch_size: int
:param X: feature matrix of size [n_sample, m_feature]
:param Y: label vector of size [n_sample, 1] (optional)
:return iteration:int, the number of step in each epoch
generator:generator, to generate batch inputs
"""
n_samples = X.shape[0]
num_iter = n_samples // batch_size
if Y is None:
generator = self._batch_generate(batch_size, num_iter, X)
else:
generator = self._batch_generate(batch_size, num_iter, X, Y)
return num_iter, generator

@staticmethod
def _batch_generate(batch_size, num_iter, *data):
for step in range(num_iter):
start = batch_size * step
end = batch_size * (step + 1)
yield tuple([x[start:end] for x in data])

def make_log(self, *args):
return "log"

class BaseSampler(object):
"""
Base class for all samplers.
"""

def __init__(self, data_set):
self.data_set_length = len(data_set)

def __len__(self):
return self.data_set_length

def __iter__(self):
raise NotImplementedError


class SequentialSampler(BaseSampler):
"""
Sample data in the original order.
"""

def __init__(self, data_set):
super(SequentialSampler, self).__init__(data_set)

def __iter__(self):
return iter(range(self.data_set_length))


class RandomSampler(BaseSampler):
"""
Sample data in random permutation order.
"""

def __init__(self, data_set):
super(RandomSampler, self).__init__(data_set)

def __iter__(self):
return iter(np.random.permutation(self.data_set_length))


class Batchifier(object):
"""
Wrap random or sequential sampler to generate a mini-batch.
"""

def __init__(self, sampler, batch_size, drop_last=True):
super(Batchifier, self).__init__()
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last

def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) < self.batch_size and self.drop_last is False:
yield batch

+ 58
- 13
fastNLP/action/trainer.py View File

@@ -1,9 +1,11 @@
import pickle
from collections import namedtuple from collections import namedtuple


import numpy as np import numpy as np
import torch import torch


from fastNLP.action.action import Action from fastNLP.action.action import Action
from fastNLP.action.action import RandomSampler, Batchifier
from fastNLP.action.tester import Tester from fastNLP.action.tester import Tester




@@ -31,8 +33,10 @@ class BaseTrainer(Action):
self.validate = train_args.validate self.validate = train_args.validate
self.batch_size = train_args.batch_size self.batch_size = train_args.batch_size
self.model = None self.model = None
self.iterator = None
self.loss_func = None


def train(self, network, train_data, dev_data=None):
def train(self, network):
"""General training loop. """General training loop.
:param network: a model :param network: a model
:param train_data: raw data for training :param train_data: raw data for training
@@ -50,22 +54,21 @@ class BaseTrainer(Action):
Subclasses must implement these methods with a specific framework. Subclasses must implement these methods with a specific framework.
""" """
self.model = network self.model = network
train_x, train_y = self.prepare_input(train_data)

iterations, train_batch_generator = self.batchify(self.batch_size, train_x, train_y)
data_train, data_dev, data_test, embedding = self.prepare_input("./save/")


test_args = Tester.TestConfig(save_output=True, validate_in_training=True, test_args = Tester.TestConfig(save_output=True, validate_in_training=True,
save_dev_input=True, save_loss=True, batch_size=self.batch_size) save_dev_input=True, save_loss=True, batch_size=self.batch_size)
evaluator = Tester(test_args) evaluator = Tester(test_args)


best_loss = 1e10 best_loss = 1e10
iterations = len(data_train) // self.batch_size


for epoch in range(self.n_epochs): for epoch in range(self.n_epochs):
self.mode(test=False) # turn on the train mode
self.mode(test=False)


self.define_optimizer() self.define_optimizer()
for step in range(iterations): for step in range(iterations):
batch_x, batch_y = train_batch_generator.__next__()
batch_x, batch_y = self.batchify(self.batch_size, data_train)


prediction = self.data_forward(network, batch_x) prediction = self.data_forward(network, batch_x)


@@ -74,21 +77,23 @@ class BaseTrainer(Action):
self.update() self.update()


if self.validate: if self.validate:
if dev_data is None:
if data_dev is None:
raise RuntimeError("No validation data provided.") raise RuntimeError("No validation data provided.")
evaluator.test(network, dev_data)
evaluator.test(network, data_dev)
if evaluator.loss < best_loss: if evaluator.loss < best_loss:
best_loss = evaluator.loss best_loss = evaluator.loss


# finish training # finish training


def prepare_input(self, data):
def prepare_input(self, data_path):
""" """
Perform data transformation from raw input to vector/matrix inputs.
:param data: raw inputs
:return (X, Y): tuple, input features and labels
To do: Load pkl files of train/dev/test and embedding
""" """
raise NotImplementedError
data_train = pickle.load(open(data_path + "data_train.pkl", "rb"))
data_dev = pickle.load(open(data_path + "data_dev.pkl", "rb"))
data_test = pickle.load(open(data_path + "data_test.pkl", "rb"))
embedding = pickle.load(open(data_path + "embedding.pkl", "rb"))
return data_train, data_dev, data_test, embedding


def mode(self, test=False): def mode(self, test=False):
""" """
@@ -138,8 +143,48 @@ class BaseTrainer(Action):
:param truth: ground truth label vector :param truth: ground truth label vector
:return: a scalar :return: a scalar
""" """
if self.loss_func is None:
if hasattr(self.model, "loss"):
self.loss_func = self.model.loss
else:
self.loss_func = self.define_loss()
return self.loss_func(predict, truth)

def define_loss(self):
raise NotImplementedError raise NotImplementedError


def batchify(self, batch_size, data):
"""
Perform batching from data and produce a batch of training data.
Add padding.
:param batch_size:
:param data:
:param pad:
:return: batch_x, batch_y
"""
if self.iterator is None:
self.iterator = iter(Batchifier(RandomSampler(data), batch_size, drop_last=True))
indices = next(self.iterator)
batch = [data[idx] for idx in indices]
batch_x = [sample[0] for sample in batch]
batch_y = [sample[1] for sample in batch]
batch_x = self.pad(batch_x)
return batch_x, batch_y

@staticmethod
def pad(batch, fill=0):
"""
Pad a batch of samples to maximum length.
:param batch: list of list
:param fill: word index to pad, default 0.
:return: a padded batch
"""
max_length = max([len(x) for x in batch])
for idx, sample in enumerate(batch):
if len(sample) < max_length:
batch[idx] = sample + [fill * (max_length - len(sample))]
return batch



class ToyTrainer(BaseTrainer): class ToyTrainer(BaseTrainer):
"""A simple trainer for a PyTorch model.""" """A simple trainer for a PyTorch model."""


Loading…
Cancel
Save