Browse Source

update trainer, tester, example model

tags/v0.2.0
yunfan 6 years ago
parent
commit
f3bb3cb578
3 changed files with 49 additions and 37 deletions
  1. +17
    -13
      fastNLP/core/tester.py
  2. +29
    -21
      fastNLP/core/trainer.py
  3. +3
    -3
      fastNLP/models/cnn_text_classification.py

+ 17
- 13
fastNLP/core/tester.py View File

@@ -10,28 +10,32 @@ from fastNLP.core.utils import _build_args
class Tester(object): class Tester(object):
"""An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ """An collection of model inference and evaluation of performance, used over validation/dev set and test set. """


def __init__(self, batch_size, evaluator, use_cuda, save_path="./save/", **kwargs):
def __init__(self, data, model, batch_size, use_cuda, save_path="./save/", **kwargs):
super(Tester, self).__init__() super(Tester, self).__init__()

self.use_cuda = use_cuda
self.data = data
self.batch_size = batch_size self.batch_size = batch_size
self.pickle_path = save_path self.pickle_path = save_path
self.use_cuda = use_cuda
self._evaluator = evaluator

self._model = None
self.eval_history = [] # evaluation results of all batches

def test(self, network, dev_data):
if torch.cuda.is_available() and self.use_cuda: if torch.cuda.is_available() and self.use_cuda:
self._model = network.cuda()
self._model = model.cuda()
else: else:
self._model = network
self._model = model
if hasattr(self._model, 'predict'):
assert callable(self._model.predict)
self._predict_func = self._model.predict
else:
self._predict_func = self._model
assert hasattr(model, 'evaluate')
self._evaluator = model.evaluate
self.eval_history = [] # evaluation results of all batches


def test(self):
# turn on the testing mode; clean up the history # turn on the testing mode; clean up the history
network = self._model
self.mode(network, is_test=True) self.mode(network, is_test=True)
self.eval_history.clear() self.eval_history.clear()
output, truths = defaultdict(list), defaultdict(list) output, truths = defaultdict(list), defaultdict(list)
data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), as_numpy=False)
data_iterator = Batch(self.data, self.batch_size, sampler=RandomSampler(), as_numpy=False)


with torch.no_grad(): with torch.no_grad():
for batch_x, batch_y in data_iterator: for batch_x, batch_y in data_iterator:
@@ -67,7 +71,7 @@ class Tester(object):
def data_forward(self, network, x): def data_forward(self, network, x):
"""A forward pass of the model. """ """A forward pass of the model. """
x = _build_args(network.forward, **x) x = _build_args(network.forward, **x)
y = network(**x)
y = self._predict_func(**x)
return y return y


def print_eval_results(self, results): def print_eval_results(self, results):


+ 29
- 21
fastNLP/core/trainer.py View File

@@ -4,9 +4,10 @@ from datetime import datetime
import warnings import warnings
from collections import defaultdict from collections import defaultdict
import os import os
import itertools
import shutil


from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
import torch


from fastNLP.core.batch import Batch from fastNLP.core.batch import Batch
from fastNLP.core.loss import Loss from fastNLP.core.loss import Loss
@@ -51,17 +52,18 @@ class Trainer(object):
self.evaluator = self.model.evaluate self.evaluator = self.model.evaluate


if self.dev_data is not None: if self.dev_data is not None:
valid_args = {"batch_size": self.batch_size, "save_path": self.save_path,
"use_cuda": self.use_cuda, "evaluator": self.evaluator}
self.tester = Tester(**valid_args)
self.tester = Tester(model=self.model,
data=self.dev_data,
batch_size=self.batch_size,
save_path=self.save_path,
use_cuda=self.use_cuda)


for k, v in kwargs.items(): for k, v in kwargs.items():
setattr(self, k, v) setattr(self, k, v)


self.tensorboard_path = os.path.join(self.save_path, 'tensorboard_logs') self.tensorboard_path = os.path.join(self.save_path, 'tensorboard_logs')
if os.path.exists(self.tensorboard_path): if os.path.exists(self.tensorboard_path):
os.rmdir(self.tensorboard_path)
self._summary_writer = SummaryWriter(self.tensorboard_path)
shutil.rmtree(self.tensorboard_path)
self._graph_summaried = False self._graph_summaried = False
self.step = 0 self.step = 0
self.start_time = None # start timestamp self.start_time = None # start timestamp
@@ -73,26 +75,32 @@ class Trainer(object):


:return: :return:
""" """
if torch.cuda.is_available() and self.use_cuda:
self.model = self.model.cuda()
try:
self._summary_writer = SummaryWriter(self.tensorboard_path)


self.mode(self.model, is_test=False)
if torch.cuda.is_available() and self.use_cuda:
self.model = self.model.cuda()


start = time.time()
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))
print("training epochs started " + self.start_time)
self.mode(self.model, is_test=False)


epoch = 1
while epoch <= self.n_epochs:
start = time.time()
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))
print("training epochs started " + self.start_time)


data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(), as_numpy=False)
epoch = 1
while epoch <= self.n_epochs:


self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start)
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(), as_numpy=False)


if self.dev_data:
self.do_validation()
self.save_model(self.model, 'training_model_' + self.start_time)
epoch += 1
self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start)

if self.dev_data:
self.do_validation()
self.save_model(self.model, 'training_model_' + self.start_time)
epoch += 1
finally:
self._summary_writer.close()
del self._summary_writer


def _train_epoch(self, data_iterator, model, epoch, dev_data, start, **kwargs): def _train_epoch(self, data_iterator, model, epoch, dev_data, start, **kwargs):
"""Training process in one epoch. """Training process in one epoch.
@@ -127,7 +135,7 @@ class Trainer(object):
self.step += 1 self.step += 1


def do_validation(self): def do_validation(self):
res = self.tester.test(self.model, self.dev_data)
res = self.tester.test()
for name, num in res.items(): for name, num in res.items():
self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step) self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step)
self.save_model(self.model, 'best_model_' + self.start_time) self.save_model(self.model, 'best_model_' + self.start_time)


+ 3
- 3
fastNLP/models/cnn_text_classification.py View File

@@ -48,16 +48,16 @@ class CNNText(torch.nn.Module):


def predict(self, word_seq): def predict(self, word_seq):
output = self(word_seq) output = self(word_seq)
_, predict = output.max(dim=1)
_, predict = output['output'].max(dim=1)
return {'predict': predict} return {'predict': predict}


def get_loss(self, output, label_seq): def get_loss(self, output, label_seq):
return self._loss(output, label_seq) return self._loss(output, label_seq)


def evaluate(self, predict, label_seq): def evaluate(self, predict, label_seq):
predict, label_seq = torch.stack(predict, dim=0), torch.stack(label_seq, dim=0)
predict, label_seq = torch.stack(tuple(predict), dim=0), torch.stack(tuple(label_seq), dim=0)
predict, label_seq = predict.squeeze(), label_seq.squeeze() predict, label_seq = predict.squeeze(), label_seq.squeeze()
correct = (predict == label_seq).long().sum().item() correct = (predict == label_seq).long().sum().item()
total = label_seq.size(0) total = label_seq.size(0)
return 1.0 * correct / total
return {'acc': 1.0 * correct / total}



Loading…
Cancel
Save