Browse Source

添加了 on_after_optimizers_step 和 on_after_zero_grad 的callback接口

tags/v1.0.0alpha
YWMditto 2 years ago
parent
commit
3ee6fc66f5
9 changed files with 72 additions and 29 deletions
  1. +21
    -1
      fastNLP/core/callbacks/callback.py
  2. +3
    -1
      fastNLP/core/callbacks/callback_events.py
  3. +9
    -1
      fastNLP/core/callbacks/callback_manager.py
  4. +15
    -6
      fastNLP/core/controllers/trainer.py
  5. +8
    -2
      fastNLP/core/controllers/utils/utils.py
  6. +0
    -8
      fastNLP/core/drivers/torch_driver/ddp.py
  7. +0
    -8
      fastNLP/core/drivers/torch_driver/single_device.py
  8. +8
    -0
      fastNLP/core/drivers/torch_driver/torch_driver.py
  9. +8
    -2
      tests/helpers/callbacks/helper_callbacks.py

+ 21
- 1
fastNLP/core/callbacks/callback.py View File

@@ -184,7 +184,7 @@ class Callback:
"""
pass

def on_before_optimizer_step(self, trainer, optimizers):
def on_before_optimizers_step(self, trainer, optimizers):
"""
在进行 optimizer 优化进行前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。

@@ -194,6 +194,16 @@ class Callback:
"""
pass

def on_after_optimizers_step(self, trainer, optimizers):
"""
在进行 optimizer 优化进行后调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。

:param trainer:
:param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。
:return:
"""
pass

def on_before_zero_grad(self, trainer, optimizers):
"""
在进行模型梯度置零前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。
@@ -204,6 +214,16 @@ class Callback:
"""
pass

def on_after_zero_grad(self, trainer, optimizers):
"""
在进行模型梯度置零后调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。

:param trainer:
:param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。
:return:
"""
pass

def on_validate_begin(self, trainer):
"""
在将要进行 validate 时调用。如果是设置的以 step 数量 或 自定义地 决定 validate 的频率,该接口是在 on_train_batch_end 之后


+ 3
- 1
fastNLP/core/callbacks/callback_events.py View File

@@ -92,8 +92,10 @@ class Events(EventEnum):
ON_LOAD_CHECKPOINT = "on_load_checkpoint"
ON_BEFORE_BACKWARD = "on_before_backward"
ON_AFTER_BACKWARD = "on_after_backward"
ON_BEFORE_OPTIMIZER_STEP = "on_before_optimizer_step"
ON_BEFORE_OPTIMIZERS_STEP = "on_before_optimizers_step"
ON_AFTER_OPTIMIZERS_STEP = "on_after_optimizers_step"
ON_BEFORE_ZERO_GRAD = "on_before_zero_grad"
ON_AFTER_ZERO_GRAD = "on_after_zero_grad"
ON_VALIDATE_BEGIN = "on_validate_begin"
ON_VALIDATE_END = "on_validate_end"



+ 9
- 1
fastNLP/core/callbacks/callback_manager.py View File

@@ -278,13 +278,21 @@ class CallbackManager:
pass

@_transfer
def on_before_optimizer_step(self, trainer, optimizers):
def on_before_optimizers_step(self, trainer, optimizers):
pass

@_transfer
def on_after_optimizers_step(self, trainer, optimizers):
pass

@_transfer
def on_before_zero_grad(self, trainer, optimizers):
pass

@_transfer
def on_after_zero_grad(self, trainer, optimizers):
pass

@_transfer
def on_validate_begin(self, trainer):
pass


+ 15
- 6
fastNLP/core/controllers/trainer.py View File

@@ -137,6 +137,7 @@ class Trainer(TrainerEventTrigger):
else:
self.driver_name = driver.__class__.__name__
self.device = device
self.optimizers = optimizers
self.fp16 = fp16
self.input_mapping = input_mapping
self.output_mapping = output_mapping
@@ -440,9 +441,11 @@ class Trainer(TrainerEventTrigger):

2. 函数作用
这一函数的作用在于检查用户定制的 batch_step_fn / TrainBatchLoop 是否能够正确地调用 callback 函数,更准确地说,当用户实际
定制了 ("on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad") /
定制了 ("on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", "on_before_zero_grad",
"on_after_zero_grad") /
("on_fetch_data_begin", "on_fetch_data_end", "on_train_batch_begin", "on_train_batch_end",
"on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad")
"on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", "on_before_zero_grad",
"on_after_zero_grad")
这些 callabck_fn 后,如果其同样也定制了 batch_step_fn / TrainBatchLoop,那么其有可能忘记了在自己的 batch_step_fn 中
上述的这些 callback 函数,而这个函数的作用就在于检测用户是否产生了这一行为;

