Browse Source

deepspeed的save load功能

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
7023ea550c
4 changed files with 106 additions and 75 deletions
  1. +13
    -9
      fastNLP/core/drivers/torch_driver/deepspeed.py
  2. +24
    -1
      fastNLP/core/drivers/torch_driver/utils.py
  3. +67
    -64
      tests/core/drivers/torch_driver/test_deepspeed.py
  4. +2
    -1
      tests/pytest.ini

+ 13
- 9
fastNLP/core/drivers/torch_driver/deepspeed.py View File

@@ -4,7 +4,7 @@ from pathlib import Path
from typing import Union, Dict, List from typing import Union, Dict, List
from .torch_driver import TorchDriver from .torch_driver import TorchDriver
from .ddp import TorchDDPDriver from .ddp import TorchDDPDriver
from .utils import _create_default_config, _DDPWrappingModel
from .utils import _create_default_config, _DeepSpeedWrappingModel
from fastNLP.core.utils import nullcontext from fastNLP.core.utils import nullcontext
from fastNLP.core.log import logger from fastNLP.core.log import logger
from fastNLP.envs import( from fastNLP.envs import(
@@ -14,6 +14,7 @@ from fastNLP.envs import(
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_DEEPSPEED from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_DEEPSPEED


if _NEED_IMPORT_TORCH: if _NEED_IMPORT_TORCH:
import pytorch_lightning
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@@ -35,8 +36,8 @@ class DeepSpeedDriver(TorchDDPDriver):
strategy= "deepspeed", strategy= "deepspeed",
**kwargs **kwargs
): ):
assert _NEED_IMPORT_DEEPSPEED, "deepspeed is not imported."
assert not dist.is_initialized(), "DeepSpeedDriver does not support initialize distributed by user."
assert _NEED_IMPORT_DEEPSPEED, "Deepspeed is not imported."
# assert not dist.is_initialized(), "DeepSpeedDriver does not support initialize distributed by user."
TorchDriver.__init__(self, model=model, fp16=False, **kwargs) TorchDriver.__init__(self, model=model, fp16=False, **kwargs)
self.fp16 = fp16 self.fp16 = fp16


@@ -88,7 +89,7 @@ class DeepSpeedDriver(TorchDDPDriver):
# 获取 batch_size 以设置 train_micro_batch_size_per_gpu 参数 # 获取 batch_size 以设置 train_micro_batch_size_per_gpu 参数
train_dl = kwargs.get("train_dataloader", None) train_dl = kwargs.get("train_dataloader", None)
if train_dl is not None: if train_dl is not None:
self.train_micro_batch_size = self.get_dataloader_args(train_dl)
self.train_micro_batch_size = self.get_dataloader_args(train_dl).batch_size
else: else:
logger.warn("No `train_dataloader` found, and we will set `train_micro_batch_size_per_gpu`" logger.warn("No `train_dataloader` found, and we will set `train_micro_batch_size_per_gpu`"
"to 1 for deepspeed configuration.") "to 1 for deepspeed configuration.")
@@ -166,7 +167,7 @@ class DeepSpeedDriver(TorchDDPDriver):
# 设置 deepspeed # 设置 deepspeed
if not isinstance(self.model, deepspeed.DeepSpeedEngine): if not isinstance(self.model, deepspeed.DeepSpeedEngine):
model=_DDPWrappingModel(self.model)
model=_DeepSpeedWrappingModel(self.model, self.fp16)
model_parameters = filter(lambda p: p.requires_grad, model.parameters()) model_parameters = filter(lambda p: p.requires_grad, model.parameters())
self.model, ds_optimizer, _, _ = deepspeed.initialize( self.model, ds_optimizer, _, _ = deepspeed.initialize(
model=model, model=model,
@@ -279,7 +280,7 @@ class DeepSpeedDriver(TorchDDPDriver):
:return: :return:
""" """
# deepspeed engine 要求在每个 rank 都调用 save_checkpoint,故去掉了 rank_zero_call 装饰器 # deepspeed engine 要求在每个 rank 都调用 save_checkpoint,故去掉了 rank_zero_call 装饰器
if self.zero_stage_3:
if self.stage_3:
logger.rank_zero_warning( logger.rank_zero_warning(
"When saving the DeepSpeed Stage 3 checkpoint, " "When saving the DeepSpeed Stage 3 checkpoint, "
"each worker will save a shard of the checkpoint within a directory. " "each worker will save a shard of the checkpoint within a directory. "
@@ -310,7 +311,8 @@ class DeepSpeedDriver(TorchDDPDriver):
def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):
# deepspeed engine 要求在每个 rank 都调用 save_checkpoint,故去掉了 rank_zero_call 装饰器 # deepspeed engine 要求在每个 rank 都调用 save_checkpoint,故去掉了 rank_zero_call 装饰器
# 1. 保存 sampler 的状态 # 1. 保存 sampler 的状态
sampler_state_dict = self.get_sampler_state_dict()
num_consumed_batches = states.pop('num_consumed_batches')
states['sampler_states'] = self.get_sampler_state(dataloader, num_consumed_batches)


# 2. 保存模型的状态; # 2. 保存模型的状态;
if not should_save_model: if not should_save_model:
@@ -318,7 +320,7 @@ class DeepSpeedDriver(TorchDDPDriver):
"so we will still save the model for you.") "so we will still save the model for you.")


self.model.save_checkpoint(Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME), self.model.save_checkpoint(Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME),
client_state=sampler_state_dict)
client_state=states)


