Browse Source

update trainer, tester, example model

tags/v0.2.0
yunfan 6 years ago
parent
commit
f3bb3cb578
3 changed files with 49 additions and 37 deletions
  1. +17
    -13
      fastNLP/core/tester.py
  2. +29
    -21
      fastNLP/core/trainer.py
  3. +3
    -3
      fastNLP/models/cnn_text_classification.py

+ 17
- 13
fastNLP/core/tester.py View File

@@ -10,28 +10,32 @@ from fastNLP.core.utils import _build_args
class Tester(object):
"""An collection of model inference and evaluation of performance, used over validation/dev set and test set. """

def __init__(self, batch_size, evaluator, use_cuda, save_path="./save/", **kwargs):
def __init__(self, data, model, batch_size, use_cuda, save_path="./save/", **kwargs):
super(Tester, self).__init__()

self.use_cuda = use_cuda
self.data = data
self.batch_size = batch_size
self.pickle_path = save_path
self.use_cuda = use_cuda
self._evaluator = evaluator

self._model = None
self.eval_history = [] # evaluation results of all batches

def test(self, network, dev_data):
if torch.cuda.is_available() and self.use_cuda:
self._model = network.cuda()
self._model = model.cuda()
else:
self._model = network
self._model = model
if hasattr(self._model, 'predict'):
assert callable(self._model.predict)
self._predict_func = self._model.predict
else:
self._predict_func = self._model
assert hasattr(model, 'evaluate')
self._evaluator = model.evaluate
self.eval_history = [] # evaluation results of all batches

def test(self):
# turn on the testing mode; clean up the history
network = self._model
self.mode(network, is_test=True)
self.eval_history.clear()
output, truths = defaultdict(list), defaultdict(list)
data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), as_numpy=False)
data_iterator = Batch(self.data, self.batch_size, sampler=RandomSampler(), as_numpy=False)

with torch.no_grad():
for batch_x, batch_y in data_iterator:
@@ -67,7 +71,7 @@ class Tester(object):
def data_forward(self, network, x):
"""A forward pass of the model. """
x = _build_args(network.forward, **x)
y = network(**x)
y = self._predict_func(**x)
return y

def print_eval_results(self, results):


+ 29
- 21
fastNLP/core/trainer.py View File

@@ -4,9 +4,10 @@ from datetime import datetime
import warnings
from collections import defaultdict
import os
import itertools
import shutil

from tensorboardX import SummaryWriter
import torch

from fastNLP.core.batch import Batch
from fastNLP.core.loss import Loss
@@ -51,17 +52,18 @@ class Trainer(object):
self.evaluator = self.model.evaluate

if self.dev_data is not None:
valid_args = {"batch_size": self.batch_size, "save_path": self.save_path,
"use_cuda": self.use_cuda, "evaluator": self.evaluator}
self.tester = Tester(**valid_args)
self.tester = Tester(model=self.model,
data=self.dev_data,
batch_size=self.batch_size,
save_path=self.save_path,
use_cuda=self.use_cuda)

for k, v in kwargs.items():
setattr(self, k, v)

self.tensorboard_path = os.path.join(self.save_path, 'tensorboard_logs')
if os.path.exists(self.tensorboard_path):
os.rmdir(self.tensorboard_path)
self._summary_writer = SummaryWriter(self.tensorboard_path)
shutil.rmtree(self.tensorboard_path)
self._graph_summaried = False
self.step = 0
self.start_time = None # start timestamp
@@ -73,26 +75,32 @@ class Trainer(object):

:return:
"""
if torch.cuda.is_available() and self.use_cuda:
self.model = self.model.cuda()
try:
self._summary_writer = SummaryWriter(self.tensorboard_path)

self.mode(self.model, is_test=False)
if torch.cuda.is_available() and self.use_cuda:
self.model = self.model.cuda()

start = time.time()
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))
print("training epochs started " + self.start_time)
self.mode(self.model, is_test=False)

epoch = 1
while epoch <= self.n_epochs:
start = time.time()
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))
print("training epochs started " + self.start_time)

data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(), as_numpy=False)
epoch = 1
while epoch <= self.n_epochs:

self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start)
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(), as_numpy=False)

if self.dev_data:
self.do_validation()
self.save_model(self.model, 'training_model_' + self.start_time)
epoch += 1
self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start)

if self.dev_data:
self.do_validation()
self.save_model(self.model, 'training_model_' + self.start_time)
epoch += 1
finally:
self._summary_writer.close()
del self._summary_writer

def _train_epoch(self, data_iterator, model, epoch, dev_data, start, **kwargs):
"""Training process in one epoch.
@@ -127,7 +135,7 @@ class Trainer(object):
self.step += 1

def do_validation(self):
res = self.tester.test(self.model, self.dev_data)
res = self.tester.test()
for name, num in res.items():
self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step)
self.save_model(self.model, 'best_model_' + self.start_time)


+ 3
- 3
fastNLP/models/cnn_text_classification.py View File

@@ -48,16 +48,16 @@ class CNNText(torch.nn.Module):

def predict(self, word_seq):
output = self(word_seq)
_, predict = output.max(dim=1)
_, predict = output['output'].max(dim=1)
return {'predict': predict}

def get_loss(self, output, label_seq):
return self._loss(output, label_seq)

def evaluate(self, predict, label_seq):
predict, label_seq = torch.stack(predict, dim=0), torch.stack(label_seq, dim=0)
predict, label_seq = torch.stack(tuple(predict), dim=0), torch.stack(tuple(label_seq), dim=0)
predict, label_seq = predict.squeeze(), label_seq.squeeze()
correct = (predict == label_seq).long().sum().item()
total = label_seq.size(0)
return 1.0 * correct / total
return {'acc': 1.0 * correct / total}


Loading…
Cancel
Save