Browse Source

paddle 不再为非fp16的driver加载fp16状态

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
06551964a3
3 changed files with 13 additions and 15 deletions
  1. +9
    -11
      fastNLP/core/drivers/paddle_driver/paddle_driver.py
  2. +2
    -2
      tests/core/drivers/paddle_driver/test_fleet.py
  3. +2
    -2
      tests/core/drivers/paddle_driver/test_single_device.py

+ 9
- 11
fastNLP/core/drivers/paddle_driver/paddle_driver.py View File

@@ -4,6 +4,7 @@ from typing import Union, Optional, Dict, Any
from pathlib import Path from pathlib import Path
from functools import partial from functools import partial
from dataclasses import dataclass from dataclasses import dataclass
from jittor import grad


import numpy as np import numpy as np


@@ -317,17 +318,14 @@ class PaddleDriver(Driver):
logger.debug("Load model...") logger.debug("Load model...")


# 3. 加载fp16的状态; # 3. 加载fp16的状态;
if "grad_scaler_state_dict" in states:
grad_scaler_state_dict = states.pop("grad_scaler_state_dict")
if isinstance(self.grad_scaler, DummyGradScaler):
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=False)
self.grad_scaler = _grad_scaler()
self.fp16 = True
self.grad_scaler.load_state_dict(grad_scaler_state_dict)
logger.debug("Load grad_scaler state dict...")
elif not isinstance(self.grad_scaler, DummyGradScaler):
logger.rank_zero_warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, "
f"the training process may be unstable.")
grad_scaler_state_dict = states.pop("grad_scaler_state_dict", None)
if self.fp16:
if grad_scaler_state_dict:
self.grad_scaler.load_state_dict(grad_scaler_state_dict)
logger.debug("Load grad_scaler state dict...")
else:
logger.rank_zero_warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, "
f"the training process may be unstable.")


# 4. 恢复 sampler 的状态; # 4. 恢复 sampler 的状态;
dataloader_args = self.get_dataloader_args(dataloader) dataloader_args = self.get_dataloader_args(dataloader)


+ 2
- 2
tests/core/drivers/paddle_driver/test_fleet.py View File

@@ -661,7 +661,7 @@ class TestSaveLoad:


# 3. 检查 fp16 是否被加载 # 3. 检查 fp16 是否被加载
if fp16: if fp16:
assert isinstance(self.driver2.grad_scaler, paddle.amp.GradScaler)
assert not isinstance(self.driver2.grad_scaler, paddle.amp.GradScaler)


# 4. 检查 model 的参数是否正确 # 4. 检查 model 的参数是否正确
# 5. 检查 batch_idx # 5. 检查 batch_idx
@@ -771,7 +771,7 @@ class TestSaveLoad:
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"]
# 3. 检查 fp16 是否被加载 # 3. 检查 fp16 是否被加载
if fp16: if fp16:
assert isinstance(self.driver2.grad_scaler, paddle.amp.GradScaler)
assert not isinstance(self.driver2.grad_scaler, paddle.amp.GradScaler)


# 4. 检查 model 的参数是否正确 # 4. 检查 model 的参数是否正确
# 5. 检查 batch_idx # 5. 检查 batch_idx


+ 2
- 2
tests/core/drivers/paddle_driver/test_single_device.py View File

@@ -632,7 +632,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):


# 3. 检查 fp16 是否被加载 # 3. 检查 fp16 是否被加载
if fp16: if fp16:
assert isinstance(driver2.grad_scaler, paddle.amp.GradScaler)
assert not isinstance(driver2.grad_scaler, paddle.amp.GradScaler)




# 4. 检查 model 的参数是否正确 # 4. 检查 model 的参数是否正确
@@ -720,7 +720,7 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16):


# 3. 检查 fp16 是否被加载 # 3. 检查 fp16 是否被加载
if fp16: if fp16:
assert isinstance(driver2.grad_scaler, paddle.amp.GradScaler)
assert not isinstance(driver2.grad_scaler, paddle.amp.GradScaler)


# 4. 检查 model 的参数是否正确 # 4. 检查 model 的参数是否正确
# 5. 检查 batch_idx # 5. 检查 batch_idx


Loading…
Cancel
Save