|
|
@@ -259,6 +259,8 @@ class TorchDriver(Driver): |
|
|
|
grad_scaler_state_dict = states.pop('grad_scaler_state_dict') |
|
|
|
if not isinstance(self.grad_scaler, DummyGradScaler): |
|
|
|
self.grad_scaler.load_state_dict(grad_scaler_state_dict) |
|
|
|
self.auto_cast = torch.cuda.amp.autocast |
|
|
|
self.fp16 = True |
|
|
|
logger.debug("Load grad_scaler state dict...") |
|
|
|
elif not isinstance(self.grad_scaler, DummyGradScaler): |
|
|
|
logger.warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, " |
|
|
|