Browse Source

添加了 on_after_optimizers_step 和 on_after_zero_grad 的callback接口

tags/v1.0.0alpha
YWMditto 2 years ago
parent
commit
3ee6fc66f5
9 changed files with 72 additions and 29 deletions
  1. +21
    -1
      fastNLP/core/callbacks/callback.py
  2. +3
    -1
      fastNLP/core/callbacks/callback_events.py
  3. +9
    -1
      fastNLP/core/callbacks/callback_manager.py
  4. +15
    -6
      fastNLP/core/controllers/trainer.py
  5. +8
    -2
      fastNLP/core/controllers/utils/utils.py
  6. +0
    -8
      fastNLP/core/drivers/torch_driver/ddp.py
  7. +0
    -8
      fastNLP/core/drivers/torch_driver/single_device.py
  8. +8
    -0
      fastNLP/core/drivers/torch_driver/torch_driver.py
  9. +8
    -2
      tests/helpers/callbacks/helper_callbacks.py

+ 21
- 1
fastNLP/core/callbacks/callback.py View File

@@ -184,7 +184,7 @@ class Callback:
""" """
pass pass


def on_before_optimizer_step(self, trainer, optimizers):
def on_before_optimizers_step(self, trainer, optimizers):
""" """
在进行 optimizer 优化进行前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 在进行 optimizer 优化进行前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。


@@ -194,6 +194,16 @@ class Callback:
""" """
pass pass


def on_after_optimizers_step(self, trainer, optimizers):
"""
在进行 optimizer 优化进行后调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。

:param trainer:
:param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。
:return:
"""
pass

def on_before_zero_grad(self, trainer, optimizers): def on_before_zero_grad(self, trainer, optimizers):
""" """
在进行模型梯度置零前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 在进行模型梯度置零前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。
@@ -204,6 +214,16 @@ class Callback:
""" """
pass pass


def on_after_zero_grad(self, trainer, optimizers):
"""
在进行模型梯度置零后调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。

:param trainer:
:param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。
:return:
"""
pass

def on_validate_begin(self, trainer): def on_validate_begin(self, trainer):
""" """
在将要进行 validate 时调用。如果是设置的以 step 数量 或 自定义地 决定 validate 的频率,该接口是在 on_train_batch_end 之后 在将要进行 validate 时调用。如果是设置的以 step 数量 或 自定义地 决定 validate 的频率,该接口是在 on_train_batch_end 之后


+ 3
- 1
fastNLP/core/callbacks/callback_events.py View File

@@ -92,8 +92,10 @@ class Events(EventEnum):
ON_LOAD_CHECKPOINT = "on_load_checkpoint" ON_LOAD_CHECKPOINT = "on_load_checkpoint"
ON_BEFORE_BACKWARD = "on_before_backward" ON_BEFORE_BACKWARD = "on_before_backward"
ON_AFTER_BACKWARD = "on_after_backward" ON_AFTER_BACKWARD = "on_after_backward"
ON_BEFORE_OPTIMIZER_STEP = "on_before_optimizer_step"
ON_BEFORE_OPTIMIZERS_STEP = "on_before_optimizers_step"
ON_AFTER_OPTIMIZERS_STEP = "on_after_optimizers_step"
ON_BEFORE_ZERO_GRAD = "on_before_zero_grad" ON_BEFORE_ZERO_GRAD = "on_before_zero_grad"
ON_AFTER_ZERO_GRAD = "on_after_zero_grad"
ON_VALIDATE_BEGIN = "on_validate_begin" ON_VALIDATE_BEGIN = "on_validate_begin"
ON_VALIDATE_END = "on_validate_end" ON_VALIDATE_END = "on_validate_end"




+ 9
- 1
fastNLP/core/callbacks/callback_manager.py View File

@@ -278,13 +278,21 @@ class CallbackManager:
pass pass


@_transfer @_transfer
def on_before_optimizer_step(self, trainer, optimizers):
def on_before_optimizers_step(self, trainer, optimizers):
pass

@_transfer
def on_after_optimizers_step(self, trainer, optimizers):
pass pass


@_transfer @_transfer
def on_before_zero_grad(self, trainer, optimizers): def on_before_zero_grad(self, trainer, optimizers):
pass pass


@_transfer
def on_after_zero_grad(self, trainer, optimizers):
pass

