Browse Source

添加了 TorchFSDPDriver;修改了 ddp 中的部分细节;删除了 topksaveer 的 rank_zero_only 修饰器

tags/v1.0.0alpha
YWMditto 2 years ago
parent
commit
0506fc2fcb
10 changed files with 756 additions and 20 deletions
  1. +2
    -2
      fastNLP/core/callbacks/topk_saver.py
  2. +1
    -1
      fastNLP/core/drivers/choose_driver.py
  3. +12
    -6
      fastNLP/core/drivers/torch_driver/ddp.py
  4. +11
    -2
      fastNLP/core/drivers/torch_driver/initialize_torch_driver.py
  5. +1
    -1
      fastNLP/core/drivers/torch_driver/torch_driver.py
  6. +341
    -0
      fastNLP/core/drivers/torch_driver/torch_fsdp.py
  7. +1
    -0
      fastNLP/envs/imports.py
  8. +1
    -1
      tests/core/callbacks/test_checkpoint_callback_torch.py
  9. +7
    -7
      tests/core/controllers/test_trainer_w_evaluator_torch.py
  10. +379
    -0
      tests/core/drivers/torch_driver/test_fsdp.py

+ 2
- 2
fastNLP/core/callbacks/topk_saver.py View File

@@ -51,7 +51,6 @@ class Saver:

self.timestamp_path = self.folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME])

@rank_zero_call
def save(self, trainer, folder_name):
"""
执行保存的函数,将数据保存在::
@@ -66,6 +65,7 @@ class Saver:
"""
folder = self.timestamp_path.joinpath(folder_name)
folder.mkdir(parents=True, exist_ok=True)

