* 给callback添加给定几个只读属性 * 通过manager设置这些属性 * 代码优化,减轻@transfer的负担tags/v0.4.10
@@ -17,6 +17,38 @@ class Callback(object): | |||
super(Callback, self).__init__() | |||
self.trainer = None # 在Trainer内部被重新赋值 | |||
# callback只读属性 | |||
self._n_epochs = None | |||
self._n_steps = None | |||
self._batch_size = None | |||
self._model = None | |||
self._pbar = None | |||
self._optimizer = None | |||
@property | |||
def n_epochs(self): | |||
return self._n_epochs | |||
@property | |||
def n_steps(self): | |||
return self._n_steps | |||
@property | |||
def batch_size(self): | |||
return self._batch_size | |||
@property | |||
def model(self): | |||
return self._model | |||
@property | |||
def pbar(self): | |||
return self._pbar | |||
@property | |||
def optimizer(self): | |||
return self._optimizer | |||
def on_train_begin(self): | |||
# before the main training loop | |||
pass | |||
@@ -101,8 +133,6 @@ def transfer(func): | |||
def wrapper(manager, *arg): | |||
returns = [] | |||
for callback in manager.callbacks: | |||
for env_name, env_value in manager.env.items(): | |||
setattr(callback, env_name, env_value) | |||
returns.append(getattr(callback, func.__name__)(*arg)) | |||
return returns | |||
@@ -115,15 +145,15 @@ class CallbackManager(Callback): | |||
""" | |||
def __init__(self, env, callbacks=None): | |||
def __init__(self, env, attr, callbacks=None): | |||
""" | |||
:param dict env: The key is the name of the Trainer attribute(str). The value is the attribute itself. | |||
:param dict attr: read-only attributes for all callbacks | |||
:param Callback callbacks: | |||
""" | |||
super(CallbackManager, self).__init__() | |||
# set attribute of trainer environment | |||
self.env = env | |||
self.callbacks = [] | |||
if callbacks is not None: | |||
@@ -136,6 +166,23 @@ class CallbackManager(Callback): | |||
else: | |||
raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") | |||
for env_name, env_val in env.items(): | |||
for callback in self.callbacks: | |||
setattr(callback, env_name, env_val) # Callback.trainer | |||
self.set_property(**attr) | |||
def set_property(self, **kwargs): | |||
"""设置所有callback的只读属性 | |||
:param kwargs: | |||
:return: | |||
""" | |||
for callback in self.callbacks: | |||
for k, v in kwargs.items(): | |||
setattr(callback, "_" + k, v) | |||
@transfer | |||
def on_train_begin(self): | |||
pass | |||
@@ -121,7 +121,6 @@ class Trainer(object): | |||
self.best_dev_perf = None | |||
self.sampler = sampler if sampler is not None else RandomSampler() | |||
self.prefetch = prefetch | |||
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) | |||
if isinstance(optimizer, torch.optim.Optimizer): | |||
self.optimizer = optimizer | |||
@@ -144,6 +143,12 @@ class Trainer(object): | |||
self.step = 0 | |||
self.start_time = None # start timestamp | |||
self.callback_manager = CallbackManager(env={"trainer": self}, | |||
attr={"n_epochs": self.n_epochs, "n_steps": self.step, | |||
"batch_size": self.batch_size, "model": self.model, | |||
"optimizer": self.optimizer}, | |||
callbacks=callbacks) | |||
def train(self, load_best_model=True): | |||
""" | |||
@@ -236,6 +241,7 @@ class Trainer(object): | |||
avg_loss = 0 | |||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | |||
prefetch=self.prefetch) | |||
self.callback_manager.set_property(pbar=pbar) | |||
for epoch in range(1, self.n_epochs+1): | |||
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | |||
# early stopping | |||
@@ -136,3 +136,28 @@ class TestCallback(unittest.TestCase): | |||
metrics=AccuracyMetric(pred="predict", target="y"), | |||
callbacks=[TensorboardCallback("loss", "metric")]) | |||
trainer.train() | |||
def test_readonly_property(self): | |||
from fastNLP.core.callback import Callback | |||
class MyCallback(Callback): | |||
def __init__(self): | |||
super(MyCallback, self).__init__() | |||
def on_epoch_begin(self, cur_epoch, total_epoch): | |||
print(self.n_epochs, self.n_steps, self.batch_size) | |||
print(self.model) | |||
print(self.optimizer) | |||
data_set, model = prepare_env() | |||
trainer = Trainer(data_set, model, | |||
loss=BCELoss(pred="predict", target="y"), | |||
n_epochs=5, | |||
batch_size=32, | |||
print_every=50, | |||
optimizer=SGD(lr=0.1), | |||
check_code_level=2, | |||
use_tqdm=False, | |||
dev_data=data_set, | |||
metrics=AccuracyMetric(pred="predict", target="y"), | |||
callbacks=[MyCallback()]) | |||
trainer.train() |