@_transfer @_transfer
def on_validate_begin(self, trainer): def on_validate_begin(self, trainer):
pass pass


+ 15
- 6
fastNLP/core/controllers/trainer.py View File

@@ -137,6 +137,7 @@ class Trainer(TrainerEventTrigger):
else: else:
self.driver_name = driver.__class__.__name__ self.driver_name = driver.__class__.__name__
self.device = device self.device = device
self.optimizers = optimizers
self.fp16 = fp16 self.fp16 = fp16
self.input_mapping = input_mapping self.input_mapping = input_mapping
self.output_mapping = output_mapping self.output_mapping = output_mapping
@@ -440,9 +441,11 @@ class Trainer(TrainerEventTrigger):


2. 函数作用 2. 函数作用
这一函数的作用在于检查用户定制的 batch_step_fn / TrainBatchLoop 是否能够正确地调用 callback 函数,更准确地说,当用户实际 这一函数的作用在于检查用户定制的 batch_step_fn / TrainBatchLoop 是否能够正确地调用 callback 函数,更准确地说,当用户实际
定制了 ("on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad") /
定制了 ("on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", "on_before_zero_grad",
"on_after_zero_grad") /
("on_fetch_data_begin", "on_fetch_data_end", "on_train_batch_begin", "on_train_batch_end", ("on_fetch_data_begin", "on_fetch_data_end", "on_train_batch_begin", "on_train_batch_end",
"on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad")
"on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", "on_before_zero_grad",
"on_after_zero_grad")
这些 callabck_fn 后,如果其同样也定制了 batch_step_fn / TrainBatchLoop,那么其有可能忘记了在自己的 batch_step_fn 中 这些 callabck_fn 后,如果其同样也定制了 batch_step_fn / TrainBatchLoop,那么其有可能忘记了在自己的 batch_step_fn 中
上述的这些 callback 函数,而这个函数的作用就在于检测用户是否产生了这一行为; 上述的这些 callback 函数,而这个函数的作用就在于检测用户是否产生了这一行为;


@@ -452,10 +455,12 @@ class Trainer(TrainerEventTrigger):
'batch_step_fn',为 False 时表示检测 'TrainBatchLoop'; 'batch_step_fn',为 False 时表示检测 'TrainBatchLoop';
""" """
if check_mode: if check_mode:
callbacks = ("on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad")
callbacks = ("on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step",
"on_before_zero_grad", "on_after_zero_grad")
else: else:
callbacks = ("on_fetch_data_begin", "on_fetch_data_end", "on_train_batch_begin", "on_train_batch_end", callbacks = ("on_fetch_data_begin", "on_fetch_data_end", "on_train_batch_begin", "on_train_batch_end",
"on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad")
"on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step",
"on_before_zero_grad", "on_after_zero_grad")
_not_called_callback_fns = [] _not_called_callback_fns = []
for each_callback_fn in callbacks: for each_callback_fn in callbacks:
if each_callback_fn in self.callback_manager.callback_fns: if each_callback_fn in self.callback_manager.callback_fns:
@@ -699,13 +704,15 @@ class Trainer(TrainerEventTrigger):


def zero_grad(self): def zero_grad(self):
if (self.global_forward_batches + 1) % self.accumulation_steps == 0: if (self.global_forward_batches + 1) % self.accumulation_steps == 0:
self.on_before_zero_grad(self.driver.optimizers)
self.on_before_zero_grad(self.optimizers)
self.driver.zero_grad(self.set_grad_to_none) self.driver.zero_grad(self.set_grad_to_none)
self.on_after_zero_grad(self.optimizers)


def step(self): def step(self):
if (self.global_forward_batches + 1) % self.accumulation_steps == 0: if (self.global_forward_batches + 1) % self.accumulation_steps == 0:
self.on_before_optimizer_step(self.driver.optimizers)
self.on_before_optimizers_step(self.optimizers)
self.driver.step() self.driver.step()
self.on_after_optimizers_step(self.optimizers)


def move_data_to_device(self, batch): def move_data_to_device(self, batch):
return self.driver.move_data_to_device(batch) return self.driver.move_data_to_device(batch)
@@ -817,3 +824,5 @@ class Trainer(TrainerEventTrigger):









