Browse Source

update trainer: add sampling and padding in batchify, add pkl loading in prepare_input, check model loss in get_loss

tags/v0.1.0
FengZiYjun 6 years ago
parent
commit
ceffed6a16
2 changed files with 119 additions and 38 deletions
  1. +61
    -25
      fastNLP/action/action.py
  2. +58
    -13
      fastNLP/action/trainer.py

+ 61
- 25
fastNLP/action/action.py View File

@@ -1,3 +1,4 @@
import numpy as np


class Action(object):
@@ -8,28 +9,63 @@ class Action(object):
def __init__(self):
super(Action, self).__init__()

def batchify(self, batch_size, X, Y=None):
"""
:param batch_size: int
:param X: feature matrix of size [n_sample, m_feature]
:param Y: label vector of size [n_sample, 1] (optional)
:return iteration:int, the number of step in each epoch
generator:generator, to generate batch inputs
"""
n_samples = X.shape[0]
num_iter = n_samples // batch_size
if Y is None:
generator = self._batch_generate(batch_size, num_iter, X)
else:
generator = self._batch_generate(batch_size, num_iter, X, Y)
return num_iter, generator

@staticmethod
def _batch_generate(batch_size, num_iter, *data):
for step in range(num_iter):
start = batch_size * step
end = batch_size * (step + 1)
yield tuple([x[start:end] for x in data])

def make_log(self, *args):
return "log"

class BaseSampler(object):
"""
Base class for all samplers.
"""

def __init__(self, data_set):
self.data_set_length = len(data_set)

def __len__(self):
return self.data_set_length

def __iter__(self):
raise NotImplementedError


class SequentialSampler(BaseSampler):
"""
Sample data in the original order.
"""

def __init__(self, data_set):
super(SequentialSampler, self).__init__(data_set)

def __iter__(self):
return iter(range(self.data_set_length))


class RandomSampler(BaseSampler):
"""
Sample data in random permutation order.
"""

def __init__(self, data_set):
super(RandomSampler, self).__init__(data_set)

def __iter__(self):
return iter(np.random.permutation(self.data_set_length))


class Batchifier(object):
"""
Wrap random or sequential sampler to generate a mini-batch.
"""

def __init__(self, sampler, batch_size, drop_last=True):
super(Batchifier, self).__init__()
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last

def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) < self.batch_size and self.drop_last is False:
yield batch

+ 58
- 13
fastNLP/action/trainer.py View File

@@ -1,9 +1,11 @@
import pickle
from collections import namedtuple

import numpy as np
import torch

from fastNLP.action.action import Action
from fastNLP.action.action import RandomSampler, Batchifier
from fastNLP.action.tester import Tester


@@ -31,8 +33,10 @@ class BaseTrainer(Action):
self.validate = train_args.validate
self.batch_size = train_args.batch_size
self.model = None
self.iterator = None
self.loss_func = None

def train(self, network, train_data, dev_data=None):
def train(self, network):
"""General training loop.
:param network: a model
:param train_data: raw data for training
@@ -50,22 +54,21 @@ class BaseTrainer(Action):
Subclasses must implement these methods with a specific framework.
"""
self.model = network
train_x, train_y = self.prepare_input(train_data)

iterations, train_batch_generator = self.batchify(self.batch_size, train_x, train_y)
data_train, data_dev, data_test, embedding = self.prepare_input("./save/")

test_args = Tester.TestConfig(save_output=True, validate_in_training=True,
save_dev_input=True, save_loss=True, batch_size=self.batch_size)
evaluator = Tester(test_args)

best_loss = 1e10
iterations = len(data_train) // self.batch_size

for epoch in range(self.n_epochs):
self.mode(test=False) # turn on the train mode
self.mode(test=False)

self.define_optimizer()
for step in range(iterations):
batch_x, batch_y = train_batch_generator.__next__()
batch_x, batch_y = self.batchify(self.batch_size, data_train)

prediction = self.data_forward(network, batch_x)

@@ -74,21 +77,23 @@ class BaseTrainer(Action):
self.update()

if self.validate:
if dev_data is None:
if data_dev is None:
raise RuntimeError("No validation data provided.")
evaluator.test(network, dev_data)
evaluator.test(network, data_dev)
if evaluator.loss < best_loss:
best_loss = evaluator.loss

# finish training

def prepare_input(self, data):
def prepare_input(self, data_path):
"""
Perform data transformation from raw input to vector/matrix inputs.
:param data: raw inputs
:return (X, Y): tuple, input features and labels
To do: Load pkl files of train/dev/test and embedding
"""
raise NotImplementedError
data_train = pickle.load(open(data_path + "data_train.pkl", "rb"))
data_dev = pickle.load(open(data_path + "data_dev.pkl", "rb"))
data_test = pickle.load(open(data_path + "data_test.pkl", "rb"))
embedding = pickle.load(open(data_path + "embedding.pkl", "rb"))
return data_train, data_dev, data_test, embedding

def mode(self, test=False):
"""
@@ -138,8 +143,48 @@ class BaseTrainer(Action):
:param truth: ground truth label vector
:return: a scalar
"""
if self.loss_func is None:
if hasattr(self.model, "loss"):
self.loss_func = self.model.loss
else:
self.loss_func = self.define_loss()
return self.loss_func(predict, truth)

def define_loss(self):
raise NotImplementedError

def batchify(self, batch_size, data):
"""
Perform batching from data and produce a batch of training data.
Add padding.
:param batch_size:
:param data:
:param pad:
:return: batch_x, batch_y
"""
if self.iterator is None:
self.iterator = iter(Batchifier(RandomSampler(data), batch_size, drop_last=True))
indices = next(self.iterator)
batch = [data[idx] for idx in indices]
batch_x = [sample[0] for sample in batch]
batch_y = [sample[1] for sample in batch]
batch_x = self.pad(batch_x)
return batch_x, batch_y

@staticmethod
def pad(batch, fill=0):
"""
Pad a batch of samples to maximum length.
:param batch: list of list
:param fill: word index to pad, default 0.
:return: a padded batch
"""
max_length = max([len(x) for x in batch])
for idx, sample in enumerate(batch):
if len(sample) < max_length:
batch[idx] = sample + [fill * (max_length - len(sample))]
return batch


class ToyTrainer(BaseTrainer):
"""A simple trainer for a PyTorch model."""


Loading…
Cancel
Save