|
|
@@ -41,10 +41,9 @@ class TorchWarmupCallback(Callback): |
|
|
|
return max((progress - 1.) / (self.warmup - 1.), 0.) |
|
|
|
|
|
|
|
def on_train_begin(self, trainer): |
|
|
|
self.t_steps = trainer.n_batches |
|
|
|
if self.warmup >1: |
|
|
|
self.warmup = self.warmup / self.t_steps |
|
|
|
self.t_steps = max(2, self.t_steps) # 不能小于2 |
|
|
|
self.warmup = self.warmup / trainer.n_batches |
|
|
|
self.t_steps = max(2, trainer.n_batches) # 不能小于2 |
|
|
|
# 防止 t_steps 不能整除 accumulation_steps |
|
|
|
self.t_steps = math.ceil(self.t_steps/trainer.accumulation_steps) * trainer.accumulation_steps |
|
|
|
# 获取param_group的初始learning rate |
|
|
|