Browse Source

refactor Tester; Tester + Trainer for seq modeling work

tags/v0.1.0
FengZiYjun 6 years ago
parent
commit
a73087e913
3 changed files with 125 additions and 74 deletions
  1. +105
    -56
      fastNLP/action/tester.py
  2. +19
    -17
      fastNLP/action/trainer.py
  3. +1
    -1
      test/test_POS_pipeline.py

+ 105
- 56
fastNLP/action/tester.py View File

@@ -1,87 +1,136 @@
from collections import namedtuple
import _pickle

import numpy as np
import torch

from fastNLP.action.action import Action
from fastNLP.action.action import RandomSampler, Batchifier
from fastNLP.modules.utils import seq_mask


class Tester(Action):
class BaseTester(Action):
"""docstring for Tester"""

TestConfig = namedtuple("config", ["validate_in_training", "save_dev_input", "save_output",
"save_loss", "batch_size"])

def __init__(self, test_args):
"""
:param test_args: named tuple
"""
super(Tester, self).__init__()
self.validate_in_training = test_args.validate_in_training
self.save_dev_input = test_args.save_dev_input
super(BaseTester, self).__init__()
self.validate_in_training = test_args["validate_in_training"]
self.valid_x = None
self.valid_y = None
self.save_output = test_args.save_output
self.save_output = test_args["save_output"]
self.output = None
self.save_loss = test_args.save_loss
self.save_loss = test_args["save_loss"]
self.mean_loss = None
self.batch_size = test_args.batch_size

def test(self, network, data):
print("testing")
network.mode(test=True) # turn on the testing mode
if self.save_dev_input:
if self.valid_x is None:
valid_x, valid_y = network.prepare_input(data)
self.valid_x = valid_x
self.valid_y = valid_y
else:
valid_x = self.valid_x
valid_y = self.valid_y
else:
valid_x, valid_y = network.prepare_input(data)
self.batch_size = test_args["batch_size"]
self.pickle_path = test_args["pickle_path"]
self.iterator = None

# split into batches by self.batch_size
iterations, test_batch_generator = self.batchify(self.batch_size, valid_x, valid_y)
def test(self, network):
# print("--------------testing----------------")
self.mode(network, test=True)

batch_output = list()
loss_history = list()
# turn on the testing mode of the network
network.mode(test=True)
dev_data = self.prepare_input(self.pickle_path)

for step in range(iterations):
batch_x, batch_y = test_batch_generator.__next__()
self.iterator = iter(Batchifier(RandomSampler(dev_data), self.batch_size, drop_last=True))

# forward pass from test input to predicted output
prediction = network.data_forward(batch_x)
batch_output = list()
eval_history = list()
num_iter = len(dev_data) // self.batch_size

for step in range(num_iter):
batch_x, batch_y = self.batchify(dev_data)

loss = network.get_loss(prediction, batch_y)
prediction = self.data_forward(network, batch_x)
eval_results = self.evaluate(prediction, batch_y)

if self.save_output:
batch_output.append(prediction.data)
batch_output.append(prediction)
if self.save_loss:
loss_history.append(loss)
self.log(self.make_log(step, loss))

if self.save_loss:
self.mean_loss = np.mean(np.array(loss_history))
if self.save_output:
self.output = self.make_output(batch_output)
eval_history.append(eval_results)

@property
def loss(self):
return self.mean_loss
def prepare_input(self, data_path):
data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb"))
return data_dev

@property
def result(self):
return self.output
def batchify(self, data):
"""
1. Perform batching from data and produce a batch of training data.
2. Add padding.
:param data: list. Each entry is a sample, which is also a list of features and label(s).
E.g.
[
[[word_11, word_12, word_13], [label_11. label_12]], # sample 1
[[word_21, word_22, word_23], [label_21. label_22]], # sample 2
...
]
:return batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len]
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels]
"""
indices = next(self.iterator)
batch = [data[idx] for idx in indices]
batch_x = [sample[0] for sample in batch]
batch_y = [sample[1] for sample in batch]
batch_x = self.pad(batch_x)
return batch_x, batch_y

@staticmethod
def make_output(batch_outputs):
# construct full prediction with batch outputs
return np.concatenate(batch_outputs, axis=0)
def pad(batch, fill=0):
"""
Pad a batch of samples to maximum length.
:param batch: list of list
:param fill: word index to pad, default 0.
:return: a padded batch
"""
max_length = max([len(x) for x in batch])
for idx, sample in enumerate(batch):
if len(sample) < max_length:
batch[idx] = sample + [fill * (max_length - len(sample))]
return batch

def load_config(self, args):
def data_forward(self, network, data):
raise NotImplementedError

def load_dataset(self, args):
def evaluate(self, predict, truth):
raise NotImplementedError

@property
def matrices(self):
raise NotImplementedError

