diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index d941c235..e3b4f36e 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -17,6 +17,38 @@ class Callback(object): super(Callback, self).__init__() self.trainer = None # 在Trainer内部被重新赋值 + # callback只读属性 + self._n_epochs = None + self._n_steps = None + self._batch_size = None + self._model = None + self._pbar = None + self._optimizer = None + + @property + def n_epochs(self): + return self._n_epochs + + @property + def n_steps(self): + return self._n_steps + + @property + def batch_size(self): + return self._batch_size + + @property + def model(self): + return self._model + + @property + def pbar(self): + return self._pbar + + @property + def optimizer(self): + return self._optimizer + def on_train_begin(self): # before the main training loop pass @@ -101,8 +133,6 @@ def transfer(func): def wrapper(manager, *arg): returns = [] for callback in manager.callbacks: - for env_name, env_value in manager.env.items(): - setattr(callback, env_name, env_value) returns.append(getattr(callback, func.__name__)(*arg)) return returns @@ -115,15 +145,15 @@ class CallbackManager(Callback): """ - def __init__(self, env, callbacks=None): + def __init__(self, env, attr, callbacks=None): """ :param dict env: The key is the name of the Trainer attribute(str). The value is the attribute itself. + :param dict attr: read-only attributes for all callbacks :param Callback callbacks: """ super(CallbackManager, self).__init__() # set attribute of trainer environment - self.env = env self.callbacks = [] if callbacks is not None: @@ -136,6 +166,23 @@ class CallbackManager(Callback): else: raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") + for env_name, env_val in env.items(): + for callback in self.callbacks: + setattr(callback, env_name, env_val) # Callback.trainer + + self.set_property(**attr) + + def set_property(self, **kwargs): + """设置所有callback的只读属性 + + :param kwargs: + :return: + """ + for callback in self.callbacks: + for k, v in kwargs.items(): + setattr(callback, "_" + k, v) + + @transfer def on_train_begin(self): pass diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 8880291d..743570fd 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -121,7 +121,6 @@ class Trainer(object): self.best_dev_perf = None self.sampler = sampler if sampler is not None else RandomSampler() self.prefetch = prefetch - self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) if isinstance(optimizer, torch.optim.Optimizer): self.optimizer = optimizer @@ -144,6 +143,12 @@ class Trainer(object): self.step = 0 self.start_time = None # start timestamp + self.callback_manager = CallbackManager(env={"trainer": self}, + attr={"n_epochs": self.n_epochs, "n_steps": self.step, + "batch_size": self.batch_size, "model": self.model, + "optimizer": self.optimizer}, + callbacks=callbacks) + def train(self, load_best_model=True): """ @@ -236,6 +241,7 @@ class Trainer(object): avg_loss = 0 data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, prefetch=self.prefetch) + self.callback_manager.set_property(pbar=pbar) for epoch in range(1, self.n_epochs+1): pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) # early stopping diff --git a/test/core/test_callbacks.py b/test/core/test_callbacks.py index 74ce4876..7d66620c 100644 --- a/test/core/test_callbacks.py +++ b/test/core/test_callbacks.py @@ -136,3 +136,28 @@ class TestCallback(unittest.TestCase): metrics=AccuracyMetric(pred="predict", target="y"), callbacks=[TensorboardCallback("loss", "metric")]) trainer.train() + + def test_readonly_property(self): + from fastNLP.core.callback import Callback + class MyCallback(Callback): + def __init__(self): + super(MyCallback, self).__init__() + + def on_epoch_begin(self, cur_epoch, total_epoch): + print(self.n_epochs, self.n_steps, self.batch_size) + print(self.model) + print(self.optimizer) + + data_set, model = prepare_env() + trainer = Trainer(data_set, model, + loss=BCELoss(pred="predict", target="y"), + n_epochs=5, + batch_size=32, + print_every=50, + optimizer=SGD(lr=0.1), + check_code_level=2, + use_tqdm=False, + dev_data=data_set, + metrics=AccuracyMetric(pred="predict", target="y"), + callbacks=[MyCallback()]) + trainer.train()