def load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: def load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict:
# 1. 加载模型状态; # 1. 加载模型状态;
@@ -330,7 +332,9 @@ class DeepSpeedDriver(TorchDDPDriver):
raise RuntimeError(f"Failed to load checkpoint from path: {str(folder)}") raise RuntimeError(f"Failed to load checkpoint from path: {str(folder)}")


# 2.恢复 sampler 的状态 # 2.恢复 sampler 的状态
states = self.load_sampler_state_dict(states)
sampler_states = states.pop('sampler_states')
states_ret = self.load_sampler_state(dataloader, sampler_states)
states.update(states_ret)


return states return states




+ 24
- 1
fastNLP/core/drivers/torch_driver/utils.py View File

@@ -15,7 +15,7 @@ from fastNLP.envs import (
FASTNLP_GLOBAL_SEED, FASTNLP_GLOBAL_SEED,
) )
from fastNLP.core.samplers import re_instantiate_sampler, ReproducibleBatchSampler from fastNLP.core.samplers import re_instantiate_sampler, ReproducibleBatchSampler
from fastNLP.core.utils import auto_param_call
from fastNLP.core.utils import auto_param_call, apply_to_collection
from fastNLP.core.log import logger from fastNLP.core.log import logger


if _NEED_IMPORT_TORCH: if _NEED_IMPORT_TORCH:
@@ -107,6 +107,29 @@ class _DDPWrappingModel(Module):
else: else:
return fn(batch) return fn(batch)


class _DeepSpeedWrappingModel(_DDPWrappingModel):
"""
继承 ``_DDPWrappingModel``,区别在于进行 forward 之前先将 float 数据转换为 float16
"""

def __init__(self, model: Module, fp16):
super(_DeepSpeedWrappingModel, self).__init__(model)
self.fp16 = fp16

def forward(self, batch, **kwargs):
if self.fp16:
batch = self._move_float_tensors_to_half(batch)

return super().forward(batch, **kwargs)

@staticmethod
def batch_to(data):
return data.half()

def _move_float_tensors_to_half(self, batch: Any):
batch = apply_to_collection(batch, (torch.FloatTensor, torch.cuda.FloatTensor), function=self.batch_to)
return batch



class DummyGradScaler: class DummyGradScaler:
""" """


+ 67
- 64
tests/core/drivers/torch_driver/test_deepspeed.py View File

@@ -1,33 +1,30 @@
import os import os
from pathlib import Path


import pytest import pytest
from pathlib import Path


from fastNLP.core.drivers.torch_driver.deepspeed import DeepSpeedDriver from fastNLP.core.drivers.torch_driver.deepspeed import DeepSpeedDriver
from fastNLP.core.samplers import ( from fastNLP.core.samplers import (
RandomSampler, RandomSampler,
UnrepeatedSampler,
BucketedBatchSampler, BucketedBatchSampler,
UnrepeatedRandomSampler, UnrepeatedRandomSampler,
UnrepeatedSequentialSampler,
) )
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchNormalXYDataset
from tests.helpers.datasets.torch_data import TorchNormalXYDataset
from tests.helpers.utils import magic_argv_env_context from tests.helpers.utils import magic_argv_env_context
from fastNLP.envs.distributed import rank_zero_rm from fastNLP.envs.distributed import rank_zero_rm
from fastNLP import logger from fastNLP import logger

from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_DEEPSPEED from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_DEEPSPEED


