Browse Source

将tesorboardX处理为callback, 从trainer移除tensorboardX相关代码

tags/v0.3.1^2
FengZiYjun 6 years ago
parent
commit
f3cb812554
4 changed files with 99 additions and 36 deletions
  1. +69
    -2
      fastNLP/core/callback.py
  2. +4
    -27
      fastNLP/core/trainer.py
  3. +8
    -6
      test/core/test_batch.py
  4. +18
    -1
      test/core/test_callbacks.py

+ 69
- 2
fastNLP/core/callback.py View File

@@ -1,4 +1,7 @@
import os

import torch import torch
from tensorboardX import SummaryWriter


from fastNLP.io.model_io import ModelSaver, ModelLoader from fastNLP.io.model_io import ModelSaver, ModelLoader


@@ -12,6 +15,7 @@ class Callback(object):


def __init__(self): def __init__(self):
super(Callback, self).__init__() super(Callback, self).__init__()
self.trainer = None # 在Trainer内部被重新赋值


def before_train(self): def before_train(self):
# before the main training loop # before the main training loop
@@ -333,8 +337,6 @@ class SmoothValue(object):




class LRFinder(Callback): class LRFinder(Callback):
"""fastai lr_finder"""

def __init__(self, n_batch, start_lr=1e-6, end_lr=10): def __init__(self, n_batch, start_lr=1e-6, end_lr=10):
"""用第一个 epoch 找最佳的学习率,从第二个epoch开始应用它 """用第一个 epoch 找最佳的学习率,从第二个epoch开始应用它


@@ -395,6 +397,71 @@ class LRFinder(Callback):
print("Model reset. \nFind best lr={}".format(self.best_lr)) print("Model reset. \nFind best lr={}".format(self.best_lr))




class TensorboardCallback(Callback):
"""
接受以下一个或多个字符串作为参数:
- "model"
- "loss"
- "metric"
"""

def __init__(self, *options):
super(TensorboardCallback, self).__init__()
args = {"model", "loss", "metric"}
for opt in options:
if opt not in args:
raise ValueError("Unrecognized argument {}. Expect one of {}".format(opt, args))
self.options = options
self._summary_writer = None
self.graph_added = False

def before_train(self):
save_dir = self.trainer.save_path
if save_dir is None:
path = os.path.join("./", 'tensorboard_logs_{}'.format(self.trainer.start_time))
else:
path = os.path.join(save_dir, 'tensorboard_logs_{}'.format(self.trainer.start_time))
self._summary_writer = SummaryWriter(path)

def before_batch(self, batch_x, batch_y, indices):
if "model" in self.options and self.graph_added is False:
# tesorboardX 这里有大bug,暂时没法画模型图
# from fastNLP.core.utils import _build_args
# inputs = _build_args(self.trainer.model, **batch_x)
# args = tuple([value for value in inputs.values()])
# args = args[0] if len(args) == 1 else args
# self._summary_writer.add_graph(self.trainer.model, torch.zeros(32, 2))
self.graph_added = True

def before_backward(self, loss, model):
if "loss" in self.options:
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.trainer.step)

if "model" in self.options:
for name, param in self.trainer.model.named_parameters():
if param.requires_grad:
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.trainer.step)
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.trainer.step)
self._summary_writer.add_scalar(name + "_grad_mean", param.grad.mean(),
global_step=self.trainer.step)

def after_valid(self, eval_result, metric_key, optimizer):
if "metric" in self.options:
for name, metric in eval_result.items():
for metric_key, metric_val in metric.items():
self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val,
global_step=self.trainer.step)

def after_train(self, model):
self._summary_writer.close()
del self._summary_writer

def on_exception(self, exception, model):
if hasattr(self, "_summary_writer"):
self._summary_writer.close()
del self._summary_writer


if __name__ == "__main__": if __name__ == "__main__":
manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()]) manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()])
manager.before_train(10, 11, 12) manager.before_train(10, 11, 12)


+ 4
- 27
fastNLP/core/trainer.py View File

@@ -5,7 +5,6 @@ from datetime import timedelta


import numpy as np import numpy as np
import torch import torch
from tensorboardX import SummaryWriter
from torch import nn from torch import nn


try: try:
@@ -195,21 +194,9 @@ class Trainer(object):
self._model_device = self.model.parameters().__next__().device self._model_device = self.model.parameters().__next__().device
self._mode(self.model, is_test=False) self._mode(self.model, is_test=False)


self.start_time = str(datetime.now().strftime('%Y-%m-%d %H-%M-%S'))
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))
start_time = time.time() start_time = time.time()
print("training epochs started " + self.start_time, flush=True) print("training epochs started " + self.start_time, flush=True)
if self.save_path is None:
class psudoSW:
def __getattr__(self, item):
def pass_func(*args, **kwargs):
pass

return pass_func

self._summary_writer = psudoSW()
else:
path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time))
self._summary_writer = SummaryWriter(path)


