@@ -184,7 +184,7 @@ class Callback: | |||
""" | |||
pass | |||
def on_before_optimizer_step(self, trainer, optimizers): | |||
def on_before_optimizers_step(self, trainer, optimizers): | |||
""" | |||
在进行 optimizer 优化进行前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | |||
@@ -194,6 +194,16 @@ class Callback: | |||
""" | |||
pass | |||
def on_after_optimizers_step(self, trainer, optimizers): | |||
""" | |||
在进行 optimizer 优化进行后调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | |||
:param trainer: | |||
:param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。 | |||
:return: | |||
""" | |||
pass | |||
def on_before_zero_grad(self, trainer, optimizers): | |||
""" | |||
在进行模型梯度置零前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | |||
@@ -204,6 +214,16 @@ class Callback: | |||
""" | |||
pass | |||
def on_after_zero_grad(self, trainer, optimizers): | |||
""" | |||
在进行模型梯度置零后调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | |||
:param trainer: | |||
:param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。 | |||
:return: | |||
""" | |||
pass | |||
def on_validate_begin(self, trainer): | |||
""" | |||
在将要进行 validate 时调用。如果是设置的以 step 数量 或 自定义地 决定 validate 的频率,该接口是在 on_train_batch_end 之后 | |||
@@ -92,8 +92,10 @@ class Events(EventEnum): | |||
ON_LOAD_CHECKPOINT = "on_load_checkpoint" | |||
ON_BEFORE_BACKWARD = "on_before_backward" | |||
ON_AFTER_BACKWARD = "on_after_backward" | |||
ON_BEFORE_OPTIMIZER_STEP = "on_before_optimizer_step" | |||
ON_BEFORE_OPTIMIZERS_STEP = "on_before_optimizers_step" | |||
ON_AFTER_OPTIMIZERS_STEP = "on_after_optimizers_step" | |||
ON_BEFORE_ZERO_GRAD = "on_before_zero_grad" | |||
ON_AFTER_ZERO_GRAD = "on_after_zero_grad" | |||
ON_VALIDATE_BEGIN = "on_validate_begin" | |||
ON_VALIDATE_END = "on_validate_end" | |||
@@ -278,13 +278,21 @@ class CallbackManager: | |||
pass | |||
@_transfer | |||
def on_before_optimizer_step(self, trainer, optimizers): | |||
def on_before_optimizers_step(self, trainer, optimizers): | |||
pass | |||
@_transfer | |||
def on_after_optimizers_step(self, trainer, optimizers): | |||
pass | |||
@_transfer | |||
def on_before_zero_grad(self, trainer, optimizers): | |||
pass | |||
@_transfer | |||
def on_after_zero_grad(self, trainer, optimizers): | |||
pass | |||
@_transfer | |||
def on_validate_begin(self, trainer): | |||
pass | |||
@@ -137,6 +137,7 @@ class Trainer(TrainerEventTrigger): | |||
else: | |||
self.driver_name = driver.__class__.__name__ | |||
self.device = device | |||
self.optimizers = optimizers | |||
self.fp16 = fp16 | |||
self.input_mapping = input_mapping | |||
self.output_mapping = output_mapping | |||
@@ -440,9 +441,11 @@ class Trainer(TrainerEventTrigger): | |||
2. 函数作用 | |||
这一函数的作用在于检查用户定制的 batch_step_fn / TrainBatchLoop 是否能够正确地调用 callback 函数,更准确地说,当用户实际 | |||
定制了 ("on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad") / | |||
定制了 ("on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", "on_before_zero_grad", | |||
"on_after_zero_grad") / | |||
("on_fetch_data_begin", "on_fetch_data_end", "on_train_batch_begin", "on_train_batch_end", | |||
"on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad") | |||
"on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", "on_before_zero_grad", | |||
"on_after_zero_grad") | |||
这些 callabck_fn 后,如果其同样也定制了 batch_step_fn / TrainBatchLoop,那么其有可能忘记了在自己的 batch_step_fn 中 | |||
上述的这些 callback 函数,而这个函数的作用就在于检测用户是否产生了这一行为; | |||
@@ -452,10 +455,12 @@ class Trainer(TrainerEventTrigger): | |||
'batch_step_fn',为 False 时表示检测 'TrainBatchLoop'; | |||
""" | |||
if check_mode: | |||
callbacks = ("on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad") | |||
callbacks = ("on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", | |||
"on_before_zero_grad", "on_after_zero_grad") | |||
else: | |||
callbacks = ("on_fetch_data_begin", "on_fetch_data_end", "on_train_batch_begin", "on_train_batch_end", | |||
"on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad") | |||
"on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", | |||
"on_before_zero_grad", "on_after_zero_grad") | |||
_not_called_callback_fns = [] | |||
for each_callback_fn in callbacks: | |||
if each_callback_fn in self.callback_manager.callback_fns: | |||
@@ -699,13 +704,15 @@ class Trainer(TrainerEventTrigger): | |||
def zero_grad(self): | |||
if (self.global_forward_batches + 1) % self.accumulation_steps == 0: | |||
self.on_before_zero_grad(self.driver.optimizers) | |||
self.on_before_zero_grad(self.optimizers) | |||
self.driver.zero_grad(self.set_grad_to_none) | |||
self.on_after_zero_grad(self.optimizers) | |||
def step(self): | |||
if (self.global_forward_batches + 1) % self.accumulation_steps == 0: | |||
self.on_before_optimizer_step(self.driver.optimizers) | |||
self.on_before_optimizers_step(self.optimizers) | |||
self.driver.step() | |||
self.on_after_optimizers_step(self.optimizers) | |||
def move_data_to_device(self, batch): | |||
return self.driver.move_data_to_device(batch) | |||
@@ -817,3 +824,5 @@ class Trainer(TrainerEventTrigger): | |||
@@ -68,12 +68,18 @@ class TrainerEventTrigger: | |||
def on_after_backward(self): | |||
self.callback_manager.on_after_backward(self) | |||
def on_before_optimizer_step(self, optimizers): | |||
self.callback_manager.on_before_optimizer_step(self, optimizers) | |||
def on_before_optimizers_step(self, optimizers): | |||
self.callback_manager.on_before_optimizers_step(self, optimizers) | |||
def on_after_optimizers_step(self, optimizers): | |||
self.callback_manager.on_after_optimizers_step(self, optimizers) | |||
def on_before_zero_grad(self, optimizers): | |||
self.callback_manager.on_before_zero_grad(self, optimizers) | |||
def on_after_zero_grad(self, optimizers): | |||
self.callback_manager.on_after_zero_grad(self, optimizers) | |||
def on_validate_begin(self): | |||
self.callback_manager.on_validate_begin(self) | |||
@@ -530,14 +530,6 @@ class TorchDDPDriver(TorchDriver): | |||
else: | |||
raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") | |||
def backward(self, loss): | |||
self.grad_scaler.scale(loss).backward() | |||
def step(self): | |||
for optimizer in self.optimizers: | |||
self.grad_scaler.step(optimizer) | |||
self.grad_scaler.update() | |||
def is_global_zero(self): | |||
return self.global_rank == 0 | |||
@@ -107,14 +107,6 @@ class TorchSingleDriver(TorchDriver): | |||
else: | |||
return self._train_step(batch) | |||
def backward(self, loss): | |||
self.grad_scaler.scale(loss).backward() | |||
def step(self): | |||
for optimizer in self.optimizers: | |||
self.grad_scaler.step(optimizer) | |||
self.grad_scaler.update() | |||
def validate_step(self, batch) -> Dict: | |||
# 因为我们 Tester 的逻辑就是将所有的 metric 传给 tester,然后 tester 控制具体 metric 的 update 和 compute;因此不管用户是否 | |||
# 实现 validate_step 函数,其都应该返回一个字典,具体使用哪些东西则是在 validate_batch_loop 中每一个具体的 metric 自己去拿的; | |||
@@ -72,6 +72,14 @@ class TorchDriver(Driver): | |||
p.grad.requires_grad_(False) | |||
p.grad.zero_() | |||
def backward(self, loss): | |||
self.grad_scaler.scale(loss).backward() | |||
def step(self): | |||
for optimizer in self.optimizers: | |||
self.grad_scaler.step(optimizer) | |||
self.grad_scaler.update() | |||
@staticmethod | |||
def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | |||
if is_train: | |||
@@ -101,12 +101,18 @@ class RecordTrainerEventTriggerCallback(Callback): | |||
def on_after_backward(self, trainer): | |||
print("on_after_backward") | |||
def on_before_optimizer_step(self, trainer, optimizers): | |||
print("on_before_optimizer_step") | |||
def on_before_optimizers_step(self, trainer, optimizers): | |||
print("on_before_optimizers_step") | |||
def on_after_optimizers_step(self, trainer, optimizers): | |||
print("on_after_optimizers_step") | |||
def on_before_zero_grad(self, trainer, optimizers): | |||
print("on_before_zero_grad") | |||
def on_after_zero_grad(self, trainer, optimizers): | |||
print("on_after_zero_grad") | |||
def on_validate_begin(self, trainer): | |||
print("on_validate_begin") | |||