@@ -2,7 +2,6 @@ import _pickle
import os
import time
from datetime import timedelta
from time import time
import numpy as np
import torch
@@ -12,9 +11,11 @@ from fastNLP.core.action import Action
from fastNLP.core.action import RandomSampler, Batchifier
from fastNLP.core.tester import SeqLabelTester, ClassificationTester
from fastNLP.modules import utils
from fastNLP.saver.logger import create_logger
from fastNLP.saver.model_saver import ModelSaver
DEFAULT_QUEUE_SIZE = 300
logger = create_logger(__name__, "./train_test.log")
class BaseTrainer(object):
@@ -73,6 +74,7 @@ class BaseTrainer(object):
self.model = network
data_train = self.load_train_data(self.pickle_path)
logger.info("training data loaded")
# define tester over dev data
if self.validate:
@@ -80,33 +82,42 @@ class BaseTrainer(object):
"save_loss": True, "batch_size": self.batch_size, "pickle_path": self.pickle_path,
"use_cuda": self.use_cuda}
validator = self._create_validator(default_valid_args)
logger.info("validator defined as {}".format(str(validator)))
self.define_optimizer()
logger.info("optimizer defined as {}".format(str(self.optimizer)))
# main training epochs
start = time.time()
n_samples = len(data_train)
n_batches = n_samples // self.batch_size
n_print = 1
start = time.time()
logger.info("training epochs started")
for epoch in range(1, self.n_epochs + 1):
logger.info("training epoch {}".format(epoch))
# turn on network training mode
self.mode(network, test=False)
# prepare mini-batch iterator
data_iterator = iter(Batchifier(RandomSampler(data_train), self.batch_size, drop_last=False))
logger.info("prepared data iterator")
self._train_step(data_iterator, network, start=start, n_print=n_print, epoch=epoch)
if self.validate:
logger.info("validation started")
validator.test(network)
if self.save_best_dev and self.best_eval_result(validator):
self.save_model(network)
print("saved better model selected by dev")
logger.info("saved better model selected by dev")
print("[epoch {}]".format(epoch), end=" ")
print(validator.show_matrices())
valid_results = validator.show_matrices()
print("[epoch {}] {}".format(epoch, valid_results))
logger.info("[epoch {}] {}".format(epoch, valid_results))
def _train_step(self, data_iterator, network, **kwargs):
"""Training process in one epoch."""
@@ -122,8 +133,10 @@ class BaseTrainer(object):
if step % kwargs["n_print"] == 0:
end = time.time()
diff = timedelta(seconds=round(end - kwargs["start"]))
print("[epoch: {:>3} step: {:>4}] train loss: {:>4.2} time: {}".format(
kwargs["epoch"], step, loss.data, diff))
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.2} time: {}".format(
kwargs["epoch"], step, loss.data, diff)
print(print_output)
logger.info(print_output)
step += 1
def load_train_data(self, pickle_path):
@@ -137,6 +150,7 @@ class BaseTrainer(object):
with open(file_path, 'rb') as f:
data = _pickle.load(f)
else:
logger.error("cannot find training data {}. invalid input path for training data.".format(file_path))
raise RuntimeError("cannot find training data {}".format(file_path))
return data
@@ -182,7 +196,9 @@ class BaseTrainer(object):
if self.loss_func is None:
if hasattr(self.model, "loss"):
self.loss_func = self.model.loss
logger.info("The model has a loss function, use it.")
else:
logger.info("The model didn't define loss, use Trainer's loss.")
self.define_loss()
return self.loss_func(predict, truth)