Browse Source

update trainer

tags/v0.2.0
yunfan 6 years ago
parent
commit
1c34a0b732
2 changed files with 29 additions and 39 deletions
  1. +1
    -3
      fastNLP/core/tester.py
  2. +28
    -36
      fastNLP/core/trainer.py

+ 1
- 3
fastNLP/core/tester.py View File

@@ -10,12 +10,11 @@ from fastNLP.core.utils import _build_args
class Tester(object):
"""An collection of model inference and evaluation of performance, used over validation/dev set and test set. """

def __init__(self, data, model, batch_size, use_cuda, save_path="./save/", **kwargs):
def __init__(self, data, model, batch_size=16, use_cuda=False):
super(Tester, self).__init__()
self.use_cuda = use_cuda
self.data = data
self.batch_size = batch_size
self.pickle_path = save_path
if torch.cuda.is_available() and self.use_cuda:
self._model = model.cuda()
else:
@@ -53,7 +52,6 @@ class Tester(object):
eval_results = self._evaluator(**args)
print("[tester] {}".format(self.print_eval_results(eval_results)))
self.mode(network, is_test=False)
self.metrics = eval_results
return eval_results

def mode(self, model, is_test=False):


+ 28
- 36
fastNLP/core/trainer.py View File

@@ -27,7 +27,7 @@ class Trainer(object):
"""
def __init__(self, train_data, model, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1,
dev_data=None, use_cuda=False, save_path="./save",
optimizer=Optimizer("Adam", lr=0.001, weight_decay=0),
optimizer=Optimizer("Adam", lr=0.001, weight_decay=0), need_check_code=True,
**kwargs):
super(Trainer, self).__init__()

@@ -37,9 +37,13 @@ class Trainer(object):
self.n_epochs = int(n_epochs)
self.batch_size = int(batch_size)
self.use_cuda = bool(use_cuda)
self.save_path = str(save_path)
self.save_path = save_path
self.print_every = int(print_every)
self.validate_every = int(validate_every)
self._best_accuracy = 0

if need_check_code:
_check_code(dataset=train_data, model=model, dev_data=dev_data)

model_name = model.__class__.__name__
assert hasattr(self.model, 'get_loss'), "model {} has to have a 'get_loss' function.".format(model_name)
@@ -56,16 +60,11 @@ class Trainer(object):
self.tester = Tester(model=self.model,
data=self.dev_data,
batch_size=self.batch_size,
save_path=self.save_path,
use_cuda=self.use_cuda)

for k, v in kwargs.items():
setattr(self, k, v)

self.tensorboard_path = os.path.join(self.save_path, 'tensorboard_logs')
if os.path.exists(self.tensorboard_path):
shutil.rmtree(self.tensorboard_path)
self._graph_summaried = False
self.step = 0
self.start_time = None # start timestamp

@@ -77,8 +76,6 @@ class Trainer(object):
:return:
"""
try:
self._summary_writer = SummaryWriter(self.tensorboard_path)

if torch.cuda.is_available() and self.use_cuda:
self.model = self.model.cuda()

@@ -87,6 +84,9 @@ class Trainer(object):
start = time.time()
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))
print("training epochs started " + self.start_time)
if self.save_path is not None:
path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time))
self._summary_writer = SummaryWriter(path)

epoch = 1
while epoch <= self.n_epochs:
@@ -143,7 +143,8 @@ class Trainer(object):
res = self.tester.test()
for name, num in res.items():
self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step)
self.save_model(self.model, 'best_model_' + self.start_time)
if self.save_path is not None and self.best_eval_result(res):
self.save_model(self.model, 'best_model_' + self.start_time)

def mode(self, model, is_test=False):
"""Train mode or Test mode. This is for PyTorch currently.
@@ -166,9 +167,6 @@ class Trainer(object):
def data_forward(self, network, x):
x = _build_args(network.forward, **x)
y = network(**x)
if not self._graph_summaried:
# self._summary_writer.add_graph(network, x, verbose=False)
self._graph_summaried = True
return y

def grad_backward(self, loss):
@@ -199,28 +197,27 @@ class Trainer(object):
else:
torch.save(model, model_name)

def best_eval_result(self, metrics):
"""Check if the current epoch yields better validation results.

def best_eval_result(self, metrics):
"""Check if the current epoch yields better validation results.

:return: bool, True means current results on dev set is the best.
"""
if isinstance(metrics, tuple):
loss, metrics = metrics
:return: bool, True means current results on dev set is the best.
"""
if isinstance(metrics, tuple):
loss, metrics = metrics

if isinstance(metrics, dict):
if len(metrics) == 1:
accuracy = list(metrics.values())[0]
if isinstance(metrics, dict):
if len(metrics) == 1:
accuracy = list(metrics.values())[0]
else:
accuracy = metrics[self.eval_sort_key]
else:
accuracy = metrics[self.eval_sort_key]
else:
accuracy = metrics
accuracy = metrics

if accuracy > self._best_accuracy:
self._best_accuracy = accuracy
return True
else:
return False
if accuracy > self._best_accuracy:
self._best_accuracy = accuracy
return True
else:
return False


DEFAULT_CHECK_BATCH_SIZE = 2
@@ -268,9 +265,6 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No
loss.backward()
if batch_count + 1 >= DEFAULT_CHECK_BATCH_SIZE:
break
if check_level > IGNORE_CHECK_LEVEL:
print('Finish checking training process.', flush=True)


if dev_data is not None:
if not hasattr(model, 'evaluate'):
@@ -310,8 +304,6 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No
func_signature = get_func_signature(model.evaluate)
assert isinstance(metrics, dict), "The return value of {} should be dict.". \
format(func_signature)
if check_level > IGNORE_CHECK_LEVEL:
print("Finish checking evaluate process.", flush=True)


def _check_forward_error(model_func, check_level, batch_x):


Loading…
Cancel
Save