From 3ee6fc66f5b37d7cbd8ebbdfcc5ab02e002fab09 Mon Sep 17 00:00:00 2001 From: YWMditto Date: Wed, 13 Apr 2022 15:37:08 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=BA=86=20on=5Fafter=5Fopti?= =?UTF-8?q?mizers=5Fstep=20=E5=92=8C=20on=5Fafter=5Fzero=5Fgrad=20=20?= =?UTF-8?q?=E7=9A=84callback=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callbacks/callback.py | 22 ++++++++++++++++++- fastNLP/core/callbacks/callback_events.py | 4 +++- fastNLP/core/callbacks/callback_manager.py | 10 ++++++++- fastNLP/core/controllers/trainer.py | 21 +++++++++++++----- fastNLP/core/controllers/utils/utils.py | 10 +++++++-- fastNLP/core/drivers/torch_driver/ddp.py | 8 ------- .../drivers/torch_driver/single_device.py | 8 ------- .../core/drivers/torch_driver/torch_driver.py | 8 +++++++ tests/helpers/callbacks/helper_callbacks.py | 10 +++++++-- 9 files changed, 72 insertions(+), 29 deletions(-) diff --git a/fastNLP/core/callbacks/callback.py b/fastNLP/core/callbacks/callback.py index 96e4372b..0b9020fe 100644 --- a/fastNLP/core/callbacks/callback.py +++ b/fastNLP/core/callbacks/callback.py @@ -184,7 +184,7 @@ class Callback: """ pass - def on_before_optimizer_step(self, trainer, optimizers): + def on_before_optimizers_step(self, trainer, optimizers): """ 在进行 optimizer 优化进行前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 @@ -194,6 +194,16 @@ class Callback: """ pass + def on_after_optimizers_step(self, trainer, optimizers): + """ + 在进行 optimizer 优化进行后调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 + + :param trainer: + :param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。 + :return: + """ + pass + def on_before_zero_grad(self, trainer, optimizers): """ 在进行模型梯度置零前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 @@ -204,6 +214,16 @@ class Callback: """ pass + def on_after_zero_grad(self, trainer, optimizers): + """ + 在进行模型梯度置零后调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 + + :param trainer: + :param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。 + :return: + """ + pass + def on_validate_begin(self, trainer): """ 在将要进行 validate 时调用。如果是设置的以 step 数量 或 自定义地 决定 validate 的频率,该接口是在 on_train_batch_end 之后 diff --git a/fastNLP/core/callbacks/callback_events.py b/fastNLP/core/callbacks/callback_events.py index 2bfe8e90..1c805ac2 100644 --- a/fastNLP/core/callbacks/callback_events.py +++ b/fastNLP/core/callbacks/callback_events.py @@ -92,8 +92,10 @@ class Events(EventEnum): ON_LOAD_CHECKPOINT = "on_load_checkpoint" ON_BEFORE_BACKWARD = "on_before_backward" ON_AFTER_BACKWARD = "on_after_backward" - ON_BEFORE_OPTIMIZER_STEP = "on_before_optimizer_step" + ON_BEFORE_OPTIMIZERS_STEP = "on_before_optimizers_step" + ON_AFTER_OPTIMIZERS_STEP = "on_after_optimizers_step" ON_BEFORE_ZERO_GRAD = "on_before_zero_grad" + ON_AFTER_ZERO_GRAD = "on_after_zero_grad" ON_VALIDATE_BEGIN = "on_validate_begin" ON_VALIDATE_END = "on_validate_end" diff --git a/fastNLP/core/callbacks/callback_manager.py b/fastNLP/core/callbacks/callback_manager.py index 8b53c70b..a962fe9f 100644 --- a/fastNLP/core/callbacks/callback_manager.py +++ b/fastNLP/core/callbacks/callback_manager.py @@ -278,13 +278,21 @@ class CallbackManager: pass @_transfer - def on_before_optimizer_step(self, trainer, optimizers): + def on_before_optimizers_step(self, trainer, optimizers): + pass + + @_transfer + def on_after_optimizers_step(self, trainer, optimizers): pass @_transfer def on_before_zero_grad(self, trainer, optimizers): pass + @_transfer + def on_after_zero_grad(self, trainer, optimizers): + pass + @_transfer def on_validate_begin(self, trainer): pass diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index fb62c3f1..a78af9d8 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -137,6 +137,7 @@ class Trainer(TrainerEventTrigger): else: self.driver_name = driver.__class__.__name__ self.device = device + self.optimizers = optimizers self.fp16 = fp16 self.input_mapping = input_mapping self.output_mapping = output_mapping @@ -440,9 +441,11 @@ class Trainer(TrainerEventTrigger): 2. 函数作用 这一函数的作用在于检查用户定制的 batch_step_fn / TrainBatchLoop 是否能够正确地调用 callback 函数,更准确地说,当用户实际 - 定制了 ("on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad") / + 定制了 ("on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", "on_before_zero_grad", + "on_after_zero_grad") / ("on_fetch_data_begin", "on_fetch_data_end", "on_train_batch_begin", "on_train_batch_end", - "on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad") + "on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", "on_before_zero_grad", + "on_after_zero_grad") 这些 callabck_fn 后,如果其同样也定制了 batch_step_fn / TrainBatchLoop,那么其有可能忘记了在自己的 batch_step_fn 中 上述的这些 callback 函数,而这个函数的作用就在于检测用户是否产生了这一行为; @@ -452,10 +455,12 @@ class Trainer(TrainerEventTrigger): 'batch_step_fn',为 False 时表示检测 'TrainBatchLoop'; """ if check_mode: - callbacks = ("on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad") + callbacks = ("on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", + "on_before_zero_grad", "on_after_zero_grad") else: callbacks = ("on_fetch_data_begin", "on_fetch_data_end", "on_train_batch_begin", "on_train_batch_end", - "on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad") + "on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", + "on_before_zero_grad", "on_after_zero_grad") _not_called_callback_fns = [] for each_callback_fn in callbacks: if each_callback_fn in self.callback_manager.callback_fns: @@ -699,13 +704,15 @@ class Trainer(TrainerEventTrigger): def zero_grad(self): if (self.global_forward_batches + 1) % self.accumulation_steps == 0: - self.on_before_zero_grad(self.driver.optimizers) + self.on_before_zero_grad(self.optimizers) self.driver.zero_grad(self.set_grad_to_none) + self.on_after_zero_grad(self.optimizers) def step(self): if (self.global_forward_batches + 1) % self.accumulation_steps == 0: - self.on_before_optimizer_step(self.driver.optimizers) + self.on_before_optimizers_step(self.optimizers) self.driver.step() + self.on_after_optimizers_step(self.optimizers) def move_data_to_device(self, batch): return self.driver.move_data_to_device(batch) @@ -817,3 +824,5 @@ class Trainer(TrainerEventTrigger): + + diff --git a/fastNLP/core/controllers/utils/utils.py b/fastNLP/core/controllers/utils/utils.py index c3f6aeef..0dce0b27 100644 --- a/fastNLP/core/controllers/utils/utils.py +++ b/fastNLP/core/controllers/utils/utils.py @@ -68,12 +68,18 @@ class TrainerEventTrigger: def on_after_backward(self): self.callback_manager.on_after_backward(self) - def on_before_optimizer_step(self, optimizers): - self.callback_manager.on_before_optimizer_step(self, optimizers) + def on_before_optimizers_step(self, optimizers): + self.callback_manager.on_before_optimizers_step(self, optimizers) + + def on_after_optimizers_step(self, optimizers): + self.callback_manager.on_after_optimizers_step(self, optimizers) def on_before_zero_grad(self, optimizers): self.callback_manager.on_before_zero_grad(self, optimizers) + def on_after_zero_grad(self, optimizers): + self.callback_manager.on_after_zero_grad(self, optimizers) + def on_validate_begin(self): self.callback_manager.on_validate_begin(self) diff --git a/fastNLP/core/drivers/torch_driver/ddp.py b/fastNLP/core/drivers/torch_driver/ddp.py index 3537d0b3..11a61dde 100644 --- a/fastNLP/core/drivers/torch_driver/ddp.py +++ b/fastNLP/core/drivers/torch_driver/ddp.py @@ -530,14 +530,6 @@ class TorchDDPDriver(TorchDriver): else: raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") - def backward(self, loss): - self.grad_scaler.scale(loss).backward() - - def step(self): - for optimizer in self.optimizers: - self.grad_scaler.step(optimizer) - self.grad_scaler.update() - def is_global_zero(self): return self.global_rank == 0 diff --git a/fastNLP/core/drivers/torch_driver/single_device.py b/fastNLP/core/drivers/torch_driver/single_device.py index 8cbb7acd..eda438d7 100644 --- a/fastNLP/core/drivers/torch_driver/single_device.py +++ b/fastNLP/core/drivers/torch_driver/single_device.py @@ -107,14 +107,6 @@ class TorchSingleDriver(TorchDriver): else: return self._train_step(batch) - def backward(self, loss): - self.grad_scaler.scale(loss).backward() - - def step(self): - for optimizer in self.optimizers: - self.grad_scaler.step(optimizer) - self.grad_scaler.update() - def validate_step(self, batch) -> Dict: # 因为我们 Tester 的逻辑就是将所有的 metric 传给 tester,然后 tester 控制具体 metric 的 update 和 compute;因此不管用户是否 # 实现 validate_step 函数,其都应该返回一个字典,具体使用哪些东西则是在 validate_batch_loop 中每一个具体的 metric 自己去拿的; diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index d2ffbac1..c8a086fe 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -72,6 +72,14 @@ class TorchDriver(Driver): p.grad.requires_grad_(False) p.grad.zero_() + def backward(self, loss): + self.grad_scaler.scale(loss).backward() + + def step(self): + for optimizer in self.optimizers: + self.grad_scaler.step(optimizer) + self.grad_scaler.update() + @staticmethod def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): if is_train: diff --git a/tests/helpers/callbacks/helper_callbacks.py b/tests/helpers/callbacks/helper_callbacks.py index a1697ab0..751d59f2 100644 --- a/tests/helpers/callbacks/helper_callbacks.py +++ b/tests/helpers/callbacks/helper_callbacks.py @@ -101,12 +101,18 @@ class RecordTrainerEventTriggerCallback(Callback): def on_after_backward(self, trainer): print("on_after_backward") - def on_before_optimizer_step(self, trainer, optimizers): - print("on_before_optimizer_step") + def on_before_optimizers_step(self, trainer, optimizers): + print("on_before_optimizers_step") + + def on_after_optimizers_step(self, trainer, optimizers): + print("on_after_optimizers_step") def on_before_zero_grad(self, trainer, optimizers): print("on_before_zero_grad") + def on_after_zero_grad(self, trainer, optimizers): + print("on_after_zero_grad") + def on_validate_begin(self, trainer): print("on_validate_begin")