@@ -1,4 +1,7 @@ | |||||
import os | |||||
import torch | import torch | ||||
from tensorboardX import SummaryWriter | |||||
from fastNLP.io.model_io import ModelSaver, ModelLoader | from fastNLP.io.model_io import ModelSaver, ModelLoader | ||||
@@ -12,6 +15,7 @@ class Callback(object): | |||||
def __init__(self): | def __init__(self): | ||||
super(Callback, self).__init__() | super(Callback, self).__init__() | ||||
self.trainer = None # 在Trainer内部被重新赋值 | |||||
def before_train(self): | def before_train(self): | ||||
# before the main training loop | # before the main training loop | ||||
@@ -333,8 +337,6 @@ class SmoothValue(object): | |||||
class LRFinder(Callback): | class LRFinder(Callback): | ||||
"""fastai lr_finder""" | |||||
def __init__(self, n_batch, start_lr=1e-6, end_lr=10): | def __init__(self, n_batch, start_lr=1e-6, end_lr=10): | ||||
"""用第一个 epoch 找最佳的学习率,从第二个epoch开始应用它 | """用第一个 epoch 找最佳的学习率,从第二个epoch开始应用它 | ||||
@@ -395,6 +397,71 @@ class LRFinder(Callback): | |||||
print("Model reset. \nFind best lr={}".format(self.best_lr)) | print("Model reset. \nFind best lr={}".format(self.best_lr)) | ||||
class TensorboardCallback(Callback): | |||||
""" | |||||
接受以下一个或多个字符串作为参数: | |||||
- "model" | |||||
- "loss" | |||||
- "metric" | |||||
""" | |||||
def __init__(self, *options): | |||||
super(TensorboardCallback, self).__init__() | |||||
args = {"model", "loss", "metric"} | |||||
for opt in options: | |||||
if opt not in args: | |||||
raise ValueError("Unrecognized argument {}. Expect one of {}".format(opt, args)) | |||||
self.options = options | |||||
self._summary_writer = None | |||||
self.graph_added = False | |||||
def before_train(self): | |||||
save_dir = self.trainer.save_path | |||||
if save_dir is None: | |||||
path = os.path.join("./", 'tensorboard_logs_{}'.format(self.trainer.start_time)) | |||||
else: | |||||
path = os.path.join(save_dir, 'tensorboard_logs_{}'.format(self.trainer.start_time)) | |||||
self._summary_writer = SummaryWriter(path) | |||||
def before_batch(self, batch_x, batch_y, indices): | |||||
if "model" in self.options and self.graph_added is False: | |||||
# tesorboardX 这里有大bug,暂时没法画模型图 | |||||
# from fastNLP.core.utils import _build_args | |||||
# inputs = _build_args(self.trainer.model, **batch_x) | |||||
# args = tuple([value for value in inputs.values()]) | |||||
# args = args[0] if len(args) == 1 else args | |||||
# self._summary_writer.add_graph(self.trainer.model, torch.zeros(32, 2)) | |||||
self.graph_added = True | |||||
def before_backward(self, loss, model): | |||||
if "loss" in self.options: | |||||
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.trainer.step) | |||||
if "model" in self.options: | |||||
for name, param in self.trainer.model.named_parameters(): | |||||
if param.requires_grad: | |||||
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.trainer.step) | |||||
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.trainer.step) | |||||
self._summary_writer.add_scalar(name + "_grad_mean", param.grad.mean(), | |||||
global_step=self.trainer.step) | |||||
def after_valid(self, eval_result, metric_key, optimizer): | |||||
if "metric" in self.options: | |||||
for name, metric in eval_result.items(): | |||||
for metric_key, metric_val in metric.items(): | |||||
self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val, | |||||
global_step=self.trainer.step) | |||||
def after_train(self, model): | |||||
self._summary_writer.close() | |||||
del self._summary_writer | |||||
def on_exception(self, exception, model): | |||||
if hasattr(self, "_summary_writer"): | |||||
self._summary_writer.close() | |||||
del self._summary_writer | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()]) | manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()]) | ||||
manager.before_train(10, 11, 12) | manager.before_train(10, 11, 12) | ||||
@@ -5,7 +5,6 @@ from datetime import timedelta | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
from tensorboardX import SummaryWriter | |||||
from torch import nn | from torch import nn | ||||
try: | try: | ||||
@@ -195,21 +194,9 @@ class Trainer(object): | |||||
self._model_device = self.model.parameters().__next__().device | self._model_device = self.model.parameters().__next__().device | ||||
self._mode(self.model, is_test=False) | self._mode(self.model, is_test=False) | ||||
self.start_time = str(datetime.now().strftime('%Y-%m-%d %H-%M-%S')) | |||||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | |||||
start_time = time.time() | start_time = time.time() | ||||
print("training epochs started " + self.start_time, flush=True) | print("training epochs started " + self.start_time, flush=True) | ||||
if self.save_path is None: | |||||
class psudoSW: | |||||
def __getattr__(self, item): | |||||
def pass_func(*args, **kwargs): | |||||
pass | |||||
return pass_func | |||||
self._summary_writer = psudoSW() | |||||
else: | |||||
path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time)) | |||||
self._summary_writer = SummaryWriter(path) | |||||
try: | try: | ||||
self.callback_manager.before_train() | self.callback_manager.before_train() | ||||
@@ -232,8 +219,7 @@ class Trainer(object): | |||||
else: | else: | ||||
print("Fail to reload best model.") | print("Fail to reload best model.") | ||||
finally: | finally: | ||||
self._summary_writer.close() | |||||
del self._summary_writer | |||||
pass | |||||
results['seconds'] = round(time.time() - start_time, 2) | results['seconds'] = round(time.time() - start_time, 2) | ||||
return results | return results | ||||
@@ -261,7 +247,7 @@ class Trainer(object): | |||||
# negative sampling; replace unknown; re-weight batch_y | # negative sampling; replace unknown; re-weight batch_y | ||||
self.callback_manager.before_batch(batch_x, batch_y, indices) | self.callback_manager.before_batch(batch_x, batch_y, indices) | ||||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device, | _move_dict_value_to_device(batch_x, batch_y, device=self._model_device, | ||||
non_blocking=self.pin_memory) # pin_memory, use non_blockling. | |||||
non_blocking=self.pin_memory) # pin_memory, use non_blocking. | |||||
prediction = self._data_forward(self.model, batch_x) | prediction = self._data_forward(self.model, batch_x) | ||||
# edit prediction | # edit prediction | ||||
@@ -279,12 +265,6 @@ class Trainer(object): | |||||
# lr scheduler; lr_finder; one_cycle | # lr scheduler; lr_finder; one_cycle | ||||
self.callback_manager.after_step(self.optimizer) | self.callback_manager.after_step(self.optimizer) | ||||
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) | |||||
for name, param in self.model.named_parameters(): | |||||
if param.requires_grad: | |||||
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step) | |||||
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step) | |||||
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step) | |||||
if (self.step+1) % self.print_every == 0: | if (self.step+1) % self.print_every == 0: | ||||
if self.use_tqdm: | if self.use_tqdm: | ||||
print_output = "loss:{0:<6.5f}".format(avg_loss / self.print_every) | print_output = "loss:{0:<6.5f}".format(avg_loss / self.print_every) | ||||
@@ -319,10 +299,7 @@ class Trainer(object): | |||||
def _do_validation(self, epoch, step): | def _do_validation(self, epoch, step): | ||||
res = self.tester.test() | res = self.tester.test() | ||||
for name, metric in res.items(): | |||||
for metric_key, metric_val in metric.items(): | |||||
self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val, | |||||
global_step=self.step) | |||||
if self._better_eval_result(res): | if self._better_eval_result(res): | ||||
if self.save_path is not None: | if self.save_path is not None: | ||||
self._save_model(self.model, | self._save_model(self.model, | ||||
@@ -1,3 +1,4 @@ | |||||
import time | |||||
import unittest | import unittest | ||||
import numpy as np | import numpy as np | ||||
@@ -8,7 +9,7 @@ from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.dataset import construct_dataset | from fastNLP.core.dataset import construct_dataset | ||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.core.sampler import SequentialSampler | from fastNLP.core.sampler import SequentialSampler | ||||
import time | |||||
def generate_fake_dataset(num_samples=1000): | def generate_fake_dataset(num_samples=1000): | ||||
""" | """ | ||||
@@ -161,12 +162,13 @@ class TestCase1(unittest.TestCase): | |||||
dataset = generate_fake_dataset(num_samples) | dataset = generate_fake_dataset(num_samples) | ||||
batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), pin_memory=True) | batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), pin_memory=True) | ||||
for batch_x, batch_y in batch: | |||||
time.sleep(pause_seconds) | |||||
# 这里发生OOM | |||||
# for batch_x, batch_y in batch: | |||||
# time.sleep(pause_seconds) | |||||
num_workers = 2 | num_workers = 2 | ||||
batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), num_workers=num_workers, | batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), num_workers=num_workers, | ||||
pin_memory=True) | pin_memory=True) | ||||
for batch_x, batch_y in batch: | |||||
time.sleep(pause_seconds) | |||||
# 这里发生OOM | |||||
# for batch_x, batch_y in batch: | |||||
# time.sleep(pause_seconds) |
@@ -3,7 +3,9 @@ import unittest | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
from fastNLP.core.callback import EchoCallback, EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, LRFinder | |||||
from fastNLP.core.callback import EchoCallback, EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, \ | |||||
LRFinder, \ | |||||
TensorboardCallback | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.core.losses import BCELoss | from fastNLP.core.losses import BCELoss | ||||
@@ -119,3 +121,18 @@ class TestCallback(unittest.TestCase): | |||||
use_tqdm=False, | use_tqdm=False, | ||||
callbacks=[LRFinder(len(data_set) // 32)]) | callbacks=[LRFinder(len(data_set) // 32)]) | ||||
trainer.train() | trainer.train() | ||||
def test_TensorboardCallback(self): | |||||
data_set, model = prepare_env() | |||||
trainer = Trainer(data_set, model, | |||||
loss=BCELoss(pred="predict", target="y"), | |||||
n_epochs=5, | |||||
batch_size=32, | |||||
print_every=50, | |||||
optimizer=SGD(lr=0.1), | |||||
check_code_level=2, | |||||
use_tqdm=False, | |||||
dev_data=data_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), | |||||
callbacks=[TensorboardCallback("loss", "metric")]) | |||||
trainer.train() |