From ef0c6e936d0503cf025287b36735bf99fd58f6a3 Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Wed, 20 Mar 2019 09:49:01 +0800 Subject: [PATCH] =?UTF-8?q?Changes=20to=20Callbacks:=20*=20=E7=BB=99callba?= =?UTF-8?q?ck=E6=B7=BB=E5=8A=A0=E7=BB=99=E5=AE=9A=E5=87=A0=E4=B8=AA?= =?UTF-8?q?=E5=8F=AA=E8=AF=BB=E5=B1=9E=E6=80=A7=20*=20=E9=80=9A=E8=BF=87ma?= =?UTF-8?q?nager=E8=AE=BE=E7=BD=AE=E8=BF=99=E4=BA=9B=E5=B1=9E=E6=80=A7=20*?= =?UTF-8?q?=20=E4=BB=A3=E7=A0=81=E4=BC=98=E5=8C=96=EF=BC=8C=E5=87=8F?= =?UTF-8?q?=E8=BD=BB@transfer=E7=9A=84=E8=B4=9F=E6=8B=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callback.py | 55 ++++++++++++++++++++++++++++++++++--- fastNLP/core/trainer.py | 8 +++++- test/core/test_callbacks.py | 25 +++++++++++++++++ 3 files changed, 83 insertions(+), 5 deletions(-) diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index d941c235..e3b4f36e 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -17,6 +17,38 @@ class Callback(object): super(Callback, self).__init__() self.trainer = None # 在Trainer内部被重新赋值 + # callback只读属性 + self._n_epochs = None + self._n_steps = None + self._batch_size = None + self._model = None + self._pbar = None + self._optimizer = None + + @property + def n_epochs(self): + return self._n_epochs + + @property + def n_steps(self): + return self._n_steps + + @property + def batch_size(self): + return self._batch_size + + @property + def model(self): + return self._model + + @property + def pbar(self): + return self._pbar + + @property + def optimizer(self): + return self._optimizer + def on_train_begin(self): # before the main training loop pass @@ -101,8 +133,6 @@ def transfer(func): def wrapper(manager, *arg): returns = [] for callback in manager.callbacks: - for env_name, env_value in manager.env.items(): - setattr(callback, env_name, env_value) returns.append(getattr(callback, func.__name__)(*arg)) return returns @@ -115,15 +145,15 @@ class CallbackManager(Callback): """ - def __init__(self, env, callbacks=None): + def __init__(self, env, attr, callbacks=None): """ :param dict env: The key is the name of the Trainer attribute(str). The value is the attribute itself. + :param dict attr: read-only attributes for all callbacks :param Callback callbacks: """ super(CallbackManager, self).__init__() # set attribute of trainer environment - self.env = env self.callbacks = [] if callbacks is not None: @@ -136,6 +166,23 @@ class CallbackManager(Callback): else: raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") + for env_name, env_val in env.items(): + for callback in self.callbacks: + setattr(callback, env_name, env_val) # Callback.trainer + + self.set_property(**attr) + + def set_property(self, **kwargs): + """设置所有callback的只读属性 + + :param kwargs: + :return: + """ + for callback in self.callbacks: + for k, v in kwargs.items(): + setattr(callback, "_" + k, v) + + @transfer def on_train_begin(self): pass diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 8880291d..743570fd 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -121,7 +121,6 @@ class Trainer(object): self.best_dev_perf = None self.sampler = sampler if sampler is not None else RandomSampler() self.prefetch = prefetch - self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) if isinstance(optimizer, torch.optim.Optimizer): self.optimizer = optimizer @@ -144,6 +143,12 @@ class Trainer(object): self.step = 0 self.start_time = None # start timestamp + self.callback_manager = CallbackManager(env={"trainer": self}, + attr={"n_epochs": self.n_epochs, "n_steps": self.step, + "batch_size": self.batch_size, "model": self.model, + "optimizer": self.optimizer}, + callbacks=callbacks) + def train(self, load_best_model=True): """ @@ -236,6 +241,7 @@ class Trainer(object): avg_loss = 0 data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, prefetch=self.prefetch) + self.callback_manager.set_property(pbar=pbar) for epoch in range(1, self.n_epochs+1): pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) # early stopping diff --git a/test/core/test_callbacks.py b/test/core/test_callbacks.py index 74ce4876..7d66620c 100644 --- a/test/core/test_callbacks.py +++ b/test/core/test_callbacks.py @@ -136,3 +136,28 @@ class TestCallback(unittest.TestCase): metrics=AccuracyMetric(pred="predict", target="y"), callbacks=[TensorboardCallback("loss", "metric")]) trainer.train() + + def test_readonly_property(self): + from fastNLP.core.callback import Callback + class MyCallback(Callback): + def __init__(self): + super(MyCallback, self).__init__() + + def on_epoch_begin(self, cur_epoch, total_epoch): + print(self.n_epochs, self.n_steps, self.batch_size) + print(self.model) + print(self.optimizer) + + data_set, model = prepare_env() + trainer = Trainer(data_set, model, + loss=BCELoss(pred="predict", target="y"), + n_epochs=5, + batch_size=32, + print_every=50, + optimizer=SGD(lr=0.1), + check_code_level=2, + use_tqdm=False, + dev_data=data_set, + metrics=AccuracyMetric(pred="predict", target="y"), + callbacks=[MyCallback()]) + trainer.train()