if _NEED_IMPORT_TORCH: if _NEED_IMPORT_TORCH:
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.utils.data import DataLoader, BatchSampler
from torch.utils.data import DataLoader
if _NEED_IMPORT_DEEPSPEED: if _NEED_IMPORT_DEEPSPEED:
import deepspeed import deepspeed


def generate_driver(labels, features, device=[0,1], fp16=False, output_from_new_proc="all"):
def generate_driver(labels, features, device=[0,1], fp16=False, output_from_new_proc="all", train_dataloader=None):
torch_model = TorchNormalModel_Classification_1(labels, features) torch_model = TorchNormalModel_Classification_1(labels, features)
torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01) torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01)
device = [torch.device(i) for i in device] device = [torch.device(i) for i in device]
@@ -35,7 +32,8 @@ def generate_driver(labels, features, device=[0,1], fp16=False, output_from_new_
model=torch_model, model=torch_model,
parallel_device=device, parallel_device=device,
fp16=fp16, fp16=fp16,
output_from_new_proc=output_from_new_proc
output_from_new_proc=output_from_new_proc,
train_dataloader=train_dataloader
) )
driver.set_optimizers(torch_opt) driver.set_optimizers(torch_opt)
driver.setup() driver.setup()
@@ -77,33 +75,33 @@ def dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last, seed=


############################################################################ ############################################################################
# #
# 测试 TorchDDPDriver 的一些函数
# 测试 TorchDeepSpeedDriver 的一些函数
# #
############################################################################ ############################################################################


@pytest.mark.torch
@magic_argv_env_context
def test_multi_drivers():
"""
测试使用了多个 TorchDDPDriver 的情况。
"""
generate_driver(10, 10)
generate_driver(20, 10)
# @pytest.mark.deepspeed
# @magic_argv_env_context
# def test_multi_drivers():
# """
# 测试使用了多个 TorchDeepSpeedDriver 的情况。
# """
# generate_driver(10, 10)
# generate_driver(20, 10)
with pytest.raises(RuntimeError):
# 设备设置不同,应该报错
generate_driver(20, 3, device=[0,1,2])
assert False
dist.barrier()
# with pytest.raises(RuntimeError):
# # 设备设置不同,应该报错
# generate_driver(20, 3, device=[0,1,2])
# assert False
# dist.barrier()


if dist.is_initialized():
dist.destroy_process_group()
# if dist.is_initialized():
# dist.destroy_process_group()