save_fn = getattr(trainer, self.save_fn_name)
save_fn(
folder=folder,
@@ -217,7 +217,7 @@ class TopkSaver(ResultsMonitor, Saver):
self.topk_queue = TopkQueue(topk)
self.save_evaluate_results = save_evaluate_results

@rank_zero_call
# 注意这里我们为了支持 torch_fsdp 去除了 ''@rank_zero_call'';
def save_topk(self, trainer, results: Dict) -> Optional[str]:
"""
根据 ``results`` 是否满足 topk 的相关设定决定是否保存,如果发生了保存,将返回保存的文件夹。如果返回为 ``None`` ,则说明此次没有满足


+ 1
- 1
fastNLP/core/drivers/choose_driver.py View File

@@ -30,7 +30,7 @@ def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int,
else:
raise ValueError(f"Cannot choose driver automatically based on model, please set `driver` specifically.")

if driver in {"torch", "fairscale", "deepspeed"}:
if driver in {"torch", "fairscale", "deepspeed", "torch_fsdp"}:
from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver
return initialize_torch_driver(driver, device, model, **kwargs)
elif driver in {"jittor"}:


+ 12
- 6
fastNLP/core/drivers/torch_driver/ddp.py View File

@@ -309,9 +309,9 @@ class TorchDDPDriver(TorchDriver):
self.world_size = None # int(os.environ.get("WORLD_SIZE")) len(self.parallel_device)
self.global_rank = 0

self._ddp_kwargs = self._torch_kwargs.get("ddp_kwargs", {})
check_user_specific_params(self._ddp_kwargs, DistributedDataParallel.__init__, DistributedDataParallel.__name__)
if len(self.model._buffers) != 0 and self._ddp_kwargs.get("broadcast_buffers", None) is None:
self._fsdp_kwargs = self._torch_kwargs.get("ddp_kwargs", {})
check_user_specific_params(self._fsdp_kwargs, DistributedDataParallel.__init__, DistributedDataParallel.__name__)
if len(self.model._buffers) != 0 and self._fsdp_kwargs.get("broadcast_buffers", None) is None:
logger.info("Notice your model has buffers and you are using `TorchDDPDriver`, but you do not set "
"'broadcast_buffers' in your trainer. Cause in most situations, this parameter can be set"
" to 'False' to avoid redundant data communication between different processes.")
@@ -381,8 +381,6 @@ class TorchDDPDriver(TorchDriver):
self.global_rank = dist.get_rank()

if not self.outside_ddp:
torch.cuda.set_device(self.model_device)
self.model.to(self.model_device)
self.configure_ddp()

self.barrier()
@@ -400,11 +398,13 @@ class TorchDDPDriver(TorchDriver):
self._pids = self.tensor_to_numeric(self._pids)

def configure_ddp(self):
torch.cuda.set_device(self.model_device)
self.model.to(self.model_device)
if not isinstance(self.model, DistributedDataParallel):
self.model = DistributedDataParallel(
# 注意这里的 self.model_device 是 `torch.device` type,因此 self.model_device.index;
_DDPWrappingModel(self.model), device_ids=[self.model_device.index],
**self._ddp_kwargs
**self._fsdp_kwargs
)
self._has_ddpwrapped = True

@@ -505,6 +505,12 @@ class TorchDDPDriver(TorchDriver):
raise RuntimeError(f"The `{fn}` attribute of model is not `Callable`.")
return fn, None
elif fn in {"train_step", "evaluate_step"}:

logger.warning("\n\nfucking hei\n\n")
print(model)
print("\n\n")
print(type(model))
print("\n\n")
return model, model.forward
else:
raise RuntimeError(f"There is no `{fn}` method in your model.")


+ 11
- 2
fastNLP/core/drivers/torch_driver/initialize_torch_driver.py View File

@@ -9,6 +9,7 @@ from .single_device import TorchSingleDriver
from .ddp import TorchDDPDriver
from .fairscale import FairScaleDriver
from .deepspeed import DeepSpeedDriver
from .torch_fsdp import TorchFSDPDriver
from fastNLP.core.log import logger
from fastNLP.envs import FASTNLP_BACKEND_LAUNCH
from pkg_resources import parse_version
@@ -45,7 +46,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi
return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"),
is_pull_by_torch_run=True, **kwargs)

if driver not in {"torch", "fairscale", "deepspeed"}:
if driver not in {"torch", "fairscale", "deepspeed", "torch_fsdp"}:
raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'fairscale'].")

_could_use_device_num = torch.cuda.device_count()
@@ -95,4 +96,12 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi
logger.warning_once("Notice you are using `deepspeed`, but the `device` is only one gpu.")
return DeepSpeedDriver(model, [device], **kwargs)
else:
return DeepSpeedDriver(model, device, **kwargs)
return DeepSpeedDriver(model, device, **kwargs)
elif driver == "torch_fsdp":
if not isinstance(device, List):
if device.type == 'cpu':
raise ValueError("You are using `torch_fsdp` driver, but your chosen `device` is 'cpu'.")
logger.warning_once("Notice you are using `torch_fsdp`, but the `device` is only one gpu.")
return TorchFSDPDriver(model, [device], **kwargs)
else:
return TorchFSDPDriver(model, device, **kwargs)

+ 1
- 1
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -27,7 +27,7 @@ from .utils import optimizer_state_to_device
from fastNLP.core.drivers.driver import Driver
from fastNLP.core.drivers.torch_driver.utils import _build_fp16_env, DummyGradScaler
from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device
from fastNLP.envs import rank_zero_call
from fastNLP.envs import rank_zero_call
from fastNLP.envs import FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
from fastNLP.core.log import logger
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, ReproduceBatchSampler, RandomSampler


+ 341
- 0
fastNLP/core/drivers/torch_driver/torch_fsdp.py View File

@@ -0,0 +1,341 @@



from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_12

if _TORCH_GREATER_EQUAL_1_12:
from torch.distributed.fsdp import FullyShardedDataParallel, StateDictType, FullStateDictConfig, OptimStateKeyType

import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from typing import Optional, Union, List, Dict, Mapping
from pathlib import Path

from .ddp import TorchDDPDriver
from fastNLP.core.drivers.torch_driver.utils import (
_DDPWrappingModel,
)

from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME, \
FASTNLP_GLOBAL_RANK, rank_zero_call
from fastNLP.core.drivers.torch_driver.utils import DummyGradScaler
from fastNLP.core.log import logger
from fastNLP.core.utils import check_user_specific_params
from .utils import optimizer_state_to_device


"""
参考文档:
1. https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/
2. https://pytorch.org/docs/stable/fsdp.html?highlight=fsdp
3. https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html
4. https://engineering.fb.com/2021/07/15/open-source/fsdp/
"""

class TorchFSDPDriver(TorchDDPDriver):
r"""
实现对于 pytorch 自己实现的 fully sharded data parallel;请阅读该文档了解更多:
https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.full_optim_state_dict;

..note::

``TorchFSDPDriver`` 大部分行为与 ``TorchDDPDriver`` 相同,如果您不了解 ``TorchDDPDriver``,
您可以先阅读 :class:`~fastNLP.core.drivers.TorchDDPDriver`;

..warning::

``TorchFSDPDriver`` 现在还不支持断点重训功能,但是支持保存模型和加载模型;

"""

def __init__(
self,
model,
parallel_device: Optional[Union[List["torch.device"], "torch.device"]],
is_pull_by_torch_run: bool = False,
fp16: bool = False,
torch_kwargs: Dict = None,
**kwargs
):

# 在加入很多东西后,需要注意这里调用 super 函数的位置;
super(TorchDDPDriver, self).__init__(model, fp16=fp16, torch_kwargs=torch_kwargs, **kwargs)

if isinstance(model, torch.nn.DataParallel):
raise ValueError(f"Parameter `model` can not be `DataParallel` in `TorchDDPDriver`, it should be "
f"`torch.nn.Module` or `torch.nn.parallel.DistributedDataParallel` type.")

# 如果用户自己在外面初始化 DDP,那么其一定是通过 python -m torch.distributed.launch 拉起的;
self.is_pull_by_torch_run = is_pull_by_torch_run
self.parallel_device = parallel_device
if not is_pull_by_torch_run and parallel_device is None:
raise ValueError(
"Parameter `parallel_device` can not be None when using `TorchDDPDriver`. This error is caused "
"when your value of parameter `device` is `None` in your `Trainer` instance.")

# 注意我们在 initialize_torch_driver 中的逻辑就是如果是 is_pull_by_torch_run,那么我们就直接把 parallel_device 置为当前进程的gpu;
if is_pull_by_torch_run:
self.model_device = parallel_device
else:
# 我们的 model_device 一定是 torch.device,而不是一个 list;
self.model_device = parallel_device[self.local_rank]

# 如果用户自己在外面初始化了 FSDP;
self.outside_ddp = False
if dist.is_initialized() and FASTNLP_DISTRIBUTED_CHECK not in os.environ and \
"fastnlp_torch_launch_not_ddp" not in os.environ:
# 如果用户自己在外面初始化了 DDP,那么我们要求用户传入的模型一定是已经由 DistributedDataParallel 包裹后的模型;
if not isinstance(model, FullyShardedDataParallel):
raise RuntimeError(
"It is not allowed to input a normal model instead of `FullyShardedDataParallel` when"
"you initialize the ddp process out of our control.")
if isinstance(model, DistributedDataParallel):
logger.warning("You are using `TorchFSDPDriver`, but you have initialized your model as "
"`DistributedDataParallel`, which will make the `FullyShardedDataParallel` not work "
"as expected. You could just delete `DistributedDataParallel` wrap operation.")

self.outside_ddp = True
# 用户只有将模型上传到对应机器上后才能用 DistributedDataParallel 包裹,因此如果用户在外面初始化了 DDP,那么在 TorchDDPDriver 中
# 我们就直接将 model_device 置为 None;
self.model_device = None

# 当用户自己在外面初始化 DDP 时我们会将 model_device 置为 None,这是用户可以通过 `data_device` 将对应的数据移到指定的机器上;
self._data_device = kwargs.get("data_device", None)
if isinstance(self._data_device, int):
if self._data_device < 0:
raise ValueError("Parameter `data_device` can not be smaller than 0.")
_could_use_device_num = torch.cuda.device_count()
if self._data_device >= _could_use_device_num:
raise ValueError("The gpu device that parameter `device` specifies is not existed.")
self._data_device = torch.device(f"cuda:{self._data_device}")
elif isinstance(self._data_device, str):
self._data_device = torch.device(self._data_device)
elif self._data_device is not None and not isinstance(self._data_device, torch.device):
raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.")

self._master_port = None
# world_size 表示的就是全局的显卡的数量;
self.world_size = None # int(os.environ.get("WORLD_SIZE")) len(self.parallel_device)
self.global_rank = 0

self._fsdp_kwargs = self._torch_kwargs.get("fsdp_kwargs", {})
self._save_on_rank0 = self._fsdp_kwargs.get("save_on_rank0", False)
if "save_on_rank0" in self._fsdp_kwargs:
self._fsdp_kwargs.pop("save_on_rank0")
self._load_on_rank0 = self._fsdp_kwargs.get("load_on_rank0", False)
if "load_on_rank0" in self._fsdp_kwargs:
self._fsdp_kwargs.pop("load_on_rank0")

if self._save_on_rank0 != self._load_on_rank0:
logger.warning(f"Notice the behavior between ``save`` and ``load`` is not matched, you choose "
f"{'save on rank0' if self._save_on_rank0 else 'save on each rank'}, but "
f"{'load on rank0' if self._save_on_rank0 else 'load on each rank'}!")

check_user_specific_params(self._fsdp_kwargs, FullyShardedDataParallel.__init__, FullyShardedDataParallel.__name__)
if "cpu_offload" in self._fsdp_kwargs and kwargs["accumulation_steps"] != 1:
logger.warning("It is not supported ``accumulation_steps`` when using ``cpu_offload`` in "
"``FullyShardedDataParallel``.")

self.output_from_new_proc = kwargs.get("output_from_new_proc", "only_error")
assert isinstance(self.output_from_new_proc, str), "Parameter `output_from_new_proc` can only be `str` type."
if self.output_from_new_proc not in {"all", "ignore", "only_error"}:
os.makedirs(name=self.output_from_new_proc, exist_ok=True)
self.output_from_new_proc = os.path.abspath(self.output_from_new_proc)

self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的;
self._has_ddpwrapped = False # 判断传入的模型是否经过 _has_ddpwrapped 包裹;

def configure_ddp(self):
torch.cuda.set_device(self.model_device)
if not isinstance(self.model, FullyShardedDataParallel):
self.model = FullyShardedDataParallel(
# 注意这里的 self.model_device 是 `torch.device` type,因此 self.model_device.index;
_DDPWrappingModel(self.model), device_id=self.model_device.index,
**self._fsdp_kwargs
)

# 必须先使用 FullyShardedDataParallel 包裹模型后再使用 optimizer 包裹模型的参数,因此这里需要将 optimizer 重新初始化一遍;
for i in range(len(self.optimizers)):
self.optimizers[i] = type(self.optimizers[i])(self.model.parameters(), **self.optimizers[i].defaults)

self._has_ddpwrapped = True

def unwrap_model(self):
"""
注意该函数因为需要在特定的时候进行调用,例如 ddp 在 get_model_call_fn 的时候,因此不能够删除;
如果您使用该函数来获取原模型的结构信息,是可以的;
但是如果您想要通过该函数来获取原模型实际的参数,是不可以的,因为在 FullyShardedDataParallel 中模型被切分成了多个部分,而对于每个 gpu 上
的模型只是整体模型的一部分。
"""
_module = self.model.module.module
if isinstance(_module, _DDPWrappingModel):
return _module.model
else:
return _module

def save_model(self, filepath: Union[str, Path], only_state_dict: bool = True, **kwargs):
filepath = Path(filepath)
prefix = filepath.parent
filename = filepath.name
_filename = filename.split('.')
filename, suffix = _filename[0], '.'.join(_filename[1:])
if only_state_dict:
if self._save_on_rank0:
full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FullyShardedDataParallel.state_dict_type(self.model, StateDictType.FULL_STATE_DICT, full_state_dict_config):
state_dict = self.model.state_dict()
rank_zero_call(torch.save)(state_dict, filepath)
else:
# 添加 'rank0/1' 字段来区分全部聚集到 rank0 保存的方式;
_filename = filename.split('_')
filename = _filename[0] + f"_rank{int(os.environ.get(FASTNLP_GLOBAL_RANK, 0))}_" + _filename[1]
filepath = prefix.joinpath(filename + "." + suffix)
with FullyShardedDataParallel.state_dict_type(self.model, StateDictType.LOCAL_STATE_DICT):
state_dict = self.model.state_dict()
torch.save(state_dict, filepath)
else:
raise RuntimeError("When using `TorchFSDPDriver`, only `only_state_dict=True` is allowed.")

def load_model(self, filepath: Union[Path, str], only_state_dict: bool = True, **kwargs):
if only_state_dict is False:
raise RuntimeError("When using `TorchFSDPDriver`, only `only_state_dict=True` is allowed.")
filepath = Path(filepath)
prefix = filepath.parent
filename = filepath.name
_filename = filename.split('.')
filename, suffix = _filename[0], '.'.join(_filename[1:])

if not self._load_on_rank0:
_filename = filename.split('_')
filename = _filename[0] + f"_rank{int(os.environ.get(FASTNLP_GLOBAL_RANK, 0))}_" + _filename[1]
filepath = prefix.joinpath(filename + "." + suffix)
states = torch.load(filepath)
else:
states = torch.load(filepath, map_location="cpu")

if isinstance(states, dict) and only_state_dict is False:
logger.rank_zero_warning(f"It seems like that {filepath} only contains state, you may need to use "
f"`only_state_dict=True`")
elif not isinstance(states, dict) and only_state_dict is True:
logger.rank_zero_warning(f"It seems like that {filepath} is not state, you may need to use "
f"`only_state_dict=False`")
if not isinstance(states, Mapping):
states = states.state_dict()

if self._load_on_rank0:
with FullyShardedDataParallel.state_dict_type(self.model, StateDictType.FULL_STATE_DICT):
self.model.load_state_dict(states)
else:
with FullyShardedDataParallel.state_dict_type(self.model, StateDictType.LOCAL_STATE_DICT):
self.model.load_state_dict(states)

def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):
raise RuntimeError("``TorchFSDPDriver`` does not support ``save_checkpoint`` function for now, there is some "
"technical issues that needs to solve. You can implement your own breakpoint retraining "
"by rewriting this function. The important thing is how to save and load the optimizers' state dict, "
"you can see ``https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.full_optim_state_dict``.")

def load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict:
raise RuntimeError("``TorchFSDPDriver`` does not support ``load_checkpoint`` function for now, there is some "
"technical issues that needs to solve. You can implement your own breakpoint retraining "
"by rewriting this function. The important thing is how to save and load the optimizers' state dict, "
"you can see ``https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.full_optim_state_dict``.")

# todo 这些加了 __ 的函数是目前还不支持;
# 这是因为 1.12 的 pytorch fsdp 的关于如何保存和加载 optimizer state dict 的接口有点过于反人类,无法在 fastNLP 的框架中进行调和
# 使用;
def __get_optimizer_state(self):
optimizers_state_dict = {}
for i in range(len(self.optimizers)):
# 注意这里其余 rank 拿到的是一个空字典,因此在真正保存的时候需要保证只有 rank0 在工作;
optimizer_state = FullyShardedDataParallel.full_optim_state_dict(self.model, self.optimizers[i])
if self._save_on_rank0:
with FullyShardedDataParallel.summon_full_params(self.model):
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0:
unwrapped_model = self.model.module.module
optimizer_state = FullyShardedDataParallel.rekey_optim_state_dict(
optimizer_state, OptimStateKeyType.PARAM_ID, unwrapped_model)
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0:
optimizer_state["state"] = optimizer_state_to_device(optimizer_state["state"], torch.device("cpu"))
optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的;
return optimizers_state_dict

# 这里单独拿出来是因为对于 fsdp 来说,每一个进程都需要运行此函数,因此不能包裹 rank_zero_call;
def __save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):
if not only_state_dict:
raise RuntimeError("When using `TorchFSDPDriver`, only `only_state_dict=True` is allowed.")

# 1. sampler 的状态;
num_consumed_batches = states.pop('num_consumed_batches')
states['sampler_states'] = self.get_sampler_state(dataloader, num_consumed_batches)

# 2. 保存模型的状态;
if should_save_model:
if not os.path.exists(folder):
os.mkdir(folder)
model_path = folder.joinpath(FASTNLP_MODEL_FILENAME)
self.save_model(model_path, only_state_dict=True)

# 3. 保存 optimizers 的状态;
states["optimizers_state_dict"] = self.get_optimizer_state()
logger.debug("Save optimizer state dict.")

# 4. 保存fp16的状态
if not isinstance(self.grad_scaler, DummyGradScaler):
grad_scaler_state_dict = self.grad_scaler.state_dict()
states['grad_scaler_state_dict'] = grad_scaler_state_dict

# 确保只有 rank0 才会执行实际的保存操作;
rank_zero_call(torch.save)(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME))

def __load_optimizer_state(self, states):
assert len(states) == len(self.optimizers), f"The number of optimizers is:{len(self.optimizers)}, while in " \
f"checkpoint it is:{len(states)}"

with FullyShardedDataParallel.summon_full_params(self.model):
unwrapped_model = self.model.module.module

for i in range(len(self.optimizers)):
optimizer_state = states[f'optimizer{i}']
if self._load_on_rank0:
optimizer_state = FullyShardedDataParallel.rekey_optim_state_dict(optimizer_state, OptimStateKeyType.PARAM_NAME, unwrapped_model)
optimizer_state = FullyShardedDataParallel.shard_full_optim_state_dict(optimizer_state, unwrapped_model)
optimizer: torch.optim.Optimizer = type(self.optimizers[i])(unwrapped_model.parameters(), **self.optimizers[i].defaults)
optimizer.load_state_dict(optimizer_state)
self.optimizers[i] = optimizer

logger.debug("Load optimizer state dict.")

def __load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict:
if not only_state_dict:
raise RuntimeError("When using `TorchFSDPDriver`, only `only_state_dict=True` is allowed.")

states = torch.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))

# 1. 加载 optimizers 的状态;
optimizers_state_dict = states.pop("optimizers_state_dict")
self.load_optimizer_state(optimizers_state_dict)

# 2. 加载模型状态;
if should_load_model:
self.load_model(filepath=folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict=only_state_dict)

# 3. 加载 fp16 的状态
if "grad_scaler_state_dict" in states:
grad_scaler_state_dict = states.pop("grad_scaler_state_dict")
if not isinstance(self.grad_scaler, DummyGradScaler):
self.grad_scaler.load_state_dict(grad_scaler_state_dict)
logger.debug("Load grad_scaler state dict...")
elif not isinstance(self.grad_scaler, DummyGradScaler):
logger.rank_zero_warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, "
f"the training process may be unstable.")

# 4. 恢复 sampler 的状态;
sampler_states = states.pop('sampler_states')
states_ret = self.load_sampler_state(dataloader, sampler_states)
states.update(states_ret)

return states


+ 1
- 0
fastNLP/envs/imports.py View File

@@ -26,3 +26,4 @@ _NEED_IMPORT_DEEPSPEED = _module_available("deepspeed") and 'torch' in need_impo
_NEED_IMPORT_ONEFLOW = _module_available("oneflow") and 'oneflow' in need_import

_TORCH_GREATER_EQUAL_1_8 = _NEED_IMPORT_TORCH and _compare_version("torch", operator.ge, "1.8.0")
_TORCH_GREATER_EQUAL_1_12 = _NEED_IMPORT_TORCH and _compare_version("torch", operator.ge, "1.12.0")

+ 1
- 1
tests/core/callbacks/test_checkpoint_callback_torch.py View File

@@ -75,7 +75,7 @@ def model_and_optimizers(request):


@pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1)
@pytest.mark.parametrize("driver,device", [("torch", [4, 5])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1)
@magic_argv_env_context(timeout=100)
def test_model_checkpoint_callback_1(
model_and_optimizers: TrainerParameters,


+ 7
- 7
tests/core/controllers/test_trainer_w_evaluator_torch.py View File

@@ -103,8 +103,8 @@ def model_and_optimizers(request):

# 测试一下普通的情况;
@pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1),
("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1])
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 4),
("torch", [4, 5])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1])
@pytest.mark.parametrize("evaluate_every", [-3, -1, 2])
@magic_argv_env_context
def test_trainer_torch_with_evaluator(
@@ -139,7 +139,7 @@ def test_trainer_torch_with_evaluator(


@pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", [0, 1]), ("torch", 1)]) # ("torch", [0, 1]),("torch", 1)
@pytest.mark.parametrize("driver,device", [("torch", [4, 5]), ("torch", 4)]) # ("torch", [0, 1]),("torch", 1)
@pytest.mark.parametrize("fp16", [True, False])
@pytest.mark.parametrize("accumulation_steps", [1, 3])
@magic_argv_env_context
@@ -250,7 +250,7 @@ def test_trainer_on(


@pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", 'cpu'), ("torch", 0)]) # ("torch", [0, 1]),("torch", 1)
@pytest.mark.parametrize("driver,device", [("torch", 'cpu'), ("torch", 4)]) # ("torch", [0, 1]),("torch", 1)
@magic_argv_env_context
def test_trainer_specific_params_1(
model_and_optimizers: TrainerParameters,
@@ -291,7 +291,7 @@ def test_trainer_specific_params_1(


@pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1)
@pytest.mark.parametrize("driver,device", [("torch", [4, 5])]) # ("torch", [0, 1]),("torch", 1)
@magic_argv_env_context
def test_trainer_specific_params_2(
model_and_optimizers: TrainerParameters,
@@ -331,7 +331,7 @@ def test_trainer_specific_params_2(
assert trainer.driver.wo_auto_param_call is True
assert trainer.driver.output_from_new_proc == "all"

_ddp_kwargs = trainer.driver._ddp_kwargs
_ddp_kwargs = trainer.driver._fsdp_kwargs
assert _ddp_kwargs.get("broadcast_buffers") is True
assert _ddp_kwargs.get("find_unused_parameters") is True

@@ -340,7 +340,7 @@ def test_trainer_specific_params_2(


@pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1)
@pytest.mark.parametrize("driver,device", [("torch", 4), ("torch", [4, 5])]) # ("torch", [0, 1]),("torch", 1)
@pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)])
@magic_argv_env_context
def test_trainer_w_evaluator_overfit_torch(


+ 379
- 0
tests/core/drivers/torch_driver/test_fsdp.py View File

@@ -0,0 +1,379 @@
import os
from dataclasses import dataclass
from typing import Any
from pathlib import Path
import re

import pytest
from fastNLP.core.controllers.trainer import Trainer
from torchmetrics import Accuracy
from fastNLP.core.callbacks import CheckpointCallback
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback
from tests.helpers.utils import magic_argv_env_context
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.envs import FASTNLP_LAUNCH_TIME, rank_zero_rm
if _NEED_IMPORT_TORCH:
import torch.distributed as dist
from torch.optim import SGD
from torch.utils.data import DataLoader


@dataclass
class ArgMaxDatasetConfig:
num_labels: int = 10
feature_dimension: int = 10
data_num: int = 50
seed: int = 0

batch_size: int = 2
shuffle: bool = True


@dataclass
class TrainerParameters:
model: Any = None
optimizers: Any = None
train_dataloader: Any = None
evaluate_dataloaders: Any = None
input_mapping: Any = None
output_mapping: Any = None
metrics: Any = None


@pytest.fixture(scope="module", params=[0], autouse=True)
def model_and_optimizers(request):
trainer_params = TrainerParameters()

trainer_params.model = TorchNormalModel_Classification_1(
num_labels=ArgMaxDatasetConfig.num_labels,
feature_dimension=ArgMaxDatasetConfig.feature_dimension
)
trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001)
dataset = TorchArgMaxDataset(
feature_dimension=ArgMaxDatasetConfig.feature_dimension,
data_num=ArgMaxDatasetConfig.data_num,
seed=ArgMaxDatasetConfig.seed
)
_dataloader = DataLoader(
dataset=dataset,
batch_size=ArgMaxDatasetConfig.batch_size,
shuffle=True
)
trainer_params.train_dataloader = _dataloader
trainer_params.evaluate_dataloaders = _dataloader
trainer_params.metrics = {"acc": Accuracy()}

return trainer_params

@pytest.mark.torch
@magic_argv_env_context
def test_trainer_torch_without_evaluator(
model_and_optimizers: TrainerParameters,
n_epochs=3,
):
callbacks = [RecordLossCallback(loss_threshold=0.5)]
trainer = Trainer(
model=model_and_optimizers.model,
driver="torch_fsdp",
device=[4, 5],
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,

n_epochs=3,
callbacks=callbacks,
output_from_new_proc="all"

)

trainer.run()

if dist.is_initialized():
dist.destroy_process_group()


@pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch_fsdp", [4, 5])])
@magic_argv_env_context(timeout=100)
def test_model_checkpoint_callback_1(
model_and_optimizers: TrainerParameters,
driver,
device
):
for version in [0]:
# 需要在每一个循环开始重新初始化 model,是因为 fsdp 会将当前卡上的 model 删除,从而导致这个引用实际上引用到的是一个空模型;
model_and_optimizers.model = TorchNormalModel_Classification_1(
num_labels=ArgMaxDatasetConfig.num_labels,
feature_dimension=ArgMaxDatasetConfig.feature_dimension
)
try:
path = Path.cwd().joinpath(f"test_model_checkpoint")
path.mkdir(exist_ok=True, parents=True)

if version == 0:
callbacks = [
CheckpointCallback(folder=path, every_n_epochs=1, every_n_batches=123, last=False, on_exceptions=None, topk=0,
monitor=None, only_state_dict=True, save_object='model')
]
elif version == 1:
callbacks = [
CheckpointCallback(folder=path, every_n_epochs=3, every_n_batches=None, last=True, on_exceptions=None, topk=2,
monitor="acc", only_state_dict=True, save_object='model')
]

trainer = Trainer(
model=model_and_optimizers.model,
driver=driver,
device=device,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,
n_epochs=10,
callbacks=callbacks,
output_from_new_proc="all",
# torch_kwargs={"fsdp_kwargs": {'save_on_rank0': True}}
)

trainer.run()
print("Finish train")
all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()}
# 检查生成保存模型文件的数量是不是正确的;
if version == 0:

if not isinstance(device, list):
assert "model-epoch_10" in all_saved_model_paths
assert "model-epoch_4-batch_123" in all_saved_model_paths

epoch_save_path = all_saved_model_paths["model-epoch_10"]
step_save_path = all_saved_model_paths["model-epoch_4-batch_123"]

assert len(all_saved_model_paths) == 12
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完;
else:
assert "model-epoch_6" in all_saved_model_paths
assert "model-epoch_9-batch_123" in all_saved_model_paths

epoch_save_path = all_saved_model_paths["model-epoch_6"]
step_save_path = all_saved_model_paths["model-epoch_9-batch_123"]

assert len(all_saved_model_paths) == 11
all_state_dicts = [epoch_save_path]#, step_save_path]

elif version == 1:

pattern = re.compile("model-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*")

if not isinstance(device, list):
assert "model-epoch_9" in all_saved_model_paths
assert "model-last" in all_saved_model_paths
aLL_topk_folders = []
for each_folder_name in all_saved_model_paths:
each_folder_name = pattern.findall(each_folder_name)
if len(each_folder_name) != 0:
aLL_topk_folders.append(each_folder_name[0])
assert len(aLL_topk_folders) == 2

epoch_save_path = all_saved_model_paths["model-epoch_9"]
last_save_path = all_saved_model_paths["model-last"]
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]]

assert len(all_saved_model_paths) == 6
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完;
else:
assert "model-epoch_9" in all_saved_model_paths
assert "model-last" in all_saved_model_paths

aLL_topk_folders = []
for each_folder_name in all_saved_model_paths:
each_folder_name = pattern.findall(each_folder_name)
if len(each_folder_name) != 0:
aLL_topk_folders.append(each_folder_name[0])
assert len(aLL_topk_folders) == 2

epoch_save_path = all_saved_model_paths["model-epoch_9"]
last_save_path = all_saved_model_paths["model-last"]
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]]

assert len(all_saved_model_paths) == 6

all_state_dicts = [epoch_save_path, last_save_path, topk_save_path]

for folder in all_state_dicts:
model_and_optimizers.model = TorchNormalModel_Classification_1(
num_labels=ArgMaxDatasetConfig.num_labels,
feature_dimension=ArgMaxDatasetConfig.feature_dimension
)

trainer = Trainer(
model=model_and_optimizers.model,
driver=driver,
device=device,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,

n_epochs=20,
output_from_new_proc="all",

)
trainer.load_model(folder, only_state_dict=True)

trainer.run()
trainer.driver.barrier()
finally:
rank_zero_rm(path)

if dist.is_initialized():
dist.destroy_process_group()





@pytest.mark.skip("现在 fsdp 还不支持断点重训;")
@pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch_fsdp", [6, 7])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1)
@magic_argv_env_context(timeout=100)
def test_trainer_checkpoint_callback_1(
model_and_optimizers: TrainerParameters,
driver,
device
):
for version in [0, 1]:
model_and_optimizers.model = TorchNormalModel_Classification_1(
num_labels=ArgMaxDatasetConfig.num_labels,
feature_dimension=ArgMaxDatasetConfig.feature_dimension
)
try:
path = Path.cwd().joinpath(f"test_model_checkpoint")
path.mkdir(exist_ok=True, parents=True)

if version == 0:
callbacks = [
CheckpointCallback(folder=path, every_n_epochs=7, every_n_batches=123, last=False, on_exceptions=None, topk=0,
monitor=None, only_state_dict=True, save_object='trainer')
]
elif version == 1:
callbacks = [
CheckpointCallback(folder=path, every_n_epochs=None, every_n_batches=None, last=True, on_exceptions=None,
topk=2, monitor="acc", only_state_dict=True, save_object='trainer')
]

trainer = Trainer(
model=model_and_optimizers.model,
driver=driver,
device=device,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,

n_epochs=10,
callbacks=callbacks,
output_from_new_proc="all"
)

trainer.run()

all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()}
# 检查生成保存模型文件的数量是不是正确的;
if version == 0:

if not isinstance(device, list):
assert "trainer-epoch_7" in all_saved_model_paths
assert "trainer-epoch_4-batch_123" in all_saved_model_paths

epoch_save_path = all_saved_model_paths["trainer-epoch_7"]
step_save_path = all_saved_model_paths["trainer-epoch_4-batch_123"]

assert len(all_saved_model_paths) == 3
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完;
else:
assert "trainer-epoch_7" in all_saved_model_paths
assert "trainer-epoch_9-batch_123" in all_saved_model_paths

epoch_save_path = all_saved_model_paths["trainer-epoch_7"]
step_save_path = all_saved_model_paths["trainer-epoch_9-batch_123"]

assert len(all_saved_model_paths) == 2
all_state_dicts = [epoch_save_path, step_save_path]

elif version == 1:

pattern = re.compile("trainer-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*")

# all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()}
if not isinstance(device, list):
assert "trainer-last" in all_saved_model_paths
aLL_topk_folders = []
for each_folder_name in all_saved_model_paths:
each_folder_name = pattern.findall(each_folder_name)
if len(each_folder_name) != 0:
aLL_topk_folders.append(each_folder_name[0])
assert len(aLL_topk_folders) == 2

last_save_path = all_saved_model_paths["trainer-last"]
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]]

assert len(all_saved_model_paths) == 3
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完;
else:
assert "trainer-last" in all_saved_model_paths

aLL_topk_folders = []
for each_folder_name in all_saved_model_paths:
each_folder_name = pattern.findall(each_folder_name)
if len(each_folder_name) != 0:
aLL_topk_folders.append(each_folder_name[0])
assert len(aLL_topk_folders) == 2

last_save_path = all_saved_model_paths["trainer-last"]
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]]

assert len(all_saved_model_paths) == 3

all_state_dicts = [last_save_path, topk_save_path]

for folder in all_state_dicts:
model_and_optimizers.model = TorchNormalModel_Classification_1(
num_labels=ArgMaxDatasetConfig.num_labels,
feature_dimension=ArgMaxDatasetConfig.feature_dimension
)

trainer = Trainer(
model=model_and_optimizers.model,
driver=driver,
device=device,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,

n_epochs=13,
output_from_new_proc="all"
)
trainer.load_checkpoint(folder, only_state_dict=True)

trainer.run()
trainer.driver.barrier()

finally:
rank_zero_rm(path)

if dist.is_initialized():
dist.destroy_process_group()

Loading…
Cancel
Save