@@ -452,10 +455,12 @@ class Trainer(TrainerEventTrigger):
'batch_step_fn',为 False 时表示检测 'TrainBatchLoop';
"""
if check_mode:
callbacks = ("on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad")
callbacks = ("on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step",
"on_before_zero_grad", "on_after_zero_grad")
else:
callbacks = ("on_fetch_data_begin", "on_fetch_data_end", "on_train_batch_begin", "on_train_batch_end",
"on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad")
"on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step",
"on_before_zero_grad", "on_after_zero_grad")
_not_called_callback_fns = []
for each_callback_fn in callbacks:
if each_callback_fn in self.callback_manager.callback_fns:
@@ -699,13 +704,15 @@ class Trainer(TrainerEventTrigger):

def zero_grad(self):
if (self.global_forward_batches + 1) % self.accumulation_steps == 0:
self.on_before_zero_grad(self.driver.optimizers)
self.on_before_zero_grad(self.optimizers)
self.driver.zero_grad(self.set_grad_to_none)
self.on_after_zero_grad(self.optimizers)

def step(self):
if (self.global_forward_batches + 1) % self.accumulation_steps == 0:
self.on_before_optimizer_step(self.driver.optimizers)
self.on_before_optimizers_step(self.optimizers)
self.driver.step()
self.on_after_optimizers_step(self.optimizers)

def move_data_to_device(self, batch):
return self.driver.move_data_to_device(batch)
@@ -817,3 +824,5 @@ class Trainer(TrainerEventTrigger):






+ 8
- 2
fastNLP/core/controllers/utils/utils.py View File

@@ -68,12 +68,18 @@ class TrainerEventTrigger:
def on_after_backward(self):
self.callback_manager.on_after_backward(self)

def on_before_optimizer_step(self, optimizers):
self.callback_manager.on_before_optimizer_step(self, optimizers)
def on_before_optimizers_step(self, optimizers):
self.callback_manager.on_before_optimizers_step(self, optimizers)

def on_after_optimizers_step(self, optimizers):
self.callback_manager.on_after_optimizers_step(self, optimizers)

def on_before_zero_grad(self, optimizers):
self.callback_manager.on_before_zero_grad(self, optimizers)

def on_after_zero_grad(self, optimizers):
self.callback_manager.on_after_zero_grad(self, optimizers)

def on_validate_begin(self):
self.callback_manager.on_validate_begin(self)



+ 0
- 8
fastNLP/core/drivers/torch_driver/ddp.py View File

@@ -530,14 +530,6 @@ class TorchDDPDriver(TorchDriver):
else:
raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).")

def backward(self, loss):
self.grad_scaler.scale(loss).backward()

def step(self):
for optimizer in self.optimizers:
self.grad_scaler.step(optimizer)
self.grad_scaler.update()

def is_global_zero(self):
return self.global_rank == 0



+ 0
- 8
fastNLP/core/drivers/torch_driver/single_device.py View File

@@ -107,14 +107,6 @@ class TorchSingleDriver(TorchDriver):
else:
return self._train_step(batch)

def backward(self, loss):
self.grad_scaler.scale(loss).backward()

def step(self):
for optimizer in self.optimizers:
self.grad_scaler.step(optimizer)
self.grad_scaler.update()

def validate_step(self, batch) -> Dict:
# 因为我们 Tester 的逻辑就是将所有的 metric 传给 tester,然后 tester 控制具体 metric 的 update 和 compute;因此不管用户是否
# 实现 validate_step 函数,其都应该返回一个字典,具体使用哪些东西则是在 validate_batch_loop 中每一个具体的 metric 自己去拿的;


+ 8
- 0
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -72,6 +72,14 @@ class TorchDriver(Driver):
p.grad.requires_grad_(False)
p.grad.zero_()

def backward(self, loss):
self.grad_scaler.scale(loss).backward()

def step(self):
for optimizer in self.optimizers:
self.grad_scaler.step(optimizer)
self.grad_scaler.update()

@staticmethod
def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
if is_train:


+ 8
- 2
tests/helpers/callbacks/helper_callbacks.py View File

@@ -101,12 +101,18 @@ class RecordTrainerEventTriggerCallback(Callback):
def on_after_backward(self, trainer):
print("on_after_backward")

def on_before_optimizer_step(self, trainer, optimizers):
print("on_before_optimizer_step")
def on_before_optimizers_step(self, trainer, optimizers):
print("on_before_optimizers_step")

def on_after_optimizers_step(self, trainer, optimizers):
print("on_after_optimizers_step")

def on_before_zero_grad(self, trainer, optimizers):
print("on_before_zero_grad")

def on_after_zero_grad(self, trainer, optimizers):
print("on_after_zero_grad")

def on_validate_begin(self, trainer):
print("on_validate_begin")



Loading…
Cancel
Save