+ 8
- 2
fastNLP/core/controllers/utils/utils.py View File

@@ -68,12 +68,18 @@ class TrainerEventTrigger:
def on_after_backward(self): def on_after_backward(self):
self.callback_manager.on_after_backward(self) self.callback_manager.on_after_backward(self)


def on_before_optimizer_step(self, optimizers):
self.callback_manager.on_before_optimizer_step(self, optimizers)
def on_before_optimizers_step(self, optimizers):
self.callback_manager.on_before_optimizers_step(self, optimizers)

def on_after_optimizers_step(self, optimizers):
self.callback_manager.on_after_optimizers_step(self, optimizers)


def on_before_zero_grad(self, optimizers): def on_before_zero_grad(self, optimizers):
self.callback_manager.on_before_zero_grad(self, optimizers) self.callback_manager.on_before_zero_grad(self, optimizers)


def on_after_zero_grad(self, optimizers):
self.callback_manager.on_after_zero_grad(self, optimizers)

def on_validate_begin(self): def on_validate_begin(self):
self.callback_manager.on_validate_begin(self) self.callback_manager.on_validate_begin(self)




+ 0
- 8
fastNLP/core/drivers/torch_driver/ddp.py View File

@@ -530,14 +530,6 @@ class TorchDDPDriver(TorchDriver):
else: else:
raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).")


def backward(self, loss):
self.grad_scaler.scale(loss).backward()

def step(self):
for optimizer in self.optimizers:
self.grad_scaler.step(optimizer)
self.grad_scaler.update()

def is_global_zero(self): def is_global_zero(self):
return self.global_rank == 0 return self.global_rank == 0




+ 0
- 8
fastNLP/core/drivers/torch_driver/single_device.py View File

@@ -107,14 +107,6 @@ class TorchSingleDriver(TorchDriver):
else: else:
return self._train_step(batch) return self._train_step(batch)


def backward(self, loss):
self.grad_scaler.scale(loss).backward()

def step(self):
for optimizer in self.optimizers:
self.grad_scaler.step(optimizer)
self.grad_scaler.update()

def validate_step(self, batch) -> Dict: def validate_step(self, batch) -> Dict:
# 因为我们 Tester 的逻辑就是将所有的 metric 传给 tester,然后 tester 控制具体 metric 的 update 和 compute;因此不管用户是否 # 因为我们 Tester 的逻辑就是将所有的 metric 传给 tester,然后 tester 控制具体 metric 的 update 和 compute;因此不管用户是否
# 实现 validate_step 函数,其都应该返回一个字典,具体使用哪些东西则是在 validate_batch_loop 中每一个具体的 metric 自己去拿的; # 实现 validate_step 函数,其都应该返回一个字典,具体使用哪些东西则是在 validate_batch_loop 中每一个具体的 metric 自己去拿的;


+ 8
- 0
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -72,6 +72,14 @@ class TorchDriver(Driver):
p.grad.requires_grad_(False) p.grad.requires_grad_(False)
p.grad.zero_() p.grad.zero_()


def backward(self, loss):
self.grad_scaler.scale(loss).backward()

def step(self):
for optimizer in self.optimizers:
self.grad_scaler.step(optimizer)
self.grad_scaler.update()

@staticmethod @staticmethod
def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False):
if is_train: if is_train:


+ 8
- 2
tests/helpers/callbacks/helper_callbacks.py View File

@@ -101,12 +101,18 @@ class RecordTrainerEventTriggerCallback(Callback):
def on_after_backward(self, trainer): def on_after_backward(self, trainer):
print("on_after_backward") print("on_after_backward")


def on_before_optimizer_step(self, trainer, optimizers):
print("on_before_optimizer_step")
def on_before_optimizers_step(self, trainer, optimizers):
print("on_before_optimizers_step")

def on_after_optimizers_step(self, trainer, optimizers):
print("on_after_optimizers_step")


def on_before_zero_grad(self, trainer, optimizers): def on_before_zero_grad(self, trainer, optimizers):
print("on_before_zero_grad") print("on_before_zero_grad")


def on_after_zero_grad(self, trainer, optimizers):
print("on_after_zero_grad")

def on_validate_begin(self, trainer): def on_validate_begin(self, trainer):
print("on_validate_begin") print("on_validate_begin")




Loading…
Cancel
Save