@@ -184,7 +184,7 @@ class Callback: | |||||
""" | """ | ||||
pass | pass | ||||
def on_before_optimizer_step(self, trainer, optimizers): | |||||
def on_before_optimizers_step(self, trainer, optimizers): | |||||
""" | """ | ||||
在进行 optimizer 优化进行前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | 在进行 optimizer 优化进行前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | ||||
@@ -194,6 +194,16 @@ class Callback: | |||||
""" | """ | ||||
pass | pass | ||||
def on_after_optimizers_step(self, trainer, optimizers): | |||||
""" | |||||
在进行 optimizer 优化进行后调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | |||||
:param trainer: | |||||
:param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。 | |||||
:return: | |||||
""" | |||||
pass | |||||
def on_before_zero_grad(self, trainer, optimizers): | def on_before_zero_grad(self, trainer, optimizers): | ||||
""" | """ | ||||
在进行模型梯度置零前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | 在进行模型梯度置零前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | ||||
@@ -204,6 +214,16 @@ class Callback: | |||||
""" | """ | ||||
pass | pass | ||||
def on_after_zero_grad(self, trainer, optimizers): | |||||
""" | |||||
在进行模型梯度置零后调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | |||||
:param trainer: | |||||
:param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。 | |||||
:return: | |||||
""" | |||||
pass | |||||
def on_validate_begin(self, trainer): | def on_validate_begin(self, trainer): | ||||
""" | """ | ||||
在将要进行 validate 时调用。如果是设置的以 step 数量 或 自定义地 决定 validate 的频率,该接口是在 on_train_batch_end 之后 | 在将要进行 validate 时调用。如果是设置的以 step 数量 或 自定义地 决定 validate 的频率,该接口是在 on_train_batch_end 之后 | ||||
@@ -92,8 +92,10 @@ class Events(EventEnum): | |||||
ON_LOAD_CHECKPOINT = "on_load_checkpoint" | ON_LOAD_CHECKPOINT = "on_load_checkpoint" | ||||
ON_BEFORE_BACKWARD = "on_before_backward" | ON_BEFORE_BACKWARD = "on_before_backward" | ||||
ON_AFTER_BACKWARD = "on_after_backward" | ON_AFTER_BACKWARD = "on_after_backward" | ||||
ON_BEFORE_OPTIMIZER_STEP = "on_before_optimizer_step" | |||||
ON_BEFORE_OPTIMIZERS_STEP = "on_before_optimizers_step" | |||||
ON_AFTER_OPTIMIZERS_STEP = "on_after_optimizers_step" | |||||
ON_BEFORE_ZERO_GRAD = "on_before_zero_grad" | ON_BEFORE_ZERO_GRAD = "on_before_zero_grad" | ||||
ON_AFTER_ZERO_GRAD = "on_after_zero_grad" | |||||
ON_VALIDATE_BEGIN = "on_validate_begin" | ON_VALIDATE_BEGIN = "on_validate_begin" | ||||
ON_VALIDATE_END = "on_validate_end" | ON_VALIDATE_END = "on_validate_end" | ||||
@@ -278,13 +278,21 @@ class CallbackManager: | |||||
pass | pass | ||||
@_transfer | @_transfer | ||||
def on_before_optimizer_step(self, trainer, optimizers): | |||||
def on_before_optimizers_step(self, trainer, optimizers): | |||||
pass | |||||
@_transfer | |||||
def on_after_optimizers_step(self, trainer, optimizers): | |||||
pass | pass | ||||
@_transfer | @_transfer | ||||
def on_before_zero_grad(self, trainer, optimizers): | def on_before_zero_grad(self, trainer, optimizers): | ||||
pass | pass | ||||
@_transfer | |||||
def on_after_zero_grad(self, trainer, optimizers): | |||||
pass | |||||
@_transfer | @_transfer | ||||
def on_validate_begin(self, trainer): | def on_validate_begin(self, trainer): | ||||
pass | pass | ||||
@@ -137,6 +137,7 @@ class Trainer(TrainerEventTrigger): | |||||
else: | else: | ||||
self.driver_name = driver.__class__.__name__ | self.driver_name = driver.__class__.__name__ | ||||
self.device = device | self.device = device | ||||
self.optimizers = optimizers | |||||
self.fp16 = fp16 | self.fp16 = fp16 | ||||
self.input_mapping = input_mapping | self.input_mapping = input_mapping | ||||
self.output_mapping = output_mapping | self.output_mapping = output_mapping | ||||
@@ -440,9 +441,11 @@ class Trainer(TrainerEventTrigger): | |||||
2. 函数作用 | 2. 函数作用 | ||||
这一函数的作用在于检查用户定制的 batch_step_fn / TrainBatchLoop 是否能够正确地调用 callback 函数,更准确地说,当用户实际 | 这一函数的作用在于检查用户定制的 batch_step_fn / TrainBatchLoop 是否能够正确地调用 callback 函数,更准确地说,当用户实际 | ||||
定制了 ("on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad") / | |||||
定制了 ("on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", "on_before_zero_grad", | |||||
"on_after_zero_grad") / | |||||
("on_fetch_data_begin", "on_fetch_data_end", "on_train_batch_begin", "on_train_batch_end", | ("on_fetch_data_begin", "on_fetch_data_end", "on_train_batch_begin", "on_train_batch_end", | ||||
"on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad") | |||||
"on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", "on_before_zero_grad", | |||||
"on_after_zero_grad") | |||||
这些 callabck_fn 后,如果其同样也定制了 batch_step_fn / TrainBatchLoop,那么其有可能忘记了在自己的 batch_step_fn 中 | 这些 callabck_fn 后,如果其同样也定制了 batch_step_fn / TrainBatchLoop,那么其有可能忘记了在自己的 batch_step_fn 中 | ||||
上述的这些 callback 函数,而这个函数的作用就在于检测用户是否产生了这一行为; | 上述的这些 callback 函数,而这个函数的作用就在于检测用户是否产生了这一行为; | ||||
@@ -452,10 +455,12 @@ class Trainer(TrainerEventTrigger): | |||||
'batch_step_fn',为 False 时表示检测 'TrainBatchLoop'; | 'batch_step_fn',为 False 时表示检测 'TrainBatchLoop'; | ||||
""" | """ | ||||
if check_mode: | if check_mode: | ||||
callbacks = ("on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad") | |||||
callbacks = ("on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", | |||||
"on_before_zero_grad", "on_after_zero_grad") | |||||
else: | else: | ||||
callbacks = ("on_fetch_data_begin", "on_fetch_data_end", "on_train_batch_begin", "on_train_batch_end", | callbacks = ("on_fetch_data_begin", "on_fetch_data_end", "on_train_batch_begin", "on_train_batch_end", | ||||
"on_before_backward", "on_after_backward", "on_before_optimizer_step", "on_before_zero_grad") | |||||
"on_before_backward", "on_after_backward", "on_before_optimizers_step", "on_after_optimizers_step", | |||||
"on_before_zero_grad", "on_after_zero_grad") | |||||
_not_called_callback_fns = [] | _not_called_callback_fns = [] | ||||
for each_callback_fn in callbacks: | for each_callback_fn in callbacks: | ||||
if each_callback_fn in self.callback_manager.callback_fns: | if each_callback_fn in self.callback_manager.callback_fns: | ||||
@@ -699,13 +704,15 @@ class Trainer(TrainerEventTrigger): | |||||
def zero_grad(self): | def zero_grad(self): | ||||
if (self.global_forward_batches + 1) % self.accumulation_steps == 0: | if (self.global_forward_batches + 1) % self.accumulation_steps == 0: | ||||
self.on_before_zero_grad(self.driver.optimizers) | |||||
self.on_before_zero_grad(self.optimizers) | |||||
self.driver.zero_grad(self.set_grad_to_none) | self.driver.zero_grad(self.set_grad_to_none) | ||||
self.on_after_zero_grad(self.optimizers) | |||||
def step(self): | def step(self): | ||||
if (self.global_forward_batches + 1) % self.accumulation_steps == 0: | if (self.global_forward_batches + 1) % self.accumulation_steps == 0: | ||||
self.on_before_optimizer_step(self.driver.optimizers) | |||||
self.on_before_optimizers_step(self.optimizers) | |||||
self.driver.step() | self.driver.step() | ||||
self.on_after_optimizers_step(self.optimizers) | |||||
def move_data_to_device(self, batch): | def move_data_to_device(self, batch): | ||||
return self.driver.move_data_to_device(batch) | return self.driver.move_data_to_device(batch) | ||||
@@ -817,3 +824,5 @@ class Trainer(TrainerEventTrigger): | |||||
@@ -68,12 +68,18 @@ class TrainerEventTrigger: | |||||
def on_after_backward(self): | def on_after_backward(self): | ||||
self.callback_manager.on_after_backward(self) | self.callback_manager.on_after_backward(self) | ||||
def on_before_optimizer_step(self, optimizers): | |||||
self.callback_manager.on_before_optimizer_step(self, optimizers) | |||||
def on_before_optimizers_step(self, optimizers): | |||||
self.callback_manager.on_before_optimizers_step(self, optimizers) | |||||
def on_after_optimizers_step(self, optimizers): | |||||
self.callback_manager.on_after_optimizers_step(self, optimizers) | |||||
def on_before_zero_grad(self, optimizers): | def on_before_zero_grad(self, optimizers): | ||||
self.callback_manager.on_before_zero_grad(self, optimizers) | self.callback_manager.on_before_zero_grad(self, optimizers) | ||||
def on_after_zero_grad(self, optimizers): | |||||
self.callback_manager.on_after_zero_grad(self, optimizers) | |||||
def on_validate_begin(self): | def on_validate_begin(self): | ||||
self.callback_manager.on_validate_begin(self) | self.callback_manager.on_validate_begin(self) | ||||
@@ -530,14 +530,6 @@ class TorchDDPDriver(TorchDriver): | |||||
else: | else: | ||||
raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") | raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") | ||||
def backward(self, loss): | |||||
self.grad_scaler.scale(loss).backward() | |||||
def step(self): | |||||
for optimizer in self.optimizers: | |||||
self.grad_scaler.step(optimizer) | |||||
self.grad_scaler.update() | |||||
def is_global_zero(self): | def is_global_zero(self): | ||||
return self.global_rank == 0 | return self.global_rank == 0 | ||||
@@ -107,14 +107,6 @@ class TorchSingleDriver(TorchDriver): | |||||
else: | else: | ||||
return self._train_step(batch) | return self._train_step(batch) | ||||
def backward(self, loss): | |||||
self.grad_scaler.scale(loss).backward() | |||||
def step(self): | |||||
for optimizer in self.optimizers: | |||||
self.grad_scaler.step(optimizer) | |||||
self.grad_scaler.update() | |||||
def validate_step(self, batch) -> Dict: | def validate_step(self, batch) -> Dict: | ||||
# 因为我们 Tester 的逻辑就是将所有的 metric 传给 tester,然后 tester 控制具体 metric 的 update 和 compute;因此不管用户是否 | # 因为我们 Tester 的逻辑就是将所有的 metric 传给 tester,然后 tester 控制具体 metric 的 update 和 compute;因此不管用户是否 | ||||
# 实现 validate_step 函数,其都应该返回一个字典,具体使用哪些东西则是在 validate_batch_loop 中每一个具体的 metric 自己去拿的; | # 实现 validate_step 函数,其都应该返回一个字典,具体使用哪些东西则是在 validate_batch_loop 中每一个具体的 metric 自己去拿的; | ||||
@@ -72,6 +72,14 @@ class TorchDriver(Driver): | |||||
p.grad.requires_grad_(False) | p.grad.requires_grad_(False) | ||||
p.grad.zero_() | p.grad.zero_() | ||||
def backward(self, loss): | |||||
self.grad_scaler.scale(loss).backward() | |||||
def step(self): | |||||
for optimizer in self.optimizers: | |||||
self.grad_scaler.step(optimizer) | |||||
self.grad_scaler.update() | |||||
@staticmethod | @staticmethod | ||||
def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | ||||
if is_train: | if is_train: | ||||
@@ -101,12 +101,18 @@ class RecordTrainerEventTriggerCallback(Callback): | |||||
def on_after_backward(self, trainer): | def on_after_backward(self, trainer): | ||||
print("on_after_backward") | print("on_after_backward") | ||||
def on_before_optimizer_step(self, trainer, optimizers): | |||||
print("on_before_optimizer_step") | |||||
def on_before_optimizers_step(self, trainer, optimizers): | |||||
print("on_before_optimizers_step") | |||||
def on_after_optimizers_step(self, trainer, optimizers): | |||||
print("on_after_optimizers_step") | |||||
def on_before_zero_grad(self, trainer, optimizers): | def on_before_zero_grad(self, trainer, optimizers): | ||||
print("on_before_zero_grad") | print("on_before_zero_grad") | ||||
def on_after_zero_grad(self, trainer, optimizers): | |||||
print("on_after_zero_grad") | |||||
def on_validate_begin(self, trainer): | def on_validate_begin(self, trainer): | ||||
print("on_validate_begin") | print("on_validate_begin") | ||||