try: try:
self.callback_manager.before_train() self.callback_manager.before_train()
@@ -232,8 +219,7 @@ class Trainer(object):
else: else:
print("Fail to reload best model.") print("Fail to reload best model.")
finally: finally:
self._summary_writer.close()
del self._summary_writer
pass
results['seconds'] = round(time.time() - start_time, 2) results['seconds'] = round(time.time() - start_time, 2)


return results return results
@@ -261,7 +247,7 @@ class Trainer(object):
# negative sampling; replace unknown; re-weight batch_y # negative sampling; replace unknown; re-weight batch_y
self.callback_manager.before_batch(batch_x, batch_y, indices) self.callback_manager.before_batch(batch_x, batch_y, indices)
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device, _move_dict_value_to_device(batch_x, batch_y, device=self._model_device,
non_blocking=self.pin_memory) # pin_memory, use non_blockling.
non_blocking=self.pin_memory) # pin_memory, use non_blocking.
prediction = self._data_forward(self.model, batch_x) prediction = self._data_forward(self.model, batch_x)


# edit prediction # edit prediction
@@ -279,12 +265,6 @@ class Trainer(object):
# lr scheduler; lr_finder; one_cycle # lr scheduler; lr_finder; one_cycle
self.callback_manager.after_step(self.optimizer) self.callback_manager.after_step(self.optimizer)


self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step)
for name, param in self.model.named_parameters():
if param.requires_grad:
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step)
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step)
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step)
if (self.step+1) % self.print_every == 0: if (self.step+1) % self.print_every == 0:
if self.use_tqdm: if self.use_tqdm:
print_output = "loss:{0:<6.5f}".format(avg_loss / self.print_every) print_output = "loss:{0:<6.5f}".format(avg_loss / self.print_every)
@@ -319,10 +299,7 @@ class Trainer(object):


def _do_validation(self, epoch, step): def _do_validation(self, epoch, step):
res = self.tester.test() res = self.tester.test()
for name, metric in res.items():
for metric_key, metric_val in metric.items():
self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val,
global_step=self.step)

if self._better_eval_result(res): if self._better_eval_result(res):
if self.save_path is not None: if self.save_path is not None:
self._save_model(self.model, self._save_model(self.model,


+ 8
- 6
test/core/test_batch.py View File

@@ -1,3 +1,4 @@
import time
import unittest import unittest


import numpy as np import numpy as np
@@ -8,7 +9,7 @@ from fastNLP.core.dataset import DataSet
from fastNLP.core.dataset import construct_dataset from fastNLP.core.dataset import construct_dataset
from fastNLP.core.instance import Instance from fastNLP.core.instance import Instance
from fastNLP.core.sampler import SequentialSampler from fastNLP.core.sampler import SequentialSampler
import time


def generate_fake_dataset(num_samples=1000): def generate_fake_dataset(num_samples=1000):
""" """
@@ -161,12 +162,13 @@ class TestCase1(unittest.TestCase):
dataset = generate_fake_dataset(num_samples) dataset = generate_fake_dataset(num_samples)


batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), pin_memory=True) batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), pin_memory=True)
for batch_x, batch_y in batch:
time.sleep(pause_seconds)
# 这里发生OOM
# for batch_x, batch_y in batch:
# time.sleep(pause_seconds)


num_workers = 2 num_workers = 2
batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), num_workers=num_workers, batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), num_workers=num_workers,
pin_memory=True) pin_memory=True)
for batch_x, batch_y in batch:
time.sleep(pause_seconds)
# 这里发生OOM
# for batch_x, batch_y in batch:
# time.sleep(pause_seconds)

+ 18
- 1
test/core/test_callbacks.py View File

@@ -3,7 +3,9 @@ import unittest
import numpy as np import numpy as np
import torch import torch


from fastNLP.core.callback import EchoCallback, EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, LRFinder
from fastNLP.core.callback import EchoCallback, EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, \
LRFinder, \
TensorboardCallback
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance from fastNLP.core.instance import Instance
from fastNLP.core.losses import BCELoss from fastNLP.core.losses import BCELoss
@@ -119,3 +121,18 @@ class TestCallback(unittest.TestCase):
use_tqdm=False, use_tqdm=False,
callbacks=[LRFinder(len(data_set) // 32)]) callbacks=[LRFinder(len(data_set) // 32)])
trainer.train() trainer.train()

def test_TensorboardCallback(self):
data_set, model = prepare_env()
trainer = Trainer(data_set, model,
loss=BCELoss(pred="predict", target="y"),
n_epochs=5,
batch_size=32,
print_every=50,
optimizer=SGD(lr=0.1),
check_code_level=2,
use_tqdm=False,
dev_data=data_set,
metrics=AccuracyMetric(pred="predict", target="y"),
callbacks=[TensorboardCallback("loss", "metric")])
trainer.train()

Loading…
Cancel
Save