From 06551964a35a73096b450c412f3d1b760b784817 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Fri, 20 May 2022 07:27:48 +0000 Subject: [PATCH] =?UTF-8?q?paddle=20=E4=B8=8D=E5=86=8D=E4=B8=BA=E9=9D=9Efp?= =?UTF-8?q?16=E7=9A=84driver=E5=8A=A0=E8=BD=BDfp16=E7=8A=B6=E6=80=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../drivers/paddle_driver/paddle_driver.py | 20 +++++++++---------- .../core/drivers/paddle_driver/test_fleet.py | 4 ++-- .../paddle_driver/test_single_device.py | 4 ++-- 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index 606bec03..39fed874 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -4,6 +4,7 @@ from typing import Union, Optional, Dict, Any from pathlib import Path from functools import partial from dataclasses import dataclass +from jittor import grad import numpy as np @@ -317,17 +318,14 @@ class PaddleDriver(Driver): logger.debug("Load model...") # 3. 加载fp16的状态; - if "grad_scaler_state_dict" in states: - grad_scaler_state_dict = states.pop("grad_scaler_state_dict") - if isinstance(self.grad_scaler, DummyGradScaler): - self.auto_cast, _grad_scaler = _build_fp16_env(dummy=False) - self.grad_scaler = _grad_scaler() - self.fp16 = True - self.grad_scaler.load_state_dict(grad_scaler_state_dict) - logger.debug("Load grad_scaler state dict...") - elif not isinstance(self.grad_scaler, DummyGradScaler): - logger.rank_zero_warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, " - f"the training process may be unstable.") + grad_scaler_state_dict = states.pop("grad_scaler_state_dict", None) + if self.fp16: + if grad_scaler_state_dict: + self.grad_scaler.load_state_dict(grad_scaler_state_dict) + logger.debug("Load grad_scaler state dict...") + else: + logger.rank_zero_warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, " + f"the training process may be unstable.") # 4. 恢复 sampler 的状态; dataloader_args = self.get_dataloader_args(dataloader) diff --git a/tests/core/drivers/paddle_driver/test_fleet.py b/tests/core/drivers/paddle_driver/test_fleet.py index ad680dcb..5f90ed12 100644 --- a/tests/core/drivers/paddle_driver/test_fleet.py +++ b/tests/core/drivers/paddle_driver/test_fleet.py @@ -661,7 +661,7 @@ class TestSaveLoad: # 3. 检查 fp16 是否被加载 if fp16: - assert isinstance(self.driver2.grad_scaler, paddle.amp.GradScaler) + assert not isinstance(self.driver2.grad_scaler, paddle.amp.GradScaler) # 4. 检查 model 的参数是否正确 # 5. 检查 batch_idx @@ -771,7 +771,7 @@ class TestSaveLoad: assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] # 3. 检查 fp16 是否被加载 if fp16: - assert isinstance(self.driver2.grad_scaler, paddle.amp.GradScaler) + assert not isinstance(self.driver2.grad_scaler, paddle.amp.GradScaler) # 4. 检查 model 的参数是否正确 # 5. 检查 batch_idx diff --git a/tests/core/drivers/paddle_driver/test_single_device.py b/tests/core/drivers/paddle_driver/test_single_device.py index 67ea1b42..9b7a8560 100644 --- a/tests/core/drivers/paddle_driver/test_single_device.py +++ b/tests/core/drivers/paddle_driver/test_single_device.py @@ -632,7 +632,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): # 3. 检查 fp16 是否被加载 if fp16: - assert isinstance(driver2.grad_scaler, paddle.amp.GradScaler) + assert not isinstance(driver2.grad_scaler, paddle.amp.GradScaler) # 4. 检查 model 的参数是否正确 @@ -720,7 +720,7 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16): # 3. 检查 fp16 是否被加载 if fp16: - assert isinstance(driver2.grad_scaler, paddle.amp.GradScaler) + assert not isinstance(driver2.grad_scaler, paddle.amp.GradScaler) # 4. 检查 model 的参数是否正确 # 5. 检查 batch_idx