* 给callback添加给定几个只读属性 * 通过manager设置这些属性 * 代码优化,减轻@transfer的负担tags/v0.4.10
@@ -17,6 +17,38 @@ class Callback(object): | |||||
super(Callback, self).__init__() | super(Callback, self).__init__() | ||||
self.trainer = None # 在Trainer内部被重新赋值 | self.trainer = None # 在Trainer内部被重新赋值 | ||||
# callback只读属性 | |||||
self._n_epochs = None | |||||
self._n_steps = None | |||||
self._batch_size = None | |||||
self._model = None | |||||
self._pbar = None | |||||
self._optimizer = None | |||||
@property | |||||
def n_epochs(self): | |||||
return self._n_epochs | |||||
@property | |||||
def n_steps(self): | |||||
return self._n_steps | |||||
@property | |||||
def batch_size(self): | |||||
return self._batch_size | |||||
@property | |||||
def model(self): | |||||
return self._model | |||||
@property | |||||
def pbar(self): | |||||
return self._pbar | |||||
@property | |||||
def optimizer(self): | |||||
return self._optimizer | |||||
def on_train_begin(self): | def on_train_begin(self): | ||||
# before the main training loop | # before the main training loop | ||||
pass | pass | ||||
@@ -101,8 +133,6 @@ def transfer(func): | |||||
def wrapper(manager, *arg): | def wrapper(manager, *arg): | ||||
returns = [] | returns = [] | ||||
for callback in manager.callbacks: | for callback in manager.callbacks: | ||||
for env_name, env_value in manager.env.items(): | |||||
setattr(callback, env_name, env_value) | |||||
returns.append(getattr(callback, func.__name__)(*arg)) | returns.append(getattr(callback, func.__name__)(*arg)) | ||||
return returns | return returns | ||||
@@ -115,15 +145,15 @@ class CallbackManager(Callback): | |||||
""" | """ | ||||
def __init__(self, env, callbacks=None): | |||||
def __init__(self, env, attr, callbacks=None): | |||||
""" | """ | ||||
:param dict env: The key is the name of the Trainer attribute(str). The value is the attribute itself. | :param dict env: The key is the name of the Trainer attribute(str). The value is the attribute itself. | ||||
:param dict attr: read-only attributes for all callbacks | |||||
:param Callback callbacks: | :param Callback callbacks: | ||||
""" | """ | ||||
super(CallbackManager, self).__init__() | super(CallbackManager, self).__init__() | ||||
# set attribute of trainer environment | # set attribute of trainer environment | ||||
self.env = env | |||||
self.callbacks = [] | self.callbacks = [] | ||||
if callbacks is not None: | if callbacks is not None: | ||||
@@ -136,6 +166,23 @@ class CallbackManager(Callback): | |||||
else: | else: | ||||
raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") | raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") | ||||
for env_name, env_val in env.items(): | |||||
for callback in self.callbacks: | |||||
setattr(callback, env_name, env_val) # Callback.trainer | |||||
self.set_property(**attr) | |||||
def set_property(self, **kwargs): | |||||
"""设置所有callback的只读属性 | |||||
:param kwargs: | |||||
:return: | |||||
""" | |||||
for callback in self.callbacks: | |||||
for k, v in kwargs.items(): | |||||
setattr(callback, "_" + k, v) | |||||
@transfer | @transfer | ||||
def on_train_begin(self): | def on_train_begin(self): | ||||
pass | pass | ||||
@@ -121,7 +121,6 @@ class Trainer(object): | |||||
self.best_dev_perf = None | self.best_dev_perf = None | ||||
self.sampler = sampler if sampler is not None else RandomSampler() | self.sampler = sampler if sampler is not None else RandomSampler() | ||||
self.prefetch = prefetch | self.prefetch = prefetch | ||||
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) | |||||
if isinstance(optimizer, torch.optim.Optimizer): | if isinstance(optimizer, torch.optim.Optimizer): | ||||
self.optimizer = optimizer | self.optimizer = optimizer | ||||
@@ -144,6 +143,12 @@ class Trainer(object): | |||||
self.step = 0 | self.step = 0 | ||||
self.start_time = None # start timestamp | self.start_time = None # start timestamp | ||||
self.callback_manager = CallbackManager(env={"trainer": self}, | |||||
attr={"n_epochs": self.n_epochs, "n_steps": self.step, | |||||
"batch_size": self.batch_size, "model": self.model, | |||||
"optimizer": self.optimizer}, | |||||
callbacks=callbacks) | |||||
def train(self, load_best_model=True): | def train(self, load_best_model=True): | ||||
""" | """ | ||||
@@ -236,6 +241,7 @@ class Trainer(object): | |||||
avg_loss = 0 | avg_loss = 0 | ||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | ||||
prefetch=self.prefetch) | prefetch=self.prefetch) | ||||
self.callback_manager.set_property(pbar=pbar) | |||||
for epoch in range(1, self.n_epochs+1): | for epoch in range(1, self.n_epochs+1): | ||||
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | ||||
# early stopping | # early stopping | ||||
@@ -136,3 +136,28 @@ class TestCallback(unittest.TestCase): | |||||
metrics=AccuracyMetric(pred="predict", target="y"), | metrics=AccuracyMetric(pred="predict", target="y"), | ||||
callbacks=[TensorboardCallback("loss", "metric")]) | callbacks=[TensorboardCallback("loss", "metric")]) | ||||
trainer.train() | trainer.train() | ||||
def test_readonly_property(self): | |||||
from fastNLP.core.callback import Callback | |||||
class MyCallback(Callback): | |||||
def __init__(self): | |||||
super(MyCallback, self).__init__() | |||||
def on_epoch_begin(self, cur_epoch, total_epoch): | |||||
print(self.n_epochs, self.n_steps, self.batch_size) | |||||
print(self.model) | |||||
print(self.optimizer) | |||||
data_set, model = prepare_env() | |||||
trainer = Trainer(data_set, model, | |||||
loss=BCELoss(pred="predict", target="y"), | |||||
n_epochs=5, | |||||
batch_size=32, | |||||
print_every=50, | |||||
optimizer=SGD(lr=0.1), | |||||
check_code_level=2, | |||||
use_tqdm=False, | |||||
dev_data=data_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), | |||||
callbacks=[MyCallback()]) | |||||
trainer.train() |