@magic_argv_env_context @magic_argv_env_context
def test_multi_optimizers(): def test_multi_optimizers():
torch_model = TorchNormalModel_Classification_1(10, 10) torch_model = TorchNormalModel_Classification_1(10, 10)
torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01) torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01)
device = [torch.device(i) for i in device]
device = [torch.device(i) for i in [0, 1]]
driver = DeepSpeedDriver( driver = DeepSpeedDriver(
model=torch_model, model=torch_model,
parallel_device=device, parallel_device=device,
@@ -112,57 +110,59 @@ def test_multi_optimizers():
with pytest.raises(ValueError): with pytest.raises(ValueError):
driver.setup() driver.setup()


if dist.is_initialized():
dist.destroy_process_group()
# if dist.is_initialized():
# dist.destroy_process_group()


@pytest.mark.torch
@pytest.mark.deepspeed
class TestDeepSpeedDriverFunction: class TestDeepSpeedDriverFunction:
""" """
测试 TorchDeepSpeedDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题 测试 TorchDeepSpeedDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题
""" """
@classmethod
def setup_class(cls):
cls.driver = generate_driver(10, 10)


@magic_argv_env_context @magic_argv_env_context
def test_simple_functions(self): def test_simple_functions(self):
""" """
简单测试多个函数 简单测试多个函数
""" """
driver = generate_driver(10, 10)


""" """
测试 move_data_to_device 函数。这个函数仅调用了 torch_move_data_to_device ,测试例在 测试 move_data_to_device 函数。这个函数仅调用了 torch_move_data_to_device ,测试例在
tests/core/utils/test_torch_utils.py中,就不重复测试了 tests/core/utils/test_torch_utils.py中,就不重复测试了
""" """
driver.move_data_to_device(torch.rand((32, 64)))
self.driver.move_data_to_device(torch.rand((32, 64)))
dist.barrier() dist.barrier()


""" """
测试 is_distributed 函数 测试 is_distributed 函数
""" """
assert driver.is_distributed() == True
assert self.driver.is_distributed() == True
dist.barrier() dist.barrier()


""" """
测试 get_no_sync_context 函数 测试 get_no_sync_context 函数
""" """
res = driver.get_model_no_sync_context()
res = self.driver.get_model_no_sync_context()
dist.barrier() dist.barrier()


""" """
测试 is_global_zero 函数 测试 is_global_zero 函数
""" """
driver.is_global_zero()
self.driver.is_global_zero()
dist.barrier() dist.barrier()


""" """
测试 unwrap_model 函数 测试 unwrap_model 函数
""" """
driver.unwrap_model()
self.driver.unwrap_model()
dist.barrier() dist.barrier()


""" """
测试 get_local_rank 函数 测试 get_local_rank 函数
""" """
driver.get_local_rank()
self.driver.get_local_rank()
dist.barrier() dist.barrier()


""" """
@@ -170,9 +170,9 @@ class TestDeepSpeedDriverFunction:
详细的测试在 test_dist_utils.py 中完成 详细的测试在 test_dist_utils.py 中完成
""" """
obj = { obj = {
"rank": driver.global_rank
"rank": self.driver.global_rank
} }
obj_list = driver.all_gather(obj, group=None)
obj_list = self.driver.all_gather(obj, group=None)
for i, res in enumerate(obj_list): for i, res in enumerate(obj_list):
assert res["rank"] == i assert res["rank"] == i


@@ -180,28 +180,32 @@ class TestDeepSpeedDriverFunction:
测试 broadcast_object 函数 测试 broadcast_object 函数
详细的函数在 test_dist_utils.py 中完成 详细的函数在 test_dist_utils.py 中完成
""" """
if driver.global_rank == 0:
if self.driver.global_rank == 0:
obj = { obj = {
"rank": driver.global_rank
"rank": self.driver.global_rank
} }
else: else:
obj = None obj = None
res = driver.broadcast_object(obj, src=0)
res = self.driver.broadcast_object(obj, src=0)
assert res["rank"] == 0 assert res["rank"] == 0


if dist.is_initialized():
dist.destroy_process_group()
# if dist.is_initialized():
# dist.destroy_process_group()


############################################################################ ############################################################################
# #
# 测试 save 和 load 相关的功能 # 测试 save 和 load 相关的功能
# #
############################################################################ ############################################################################
@pytest.mark.torch
@pytest.mark.deepspeed
class TestSaveLoad: class TestSaveLoad:
""" """
测试多卡情况下 save 和 load 相关函数的表现 测试多卡情况下 save 和 load 相关函数的表现
""" """
@classmethod
def setup_class(cls):
# 不在这里 setup 的话会报错
cls.driver = generate_driver(10, 10, device=[0,1])


def setup_method(self): def setup_method(self):
self.dataset = TorchNormalXYDataset(100) self.dataset = TorchNormalXYDataset(100)
@@ -216,7 +220,8 @@ class TestSaveLoad:
path = "model" path = "model"


dataloader = DataLoader(self.dataset, batch_size=2) dataloader = DataLoader(self.dataset, batch_size=2)
driver1, driver2 = generate_driver(20, 1), generate_driver(20, 1)
driver1, driver2 = generate_driver(20, 1, train_dataloader=dataloader), \
generate_driver(20, 1, train_dataloader=dataloader)


driver1.save_model(path, only_state_dict) driver1.save_model(path, only_state_dict)


@@ -244,8 +249,8 @@ class TestSaveLoad:
finally: finally:
rank_zero_rm(path) rank_zero_rm(path)


if dist.is_initialized():
dist.destroy_process_group()
# if dist.is_initialized():
# dist.destroy_process_group()


@magic_argv_env_context @magic_argv_env_context
@pytest.mark.parametrize("only_state_dict", ([True, False])) @pytest.mark.parametrize("only_state_dict", ([True, False]))
@@ -260,8 +265,6 @@ class TestSaveLoad:
path = "model.ckp" path = "model.ckp"
num_replicas = len(device) num_replicas = len(device)


driver1, driver2 = generate_driver(20, 1, device=device, fp16=fp16), \
generate_driver(20, 1, device=device, fp16=False)
dataloader = dataloader_with_bucketedbatchsampler( dataloader = dataloader_with_bucketedbatchsampler(
self.dataset, self.dataset,
length=[10 for i in range(len(self.dataset))], length=[10 for i in range(len(self.dataset))],
@@ -270,11 +273,13 @@ class TestSaveLoad:
drop_last=False drop_last=False
) )
dataloader.batch_sampler.set_distributed( dataloader.batch_sampler.set_distributed(
num_replicas=driver1.world_size,
rank=driver1.global_rank,
pad=True
num_replicas=int(os.getenv("WORLD_SIZE", "1")),
rank=int(os.getenv("RANK", "0")),
pad=True,
) )
num_consumed_batches = 4 num_consumed_batches = 4
driver1, driver2 = generate_driver(20, 1, device=device, fp16=fp16, train_dataloader=dataloader), \
generate_driver(20, 1, device=device, fp16=False, train_dataloader=dataloader)


already_seen_x_set = set() already_seen_x_set = set()
already_seen_y_set = set() already_seen_y_set = set()
@@ -323,10 +328,6 @@ class TestSaveLoad:
assert replaced_loader.batch_sampler.seed == sampler_states["seed"] assert replaced_loader.batch_sampler.seed == sampler_states["seed"]
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 * num_replicas assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 * num_replicas


# 3. 检查 fp16 是否被加载
if fp16:
assert not isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler)

# 4. 检查 model 的参数是否正确 # 4. 检查 model 的参数是否正确
# 5. 检查 batch_idx # 5. 检查 batch_idx
start_batch = load_states.pop('batch_idx_in_epoch') start_batch = load_states.pop('batch_idx_in_epoch')
@@ -338,6 +339,7 @@ class TestSaveLoad:


left_x_batches.update(batch["x"].reshape(-1, ).tolist()) left_x_batches.update(batch["x"].reshape(-1, ).tolist())
left_y_batches.update(batch["y"].reshape(-1, ).tolist()) left_y_batches.update(batch["y"].reshape(-1, ).tolist())
batch = driver1.move_data_to_device(batch)
res1 = driver1.model( res1 = driver1.model(
batch, batch,
fastnlp_fn=driver1.model.module.model.evaluate_step, fastnlp_fn=driver1.model.module.model.evaluate_step,
@@ -361,8 +363,8 @@ class TestSaveLoad:
finally: finally:
rank_zero_rm(path) rank_zero_rm(path)


if dist.is_initialized():
dist.destroy_process_group()
# if dist.is_initialized():
# dist.destroy_process_group()


@magic_argv_env_context @magic_argv_env_context
@pytest.mark.parametrize("only_state_dict", ([True, False])) @pytest.mark.parametrize("only_state_dict", ([True, False]))
@@ -378,16 +380,16 @@ class TestSaveLoad:


num_replicas = len(device) num_replicas = len(device)


driver1 = generate_driver(20, 1, device=device, fp16=fp16)
driver2 = generate_driver(20, 1, device=device, fp16=False)

dataloader = dataloader_with_randomsampler(self.dataset, 4, True, False, unrepeated=False) dataloader = dataloader_with_randomsampler(self.dataset, 4, True, False, unrepeated=False)
dataloader.batch_sampler.sampler.set_distributed( dataloader.batch_sampler.sampler.set_distributed(
num_replicas=driver1.world_size,
rank=driver1.global_rank,
num_replicas=int(os.getenv("WORLD_SIZE", "1")),
rank=int(os.getenv("RANK", "0")),
pad=True pad=True
) )
num_consumed_batches = 4 num_consumed_batches = 4
driver1 = generate_driver(20, 1, device=device, fp16=fp16, train_dataloader=dataloader)
driver2 = generate_driver(20, 1, device=device, fp16=False, train_dataloader=dataloader)


already_seen_x_set = set() already_seen_x_set = set()
already_seen_y_set = set() already_seen_y_set = set()
@@ -448,6 +450,7 @@ class TestSaveLoad:


left_x_batches.update(batch["x"].reshape(-1, ).tolist()) left_x_batches.update(batch["x"].reshape(-1, ).tolist())
left_y_batches.update(batch["y"].reshape(-1, ).tolist()) left_y_batches.update(batch["y"].reshape(-1, ).tolist())
batch = driver1.move_data_to_device(batch)
res1 = driver1.model( res1 = driver1.model(
batch, batch,
fastnlp_fn=driver1.model.module.model.evaluate_step, fastnlp_fn=driver1.model.module.model.evaluate_step,
@@ -471,5 +474,5 @@ class TestSaveLoad:
finally: finally:
rank_zero_rm(path) rank_zero_rm(path)


if dist.is_initialized():
dist.destroy_process_group()
# if dist.is_initialized():
# dist.destroy_process_group()

+ 2
- 1
tests/pytest.ini View File

@@ -5,4 +5,5 @@ markers =
paddledist paddledist
jittor jittor
torchpaddle torchpaddle
torchjittor
torchjittor
deepspeed

Loading…
Cancel
Save