diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index 606bec03..39fed874 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -4,6 +4,7 @@ from typing import Union, Optional, Dict, Any from pathlib import Path from functools import partial from dataclasses import dataclass +from jittor import grad import numpy as np @@ -317,17 +318,14 @@ class PaddleDriver(Driver): logger.debug("Load model...") # 3. 加载fp16的状态; - if "grad_scaler_state_dict" in states: - grad_scaler_state_dict = states.pop("grad_scaler_state_dict") - if isinstance(self.grad_scaler, DummyGradScaler): - self.auto_cast, _grad_scaler = _build_fp16_env(dummy=False) - self.grad_scaler = _grad_scaler() - self.fp16 = True - self.grad_scaler.load_state_dict(grad_scaler_state_dict) - logger.debug("Load grad_scaler state dict...") - elif not isinstance(self.grad_scaler, DummyGradScaler): - logger.rank_zero_warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, " - f"the training process may be unstable.") + grad_scaler_state_dict = states.pop("grad_scaler_state_dict", None) + if self.fp16: + if grad_scaler_state_dict: + self.grad_scaler.load_state_dict(grad_scaler_state_dict) + logger.debug("Load grad_scaler state dict...") + else: + logger.rank_zero_warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, " + f"the training process may be unstable.") # 4. 恢复 sampler 的状态; dataloader_args = self.get_dataloader_args(dataloader) diff --git a/tests/core/drivers/paddle_driver/test_fleet.py b/tests/core/drivers/paddle_driver/test_fleet.py index ad680dcb..5f90ed12 100644 --- a/tests/core/drivers/paddle_driver/test_fleet.py +++ b/tests/core/drivers/paddle_driver/test_fleet.py @@ -661,7 +661,7 @@ class TestSaveLoad: # 3. 检查 fp16 是否被加载 if fp16: - assert isinstance(self.driver2.grad_scaler, paddle.amp.GradScaler) + assert not isinstance(self.driver2.grad_scaler, paddle.amp.GradScaler) # 4. 检查 model 的参数是否正确 # 5. 检查 batch_idx @@ -771,7 +771,7 @@ class TestSaveLoad: assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] # 3. 检查 fp16 是否被加载 if fp16: - assert isinstance(self.driver2.grad_scaler, paddle.amp.GradScaler) + assert not isinstance(self.driver2.grad_scaler, paddle.amp.GradScaler) # 4. 检查 model 的参数是否正确 # 5. 检查 batch_idx diff --git a/tests/core/drivers/paddle_driver/test_single_device.py b/tests/core/drivers/paddle_driver/test_single_device.py index 67ea1b42..9b7a8560 100644 --- a/tests/core/drivers/paddle_driver/test_single_device.py +++ b/tests/core/drivers/paddle_driver/test_single_device.py @@ -632,7 +632,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): # 3. 检查 fp16 是否被加载 if fp16: - assert isinstance(driver2.grad_scaler, paddle.amp.GradScaler) + assert not isinstance(driver2.grad_scaler, paddle.amp.GradScaler) # 4. 检查 model 的参数是否正确 @@ -720,7 +720,7 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16): # 3. 检查 fp16 是否被加载 if fp16: - assert isinstance(driver2.grad_scaler, paddle.amp.GradScaler) + assert not isinstance(driver2.grad_scaler, paddle.amp.GradScaler) # 4. 检查 model 的参数是否正确 # 5. 检查 batch_idx