def mode(self, model, test=True):
"""To do: combine this function with Trainer"""
if test:
model.eval()
else:
model.train()


class POSTester(BaseTester):
"""
Tester for sequence labeling.
"""

def __init__(self, test_args):
super(POSTester, self).__init__(test_args)
self.max_len = None
self.mask = None

def data_forward(self, network, x):
"""To Do: combine with Trainer

:param network: the PyTorch model
:param x: list of list, [batch_size, max_len]
:return y: [batch_size, num_classes]
"""
seq_len = [len(seq) for seq in x]
x = torch.Tensor(x).long()
self.batch_size = x.size(0)
self.max_len = x.size(1)
self.mask = seq_mask(seq_len, self.max_len)
y = network(x)
return y

def evaluate(self, predict, truth):
"""To Do: """
return 0

+ 19
- 17
fastNLP/action/trainer.py View File

@@ -5,7 +5,7 @@ import torch

from fastNLP.action.action import Action
from fastNLP.action.action import RandomSampler, Batchifier
from fastNLP.action.tester import Tester
from fastNLP.action.tester import POSTester
from fastNLP.modules.utils import seq_mask


@@ -43,7 +43,7 @@ class BaseTrainer(Action):
self.optimizer = None

def train(self, network):
"""General training loop.
"""General Training Steps
:param network: a model

The method is framework independent.
@@ -57,23 +57,27 @@ class BaseTrainer(Action):
- update
Subclasses must implement these methods with a specific framework.
"""
# prepare model and data
self.model = network
data_train, data_dev, data_test, embedding = self.prepare_input(self.pickle_path)

test_args = Tester.TestConfig(save_output=True, validate_in_training=True,
save_dev_input=True, save_loss=True, batch_size=self.batch_size)
evaluator = Tester(test_args)
# define tester over dev data
valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True,
"save_loss": True, "batch_size": self.batch_size, "pickle_path": self.pickle_path}
validator = POSTester(valid_args)

best_loss = 1e10
# main training epochs
iterations = len(data_train) // self.batch_size

for epoch in range(self.n_epochs):

# turn on network training mode; define optimizer; prepare batch iterator
self.mode(test=False)
self.define_optimizer()
self.iterator = iter(Batchifier(RandomSampler(data_train), self.batch_size, drop_last=True))

# training iterations in one epoch
for step in range(iterations):
batch_x, batch_y = self.batchify(self.batch_size, data_train)
batch_x, batch_y = self.batchify(data_train)

prediction = self.data_forward(network, batch_x)

@@ -84,9 +88,7 @@ class BaseTrainer(Action):
if self.validate:
if data_dev is None:
raise RuntimeError("No validation data provided.")
evaluator.test(network, data_dev)
if evaluator.loss < best_loss:
best_loss = evaluator.loss
validator.test(network)

# finish training

@@ -162,11 +164,10 @@ class BaseTrainer(Action):
"""
raise NotImplementedError

def batchify(self, batch_size, data):
def batchify(self, data):
"""
1. Perform batching from data and produce a batch of training data.
2. Add padding.
:param batch_size: int, the size of a batch
:param data: list. Each entry is a sample, which is also a list of features and label(s).
E.g.
[
@@ -200,7 +201,9 @@ class BaseTrainer(Action):


class ToyTrainer(BaseTrainer):
"""A simple trainer for a PyTorch model."""
"""
deprecated
"""

def __init__(self, train_args):
super(ToyTrainer, self).__init__(train_args)
@@ -235,7 +238,7 @@ class ToyTrainer(BaseTrainer):

class WordSegTrainer(BaseTrainer):
"""
reserve for changes
deprecated
"""

def __init__(self, train_args):
@@ -319,7 +322,6 @@ class WordSegTrainer(BaseTrainer):
self.optimizer.step()



class POSTrainer(BaseTrainer):
"""
Trainer for Sequence Modeling
@@ -391,4 +393,4 @@ if __name__ == "__name__":
train_args = {"epochs": 1, "validate": False, "batch_size": 3, "pickle_path": "./"}
trainer = BaseTrainer(train_args)
data_train = [[[1, 2, 3, 4], [0]] * 10] + [[[1, 3, 5, 2], [1]] * 10]
trainer.batchify(batch_size=3, data=data_train)
trainer.batchify(data=data_train)

+ 1
- 1
test/test_POS_pipeline.py View File

@@ -23,7 +23,7 @@ if __name__ == "__main__":

# Trainer
train_args = {"epochs": 20, "batch_size": 1, "num_classes": num_classes,
"vocab_size": vocab_size, "pickle_path": pickle_path, "validate": False}
"vocab_size": vocab_size, "pickle_path": pickle_path, "validate": True}
trainer = POSTrainer(train_args)

# Model


Loading…
Cancel
Save