Browse Source

Changes to Callbacks:

* 给callback添加给定几个只读属性
* 通过manager设置这些属性
* 代码优化,减轻@transfer的负担
tags/v0.4.10
FengZiYjun 5 years ago
parent
commit
ef0c6e936d
3 changed files with 83 additions and 5 deletions
  1. +51
    -4
      fastNLP/core/callback.py
  2. +7
    -1
      fastNLP/core/trainer.py
  3. +25
    -0
      test/core/test_callbacks.py

+ 51
- 4
fastNLP/core/callback.py View File

@@ -17,6 +17,38 @@ class Callback(object):
super(Callback, self).__init__()
self.trainer = None # 在Trainer内部被重新赋值

# callback只读属性
self._n_epochs = None
self._n_steps = None
self._batch_size = None
self._model = None
self._pbar = None
self._optimizer = None

@property
def n_epochs(self):
return self._n_epochs

@property
def n_steps(self):
return self._n_steps

@property
def batch_size(self):
return self._batch_size

@property
def model(self):
return self._model

@property
def pbar(self):
return self._pbar

@property
def optimizer(self):
return self._optimizer

def on_train_begin(self):
# before the main training loop
pass
@@ -101,8 +133,6 @@ def transfer(func):
def wrapper(manager, *arg):
returns = []
for callback in manager.callbacks:
for env_name, env_value in manager.env.items():
setattr(callback, env_name, env_value)
returns.append(getattr(callback, func.__name__)(*arg))
return returns

@@ -115,15 +145,15 @@ class CallbackManager(Callback):

"""

def __init__(self, env, callbacks=None):
def __init__(self, env, attr, callbacks=None):
"""

:param dict env: The key is the name of the Trainer attribute(str). The value is the attribute itself.
:param dict attr: read-only attributes for all callbacks
:param Callback callbacks:
"""
super(CallbackManager, self).__init__()
# set attribute of trainer environment
self.env = env

self.callbacks = []
if callbacks is not None:
@@ -136,6 +166,23 @@ class CallbackManager(Callback):
else:
raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.")

for env_name, env_val in env.items():
for callback in self.callbacks:
setattr(callback, env_name, env_val) # Callback.trainer

self.set_property(**attr)

def set_property(self, **kwargs):
"""设置所有callback的只读属性

:param kwargs:
:return:
"""
for callback in self.callbacks:
for k, v in kwargs.items():
setattr(callback, "_" + k, v)


@transfer
def on_train_begin(self):
pass


+ 7
- 1
fastNLP/core/trainer.py View File

@@ -121,7 +121,6 @@ class Trainer(object):
self.best_dev_perf = None
self.sampler = sampler if sampler is not None else RandomSampler()
self.prefetch = prefetch
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks)

if isinstance(optimizer, torch.optim.Optimizer):
self.optimizer = optimizer
@@ -144,6 +143,12 @@ class Trainer(object):
self.step = 0
self.start_time = None # start timestamp

self.callback_manager = CallbackManager(env={"trainer": self},
attr={"n_epochs": self.n_epochs, "n_steps": self.step,
"batch_size": self.batch_size, "model": self.model,
"optimizer": self.optimizer},
callbacks=callbacks)

def train(self, load_best_model=True):
"""

@@ -236,6 +241,7 @@ class Trainer(object):
avg_loss = 0
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
prefetch=self.prefetch)
self.callback_manager.set_property(pbar=pbar)
for epoch in range(1, self.n_epochs+1):
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
# early stopping


+ 25
- 0
test/core/test_callbacks.py View File

@@ -136,3 +136,28 @@ class TestCallback(unittest.TestCase):
metrics=AccuracyMetric(pred="predict", target="y"),
callbacks=[TensorboardCallback("loss", "metric")])
trainer.train()

def test_readonly_property(self):
from fastNLP.core.callback import Callback
class MyCallback(Callback):
def __init__(self):
super(MyCallback, self).__init__()

def on_epoch_begin(self, cur_epoch, total_epoch):
print(self.n_epochs, self.n_steps, self.batch_size)
print(self.model)
print(self.optimizer)

data_set, model = prepare_env()
trainer = Trainer(data_set, model,
loss=BCELoss(pred="predict", target="y"),
n_epochs=5,
batch_size=32,
print_every=50,
optimizer=SGD(lr=0.1),
check_code_level=2,
use_tqdm=False,
dev_data=data_set,
metrics=AccuracyMetric(pred="predict", target="y"),
callbacks=[MyCallback()])
trainer.train()

Loading…
Cancel
Save