Browse Source

paddle 不再为非fp16的driver加载fp16状态

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
06551964a3
3 changed files with 13 additions and 15 deletions
  1. +9
    -11
      fastNLP/core/drivers/paddle_driver/paddle_driver.py
  2. +2
    -2
      tests/core/drivers/paddle_driver/test_fleet.py
  3. +2
    -2
      tests/core/drivers/paddle_driver/test_single_device.py

+ 9
- 11
fastNLP/core/drivers/paddle_driver/paddle_driver.py View File

@@ -4,6 +4,7 @@ from typing import Union, Optional, Dict, Any
from pathlib import Path
from functools import partial
from dataclasses import dataclass
from jittor import grad

import numpy as np

@@ -317,17 +318,14 @@ class PaddleDriver(Driver):
logger.debug("Load model...")

# 3. 加载fp16的状态;
if "grad_scaler_state_dict" in states:
grad_scaler_state_dict = states.pop("grad_scaler_state_dict")
if isinstance(self.grad_scaler, DummyGradScaler):
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=False)
self.grad_scaler = _grad_scaler()
self.fp16 = True
self.grad_scaler.load_state_dict(grad_scaler_state_dict)
logger.debug("Load grad_scaler state dict...")
elif not isinstance(self.grad_scaler, DummyGradScaler):
logger.rank_zero_warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, "
f"the training process may be unstable.")
grad_scaler_state_dict = states.pop("grad_scaler_state_dict", None)
if self.fp16:
if grad_scaler_state_dict:
self.grad_scaler.load_state_dict(grad_scaler_state_dict)
logger.debug("Load grad_scaler state dict...")
else:
logger.rank_zero_warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, "
f"the training process may be unstable.")

# 4. 恢复 sampler 的状态;
dataloader_args = self.get_dataloader_args(dataloader)


+ 2
- 2
tests/core/drivers/paddle_driver/test_fleet.py View File

@@ -661,7 +661,7 @@ class TestSaveLoad:

# 3. 检查 fp16 是否被加载
if fp16:
assert isinstance(self.driver2.grad_scaler, paddle.amp.GradScaler)
assert not isinstance(self.driver2.grad_scaler, paddle.amp.GradScaler)

# 4. 检查 model 的参数是否正确
# 5. 检查 batch_idx
@@ -771,7 +771,7 @@ class TestSaveLoad:
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"]
# 3. 检查 fp16 是否被加载
if fp16:
assert isinstance(self.driver2.grad_scaler, paddle.amp.GradScaler)
assert not isinstance(self.driver2.grad_scaler, paddle.amp.GradScaler)

# 4. 检查 model 的参数是否正确
# 5. 检查 batch_idx


+ 2
- 2
tests/core/drivers/paddle_driver/test_single_device.py View File

@@ -632,7 +632,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):

# 3. 检查 fp16 是否被加载
if fp16:
assert isinstance(driver2.grad_scaler, paddle.amp.GradScaler)
assert not isinstance(driver2.grad_scaler, paddle.amp.GradScaler)


# 4. 检查 model 的参数是否正确
@@ -720,7 +720,7 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16):

# 3. 检查 fp16 是否被加载
if fp16:
assert isinstance(driver2.grad_scaler, paddle.amp.GradScaler)
assert not isinstance(driver2.grad_scaler, paddle.amp.GradScaler)

# 4. 检查 model 的参数是否正确
# 5. 检查 batch_idx


Loading…
Cancel
Save