@@ -2,7 +2,6 @@ import _pickle
import os
import os
import time
import time
from datetime import timedelta
from datetime import timedelta
from time import time
import numpy as np
import numpy as np
import torch
import torch
@@ -12,9 +11,11 @@ from fastNLP.core.action import Action
from fastNLP.core.action import RandomSampler, Batchifier
from fastNLP.core.action import RandomSampler, Batchifier
from fastNLP.core.tester import SeqLabelTester, ClassificationTester
from fastNLP.core.tester import SeqLabelTester, ClassificationTester
from fastNLP.modules import utils
from fastNLP.modules import utils
from fastNLP.saver.logger import create_logger
from fastNLP.saver.model_saver import ModelSaver
from fastNLP.saver.model_saver import ModelSaver
DEFAULT_QUEUE_SIZE = 300
DEFAULT_QUEUE_SIZE = 300
logger = create_logger(__name__, "./train_test.log")
class BaseTrainer(object):
class BaseTrainer(object):
@@ -73,6 +74,7 @@ class BaseTrainer(object):
self.model = network
self.model = network
data_train = self.load_train_data(self.pickle_path)
data_train = self.load_train_data(self.pickle_path)
logger.info("training data loaded")
# define tester over dev data
# define tester over dev data
if self.validate:
if self.validate:
@@ -80,33 +82,42 @@ class BaseTrainer(object):
"save_loss": True, "batch_size": self.batch_size, "pickle_path": self.pickle_path,
"save_loss": True, "batch_size": self.batch_size, "pickle_path": self.pickle_path,
"use_cuda": self.use_cuda}
"use_cuda": self.use_cuda}
validator = self._create_validator(default_valid_args)
validator = self._create_validator(default_valid_args)
logger.info("validator defined as {}".format(str(validator)))
self.define_optimizer()
self.define_optimizer()
logger.info("optimizer defined as {}".format(str(self.optimizer)))
# main training epochs
# main training epochs
start = time.time()
n_samples = len(data_train)
n_samples = len(data_train)
n_batches = n_samples // self.batch_size
n_batches = n_samples // self.batch_size
n_print = 1
n_print = 1
start = time.time()
logger.info("training epochs started")
for epoch in range(1, self.n_epochs + 1):
for epoch in range(1, self.n_epochs + 1):
logger.info("training epoch {}".format(epoch))
# turn on network training mode
# turn on network training mode
self.mode(network, test=False)
self.mode(network, test=False)
# prepare mini-batch iterator
# prepare mini-batch iterator
data_iterator = iter(Batchifier(RandomSampler(data_train), self.batch_size, drop_last=False))
data_iterator = iter(Batchifier(RandomSampler(data_train), self.batch_size, drop_last=False))
logger.info("prepared data iterator")
self._train_step(data_iterator, network, start=start, n_print=n_print, epoch=epoch)
self._train_step(data_iterator, network, start=start, n_print=n_print, epoch=epoch)
if self.validate:
if self.validate:
logger.info("validation started")
validator.test(network)
validator.test(network)
if self.save_best_dev and self.best_eval_result(validator):
if self.save_best_dev and self.best_eval_result(validator):
self.save_model(network)
self.save_model(network)
print("saved better model selected by dev")
print("saved better model selected by dev")
logger.info("saved better model selected by dev")
print("[epoch {}]".format(epoch), end=" ")
print(validator.show_matrices())
valid_results = validator.show_matrices()
print("[epoch {}] {}".format(epoch, valid_results))
logger.info("[epoch {}] {}".format(epoch, valid_results))
def _train_step(self, data_iterator, network, **kwargs):
def _train_step(self, data_iterator, network, **kwargs):
"""Training process in one epoch."""
"""Training process in one epoch."""
@@ -122,8 +133,10 @@ class BaseTrainer(object):
if step % kwargs["n_print"] == 0:
if step % kwargs["n_print"] == 0:
end = time.time()
end = time.time()
diff = timedelta(seconds=round(end - kwargs["start"]))
diff = timedelta(seconds=round(end - kwargs["start"]))
print("[epoch: {:>3} step: {:>4}] train loss: {:>4.2} time: {}".format(
kwargs["epoch"], step, loss.data, diff))
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.2} time: {}".format(
kwargs["epoch"], step, loss.data, diff)
print(print_output)
logger.info(print_output)
step += 1
step += 1
def load_train_data(self, pickle_path):
def load_train_data(self, pickle_path):
@@ -137,6 +150,7 @@ class BaseTrainer(object):
with open(file_path, 'rb') as f:
with open(file_path, 'rb') as f:
data = _pickle.load(f)
data = _pickle.load(f)
else:
else:
logger.error("cannot find training data {}. invalid input path for training data.".format(file_path))
raise RuntimeError("cannot find training data {}".format(file_path))
raise RuntimeError("cannot find training data {}".format(file_path))
return data
return data
@@ -182,7 +196,9 @@ class BaseTrainer(object):
if self.loss_func is None:
if self.loss_func is None:
if hasattr(self.model, "loss"):
if hasattr(self.model, "loss"):
self.loss_func = self.model.loss
self.loss_func = self.model.loss
logger.info("The model has a loss function, use it.")
else:
else:
logger.info("The model didn't define loss, use Trainer's loss.")
self.define_loss()
self.define_loss()
return self.loss_func(predict, truth)
return self.loss_func(predict, truth)