@@ -51,7 +51,6 @@ class Saver: | |||
self.timestamp_path = self.folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) | |||
@rank_zero_call | |||
def save(self, trainer, folder_name): | |||
""" | |||
执行保存的函数,将数据保存在:: | |||
@@ -66,6 +65,7 @@ class Saver: | |||
""" | |||
folder = self.timestamp_path.joinpath(folder_name) | |||
folder.mkdir(parents=True, exist_ok=True) | |||
save_fn = getattr(trainer, self.save_fn_name) | |||
save_fn( | |||
folder=folder, | |||
@@ -217,7 +217,7 @@ class TopkSaver(ResultsMonitor, Saver): | |||
self.topk_queue = TopkQueue(topk) | |||
self.save_evaluate_results = save_evaluate_results | |||
@rank_zero_call | |||
# 注意这里我们为了支持 torch_fsdp 去除了 ''@rank_zero_call''; | |||
def save_topk(self, trainer, results: Dict) -> Optional[str]: | |||
""" | |||
根据 ``results`` 是否满足 topk 的相关设定决定是否保存,如果发生了保存,将返回保存的文件夹。如果返回为 ``None`` ,则说明此次没有满足 | |||
@@ -30,7 +30,7 @@ def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, | |||
else: | |||
raise ValueError(f"Cannot choose driver automatically based on model, please set `driver` specifically.") | |||
if driver in {"torch", "fairscale", "deepspeed"}: | |||
if driver in {"torch", "fairscale", "deepspeed", "torch_fsdp"}: | |||
from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver | |||
return initialize_torch_driver(driver, device, model, **kwargs) | |||
elif driver in {"jittor"}: | |||
@@ -309,9 +309,9 @@ class TorchDDPDriver(TorchDriver): | |||
self.world_size = None # int(os.environ.get("WORLD_SIZE")) len(self.parallel_device) | |||
self.global_rank = 0 | |||
self._ddp_kwargs = self._torch_kwargs.get("ddp_kwargs", {}) | |||
check_user_specific_params(self._ddp_kwargs, DistributedDataParallel.__init__, DistributedDataParallel.__name__) | |||
if len(self.model._buffers) != 0 and self._ddp_kwargs.get("broadcast_buffers", None) is None: | |||
self._fsdp_kwargs = self._torch_kwargs.get("ddp_kwargs", {}) | |||
check_user_specific_params(self._fsdp_kwargs, DistributedDataParallel.__init__, DistributedDataParallel.__name__) | |||
if len(self.model._buffers) != 0 and self._fsdp_kwargs.get("broadcast_buffers", None) is None: | |||
logger.info("Notice your model has buffers and you are using `TorchDDPDriver`, but you do not set " | |||
"'broadcast_buffers' in your trainer. Cause in most situations, this parameter can be set" | |||
" to 'False' to avoid redundant data communication between different processes.") | |||
@@ -381,8 +381,6 @@ class TorchDDPDriver(TorchDriver): | |||
self.global_rank = dist.get_rank() | |||
if not self.outside_ddp: | |||
torch.cuda.set_device(self.model_device) | |||
self.model.to(self.model_device) | |||
self.configure_ddp() | |||
self.barrier() | |||
@@ -400,11 +398,13 @@ class TorchDDPDriver(TorchDriver): | |||
self._pids = self.tensor_to_numeric(self._pids) | |||
def configure_ddp(self): | |||
torch.cuda.set_device(self.model_device) | |||
self.model.to(self.model_device) | |||
if not isinstance(self.model, DistributedDataParallel): | |||
self.model = DistributedDataParallel( | |||
# 注意这里的 self.model_device 是 `torch.device` type,因此 self.model_device.index; | |||
_DDPWrappingModel(self.model), device_ids=[self.model_device.index], | |||
**self._ddp_kwargs | |||
**self._fsdp_kwargs | |||
) | |||
self._has_ddpwrapped = True | |||
@@ -505,6 +505,12 @@ class TorchDDPDriver(TorchDriver): | |||
raise RuntimeError(f"The `{fn}` attribute of model is not `Callable`.") | |||
return fn, None | |||
elif fn in {"train_step", "evaluate_step"}: | |||
logger.warning("\n\nfucking hei\n\n") | |||
print(model) | |||
print("\n\n") | |||
print(type(model)) | |||
print("\n\n") | |||
return model, model.forward | |||
else: | |||
raise RuntimeError(f"There is no `{fn}` method in your model.") | |||
@@ -9,6 +9,7 @@ from .single_device import TorchSingleDriver | |||
from .ddp import TorchDDPDriver | |||
from .fairscale import FairScaleDriver | |||
from .deepspeed import DeepSpeedDriver | |||
from .torch_fsdp import TorchFSDPDriver | |||
from fastNLP.core.log import logger | |||
from fastNLP.envs import FASTNLP_BACKEND_LAUNCH | |||
from pkg_resources import parse_version | |||
@@ -45,7 +46,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi | |||
return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), | |||
is_pull_by_torch_run=True, **kwargs) | |||
if driver not in {"torch", "fairscale", "deepspeed"}: | |||
if driver not in {"torch", "fairscale", "deepspeed", "torch_fsdp"}: | |||
raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'fairscale'].") | |||
_could_use_device_num = torch.cuda.device_count() | |||
@@ -95,4 +96,12 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi | |||
logger.warning_once("Notice you are using `deepspeed`, but the `device` is only one gpu.") | |||
return DeepSpeedDriver(model, [device], **kwargs) | |||
else: | |||
return DeepSpeedDriver(model, device, **kwargs) | |||
return DeepSpeedDriver(model, device, **kwargs) | |||
elif driver == "torch_fsdp": | |||
if not isinstance(device, List): | |||
if device.type == 'cpu': | |||
raise ValueError("You are using `torch_fsdp` driver, but your chosen `device` is 'cpu'.") | |||
logger.warning_once("Notice you are using `torch_fsdp`, but the `device` is only one gpu.") | |||
return TorchFSDPDriver(model, [device], **kwargs) | |||
else: | |||
return TorchFSDPDriver(model, device, **kwargs) |
@@ -27,7 +27,7 @@ from .utils import optimizer_state_to_device | |||
from fastNLP.core.drivers.driver import Driver | |||
from fastNLP.core.drivers.torch_driver.utils import _build_fp16_env, DummyGradScaler | |||
from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device | |||
from fastNLP.envs import rank_zero_call | |||
from fastNLP.envs import rank_zero_call | |||
from fastNLP.envs import FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | |||
from fastNLP.core.log import logger | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, ReproduceBatchSampler, RandomSampler | |||
@@ -0,0 +1,341 @@ | |||
from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_12 | |||
if _TORCH_GREATER_EQUAL_1_12: | |||
from torch.distributed.fsdp import FullyShardedDataParallel, StateDictType, FullStateDictConfig, OptimStateKeyType | |||
import os | |||
import torch | |||
import torch.distributed as dist | |||
from torch.nn.parallel import DistributedDataParallel | |||
from typing import Optional, Union, List, Dict, Mapping | |||
from pathlib import Path | |||
from .ddp import TorchDDPDriver | |||
from fastNLP.core.drivers.torch_driver.utils import ( | |||
_DDPWrappingModel, | |||
) | |||
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME, \ | |||
FASTNLP_GLOBAL_RANK, rank_zero_call | |||
from fastNLP.core.drivers.torch_driver.utils import DummyGradScaler | |||
from fastNLP.core.log import logger | |||
from fastNLP.core.utils import check_user_specific_params | |||
from .utils import optimizer_state_to_device | |||
""" | |||
参考文档: | |||
1. https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/ | |||
2. https://pytorch.org/docs/stable/fsdp.html?highlight=fsdp | |||
3. https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html | |||
4. https://engineering.fb.com/2021/07/15/open-source/fsdp/ | |||
""" | |||
class TorchFSDPDriver(TorchDDPDriver): | |||
r""" | |||
实现对于 pytorch 自己实现的 fully sharded data parallel;请阅读该文档了解更多: | |||
https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.full_optim_state_dict; | |||
..note:: | |||
``TorchFSDPDriver`` 大部分行为与 ``TorchDDPDriver`` 相同,如果您不了解 ``TorchDDPDriver``, | |||
您可以先阅读 :class:`~fastNLP.core.drivers.TorchDDPDriver`; | |||
..warning:: | |||
``TorchFSDPDriver`` 现在还不支持断点重训功能,但是支持保存模型和加载模型; | |||
""" | |||
def __init__( | |||
self, | |||
model, | |||
parallel_device: Optional[Union[List["torch.device"], "torch.device"]], | |||
is_pull_by_torch_run: bool = False, | |||
fp16: bool = False, | |||
torch_kwargs: Dict = None, | |||
**kwargs | |||
): | |||
# 在加入很多东西后,需要注意这里调用 super 函数的位置; | |||
super(TorchDDPDriver, self).__init__(model, fp16=fp16, torch_kwargs=torch_kwargs, **kwargs) | |||
if isinstance(model, torch.nn.DataParallel): | |||
raise ValueError(f"Parameter `model` can not be `DataParallel` in `TorchDDPDriver`, it should be " | |||
f"`torch.nn.Module` or `torch.nn.parallel.DistributedDataParallel` type.") | |||
# 如果用户自己在外面初始化 DDP,那么其一定是通过 python -m torch.distributed.launch 拉起的; | |||
self.is_pull_by_torch_run = is_pull_by_torch_run | |||
self.parallel_device = parallel_device | |||
if not is_pull_by_torch_run and parallel_device is None: | |||
raise ValueError( | |||
"Parameter `parallel_device` can not be None when using `TorchDDPDriver`. This error is caused " | |||
"when your value of parameter `device` is `None` in your `Trainer` instance.") | |||
# 注意我们在 initialize_torch_driver 中的逻辑就是如果是 is_pull_by_torch_run,那么我们就直接把 parallel_device 置为当前进程的gpu; | |||
if is_pull_by_torch_run: | |||
self.model_device = parallel_device | |||
else: | |||
# 我们的 model_device 一定是 torch.device,而不是一个 list; | |||
self.model_device = parallel_device[self.local_rank] | |||
# 如果用户自己在外面初始化了 FSDP; | |||
self.outside_ddp = False | |||
if dist.is_initialized() and FASTNLP_DISTRIBUTED_CHECK not in os.environ and \ | |||
"fastnlp_torch_launch_not_ddp" not in os.environ: | |||
# 如果用户自己在外面初始化了 DDP,那么我们要求用户传入的模型一定是已经由 DistributedDataParallel 包裹后的模型; | |||
if not isinstance(model, FullyShardedDataParallel): | |||
raise RuntimeError( | |||
"It is not allowed to input a normal model instead of `FullyShardedDataParallel` when" | |||
"you initialize the ddp process out of our control.") | |||
if isinstance(model, DistributedDataParallel): | |||
logger.warning("You are using `TorchFSDPDriver`, but you have initialized your model as " | |||
"`DistributedDataParallel`, which will make the `FullyShardedDataParallel` not work " | |||
"as expected. You could just delete `DistributedDataParallel` wrap operation.") | |||
self.outside_ddp = True | |||
# 用户只有将模型上传到对应机器上后才能用 DistributedDataParallel 包裹,因此如果用户在外面初始化了 DDP,那么在 TorchDDPDriver 中 | |||
# 我们就直接将 model_device 置为 None; | |||
self.model_device = None | |||
# 当用户自己在外面初始化 DDP 时我们会将 model_device 置为 None,这是用户可以通过 `data_device` 将对应的数据移到指定的机器上; | |||
self._data_device = kwargs.get("data_device", None) | |||
if isinstance(self._data_device, int): | |||
if self._data_device < 0: | |||
raise ValueError("Parameter `data_device` can not be smaller than 0.") | |||
_could_use_device_num = torch.cuda.device_count() | |||
if self._data_device >= _could_use_device_num: | |||
raise ValueError("The gpu device that parameter `device` specifies is not existed.") | |||
self._data_device = torch.device(f"cuda:{self._data_device}") | |||
elif isinstance(self._data_device, str): | |||
self._data_device = torch.device(self._data_device) | |||
elif self._data_device is not None and not isinstance(self._data_device, torch.device): | |||
raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") | |||
self._master_port = None | |||
# world_size 表示的就是全局的显卡的数量; | |||
self.world_size = None # int(os.environ.get("WORLD_SIZE")) len(self.parallel_device) | |||
self.global_rank = 0 | |||
self._fsdp_kwargs = self._torch_kwargs.get("fsdp_kwargs", {}) | |||
self._save_on_rank0 = self._fsdp_kwargs.get("save_on_rank0", False) | |||
if "save_on_rank0" in self._fsdp_kwargs: | |||
self._fsdp_kwargs.pop("save_on_rank0") | |||
self._load_on_rank0 = self._fsdp_kwargs.get("load_on_rank0", False) | |||
if "load_on_rank0" in self._fsdp_kwargs: | |||
self._fsdp_kwargs.pop("load_on_rank0") | |||
if self._save_on_rank0 != self._load_on_rank0: | |||
logger.warning(f"Notice the behavior between ``save`` and ``load`` is not matched, you choose " | |||
f"{'save on rank0' if self._save_on_rank0 else 'save on each rank'}, but " | |||
f"{'load on rank0' if self._save_on_rank0 else 'load on each rank'}!") | |||
check_user_specific_params(self._fsdp_kwargs, FullyShardedDataParallel.__init__, FullyShardedDataParallel.__name__) | |||
if "cpu_offload" in self._fsdp_kwargs and kwargs["accumulation_steps"] != 1: | |||
logger.warning("It is not supported ``accumulation_steps`` when using ``cpu_offload`` in " | |||
"``FullyShardedDataParallel``.") | |||
self.output_from_new_proc = kwargs.get("output_from_new_proc", "only_error") | |||
assert isinstance(self.output_from_new_proc, str), "Parameter `output_from_new_proc` can only be `str` type." | |||
if self.output_from_new_proc not in {"all", "ignore", "only_error"}: | |||
os.makedirs(name=self.output_from_new_proc, exist_ok=True) | |||
self.output_from_new_proc = os.path.abspath(self.output_from_new_proc) | |||
self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的; | |||
self._has_ddpwrapped = False # 判断传入的模型是否经过 _has_ddpwrapped 包裹; | |||
def configure_ddp(self): | |||
torch.cuda.set_device(self.model_device) | |||
if not isinstance(self.model, FullyShardedDataParallel): | |||
self.model = FullyShardedDataParallel( | |||
# 注意这里的 self.model_device 是 `torch.device` type,因此 self.model_device.index; | |||
_DDPWrappingModel(self.model), device_id=self.model_device.index, | |||
**self._fsdp_kwargs | |||
) | |||
# 必须先使用 FullyShardedDataParallel 包裹模型后再使用 optimizer 包裹模型的参数,因此这里需要将 optimizer 重新初始化一遍; | |||
for i in range(len(self.optimizers)): | |||
self.optimizers[i] = type(self.optimizers[i])(self.model.parameters(), **self.optimizers[i].defaults) | |||
self._has_ddpwrapped = True | |||
def unwrap_model(self): | |||
""" | |||
注意该函数因为需要在特定的时候进行调用,例如 ddp 在 get_model_call_fn 的时候,因此不能够删除; | |||
如果您使用该函数来获取原模型的结构信息,是可以的; | |||
但是如果您想要通过该函数来获取原模型实际的参数,是不可以的,因为在 FullyShardedDataParallel 中模型被切分成了多个部分,而对于每个 gpu 上 | |||
的模型只是整体模型的一部分。 | |||
""" | |||
_module = self.model.module.module | |||
if isinstance(_module, _DDPWrappingModel): | |||
return _module.model | |||
else: | |||
return _module | |||
def save_model(self, filepath: Union[str, Path], only_state_dict: bool = True, **kwargs): | |||
filepath = Path(filepath) | |||
prefix = filepath.parent | |||
filename = filepath.name | |||
_filename = filename.split('.') | |||
filename, suffix = _filename[0], '.'.join(_filename[1:]) | |||
if only_state_dict: | |||
if self._save_on_rank0: | |||
full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) | |||
with FullyShardedDataParallel.state_dict_type(self.model, StateDictType.FULL_STATE_DICT, full_state_dict_config): | |||
state_dict = self.model.state_dict() | |||
rank_zero_call(torch.save)(state_dict, filepath) | |||
else: | |||
# 添加 'rank0/1' 字段来区分全部聚集到 rank0 保存的方式; | |||
_filename = filename.split('_') | |||
filename = _filename[0] + f"_rank{int(os.environ.get(FASTNLP_GLOBAL_RANK, 0))}_" + _filename[1] | |||
filepath = prefix.joinpath(filename + "." + suffix) | |||
with FullyShardedDataParallel.state_dict_type(self.model, StateDictType.LOCAL_STATE_DICT): | |||
state_dict = self.model.state_dict() | |||
torch.save(state_dict, filepath) | |||
else: | |||
raise RuntimeError("When using `TorchFSDPDriver`, only `only_state_dict=True` is allowed.") | |||
def load_model(self, filepath: Union[Path, str], only_state_dict: bool = True, **kwargs): | |||
if only_state_dict is False: | |||
raise RuntimeError("When using `TorchFSDPDriver`, only `only_state_dict=True` is allowed.") | |||
filepath = Path(filepath) | |||
prefix = filepath.parent | |||
filename = filepath.name | |||
_filename = filename.split('.') | |||
filename, suffix = _filename[0], '.'.join(_filename[1:]) | |||
if not self._load_on_rank0: | |||
_filename = filename.split('_') | |||
filename = _filename[0] + f"_rank{int(os.environ.get(FASTNLP_GLOBAL_RANK, 0))}_" + _filename[1] | |||
filepath = prefix.joinpath(filename + "." + suffix) | |||
states = torch.load(filepath) | |||
else: | |||
states = torch.load(filepath, map_location="cpu") | |||
if isinstance(states, dict) and only_state_dict is False: | |||
logger.rank_zero_warning(f"It seems like that {filepath} only contains state, you may need to use " | |||
f"`only_state_dict=True`") | |||
elif not isinstance(states, dict) and only_state_dict is True: | |||
logger.rank_zero_warning(f"It seems like that {filepath} is not state, you may need to use " | |||
f"`only_state_dict=False`") | |||
if not isinstance(states, Mapping): | |||
states = states.state_dict() | |||
if self._load_on_rank0: | |||
with FullyShardedDataParallel.state_dict_type(self.model, StateDictType.FULL_STATE_DICT): | |||
self.model.load_state_dict(states) | |||
else: | |||
with FullyShardedDataParallel.state_dict_type(self.model, StateDictType.LOCAL_STATE_DICT): | |||
self.model.load_state_dict(states) | |||
def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||
raise RuntimeError("``TorchFSDPDriver`` does not support ``save_checkpoint`` function for now, there is some " | |||
"technical issues that needs to solve. You can implement your own breakpoint retraining " | |||
"by rewriting this function. The important thing is how to save and load the optimizers' state dict, " | |||
"you can see ``https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.full_optim_state_dict``.") | |||
def load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: | |||
raise RuntimeError("``TorchFSDPDriver`` does not support ``load_checkpoint`` function for now, there is some " | |||
"technical issues that needs to solve. You can implement your own breakpoint retraining " | |||
"by rewriting this function. The important thing is how to save and load the optimizers' state dict, " | |||
"you can see ``https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.full_optim_state_dict``.") | |||
# todo 这些加了 __ 的函数是目前还不支持; | |||
# 这是因为 1.12 的 pytorch fsdp 的关于如何保存和加载 optimizer state dict 的接口有点过于反人类,无法在 fastNLP 的框架中进行调和 | |||
# 使用; | |||
def __get_optimizer_state(self): | |||
optimizers_state_dict = {} | |||
for i in range(len(self.optimizers)): | |||
# 注意这里其余 rank 拿到的是一个空字典,因此在真正保存的时候需要保证只有 rank0 在工作; | |||
optimizer_state = FullyShardedDataParallel.full_optim_state_dict(self.model, self.optimizers[i]) | |||
if self._save_on_rank0: | |||
with FullyShardedDataParallel.summon_full_params(self.model): | |||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | |||
unwrapped_model = self.model.module.module | |||
optimizer_state = FullyShardedDataParallel.rekey_optim_state_dict( | |||
optimizer_state, OptimStateKeyType.PARAM_ID, unwrapped_model) | |||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | |||
optimizer_state["state"] = optimizer_state_to_device(optimizer_state["state"], torch.device("cpu")) | |||
optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; | |||
return optimizers_state_dict | |||
# 这里单独拿出来是因为对于 fsdp 来说,每一个进程都需要运行此函数,因此不能包裹 rank_zero_call; | |||
def __save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||
if not only_state_dict: | |||
raise RuntimeError("When using `TorchFSDPDriver`, only `only_state_dict=True` is allowed.") | |||
# 1. sampler 的状态; | |||
num_consumed_batches = states.pop('num_consumed_batches') | |||
states['sampler_states'] = self.get_sampler_state(dataloader, num_consumed_batches) | |||
# 2. 保存模型的状态; | |||
if should_save_model: | |||
if not os.path.exists(folder): | |||
os.mkdir(folder) | |||
model_path = folder.joinpath(FASTNLP_MODEL_FILENAME) | |||
self.save_model(model_path, only_state_dict=True) | |||
# 3. 保存 optimizers 的状态; | |||
states["optimizers_state_dict"] = self.get_optimizer_state() | |||
logger.debug("Save optimizer state dict.") | |||
# 4. 保存fp16的状态 | |||
if not isinstance(self.grad_scaler, DummyGradScaler): | |||
grad_scaler_state_dict = self.grad_scaler.state_dict() | |||
states['grad_scaler_state_dict'] = grad_scaler_state_dict | |||
# 确保只有 rank0 才会执行实际的保存操作; | |||
rank_zero_call(torch.save)(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) | |||
def __load_optimizer_state(self, states): | |||
assert len(states) == len(self.optimizers), f"The number of optimizers is:{len(self.optimizers)}, while in " \ | |||
f"checkpoint it is:{len(states)}" | |||
with FullyShardedDataParallel.summon_full_params(self.model): | |||
unwrapped_model = self.model.module.module | |||
for i in range(len(self.optimizers)): | |||
optimizer_state = states[f'optimizer{i}'] | |||
if self._load_on_rank0: | |||
optimizer_state = FullyShardedDataParallel.rekey_optim_state_dict(optimizer_state, OptimStateKeyType.PARAM_NAME, unwrapped_model) | |||
optimizer_state = FullyShardedDataParallel.shard_full_optim_state_dict(optimizer_state, unwrapped_model) | |||
optimizer: torch.optim.Optimizer = type(self.optimizers[i])(unwrapped_model.parameters(), **self.optimizers[i].defaults) | |||
optimizer.load_state_dict(optimizer_state) | |||
self.optimizers[i] = optimizer | |||
logger.debug("Load optimizer state dict.") | |||
def __load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: | |||
if not only_state_dict: | |||
raise RuntimeError("When using `TorchFSDPDriver`, only `only_state_dict=True` is allowed.") | |||
states = torch.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)) | |||
# 1. 加载 optimizers 的状态; | |||
optimizers_state_dict = states.pop("optimizers_state_dict") | |||
self.load_optimizer_state(optimizers_state_dict) | |||
# 2. 加载模型状态; | |||
if should_load_model: | |||
self.load_model(filepath=folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict=only_state_dict) | |||
# 3. 加载 fp16 的状态 | |||
if "grad_scaler_state_dict" in states: | |||
grad_scaler_state_dict = states.pop("grad_scaler_state_dict") | |||
if not isinstance(self.grad_scaler, DummyGradScaler): | |||
self.grad_scaler.load_state_dict(grad_scaler_state_dict) | |||
logger.debug("Load grad_scaler state dict...") | |||
elif not isinstance(self.grad_scaler, DummyGradScaler): | |||
logger.rank_zero_warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, " | |||
f"the training process may be unstable.") | |||
# 4. 恢复 sampler 的状态; | |||
sampler_states = states.pop('sampler_states') | |||
states_ret = self.load_sampler_state(dataloader, sampler_states) | |||
states.update(states_ret) | |||
return states | |||
@@ -26,3 +26,4 @@ _NEED_IMPORT_DEEPSPEED = _module_available("deepspeed") and 'torch' in need_impo | |||
_NEED_IMPORT_ONEFLOW = _module_available("oneflow") and 'oneflow' in need_import | |||
_TORCH_GREATER_EQUAL_1_8 = _NEED_IMPORT_TORCH and _compare_version("torch", operator.ge, "1.8.0") | |||
_TORCH_GREATER_EQUAL_1_12 = _NEED_IMPORT_TORCH and _compare_version("torch", operator.ge, "1.12.0") |
@@ -75,7 +75,7 @@ def model_and_optimizers(request): | |||
@pytest.mark.torch | |||
@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||
@pytest.mark.parametrize("driver,device", [("torch", [4, 5])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||
@magic_argv_env_context(timeout=100) | |||
def test_model_checkpoint_callback_1( | |||
model_and_optimizers: TrainerParameters, | |||
@@ -103,8 +103,8 @@ def model_and_optimizers(request): | |||
# 测试一下普通的情况; | |||
@pytest.mark.torch | |||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), | |||
("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1]) | |||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 4), | |||
("torch", [4, 5])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1]) | |||
@pytest.mark.parametrize("evaluate_every", [-3, -1, 2]) | |||
@magic_argv_env_context | |||
def test_trainer_torch_with_evaluator( | |||
@@ -139,7 +139,7 @@ def test_trainer_torch_with_evaluator( | |||
@pytest.mark.torch | |||
@pytest.mark.parametrize("driver,device", [("torch", [0, 1]), ("torch", 1)]) # ("torch", [0, 1]),("torch", 1) | |||
@pytest.mark.parametrize("driver,device", [("torch", [4, 5]), ("torch", 4)]) # ("torch", [0, 1]),("torch", 1) | |||
@pytest.mark.parametrize("fp16", [True, False]) | |||
@pytest.mark.parametrize("accumulation_steps", [1, 3]) | |||
@magic_argv_env_context | |||
@@ -250,7 +250,7 @@ def test_trainer_on( | |||
@pytest.mark.torch | |||
@pytest.mark.parametrize("driver,device", [("torch", 'cpu'), ("torch", 0)]) # ("torch", [0, 1]),("torch", 1) | |||
@pytest.mark.parametrize("driver,device", [("torch", 'cpu'), ("torch", 4)]) # ("torch", [0, 1]),("torch", 1) | |||
@magic_argv_env_context | |||
def test_trainer_specific_params_1( | |||
model_and_optimizers: TrainerParameters, | |||
@@ -291,7 +291,7 @@ def test_trainer_specific_params_1( | |||
@pytest.mark.torch | |||
@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1) | |||
@pytest.mark.parametrize("driver,device", [("torch", [4, 5])]) # ("torch", [0, 1]),("torch", 1) | |||
@magic_argv_env_context | |||
def test_trainer_specific_params_2( | |||
model_and_optimizers: TrainerParameters, | |||
@@ -331,7 +331,7 @@ def test_trainer_specific_params_2( | |||
assert trainer.driver.wo_auto_param_call is True | |||
assert trainer.driver.output_from_new_proc == "all" | |||
_ddp_kwargs = trainer.driver._ddp_kwargs | |||
_ddp_kwargs = trainer.driver._fsdp_kwargs | |||
assert _ddp_kwargs.get("broadcast_buffers") is True | |||
assert _ddp_kwargs.get("find_unused_parameters") is True | |||
@@ -340,7 +340,7 @@ def test_trainer_specific_params_2( | |||
@pytest.mark.torch | |||
@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1) | |||
@pytest.mark.parametrize("driver,device", [("torch", 4), ("torch", [4, 5])]) # ("torch", [0, 1]),("torch", 1) | |||
@pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)]) | |||
@magic_argv_env_context | |||
def test_trainer_w_evaluator_overfit_torch( | |||
@@ -0,0 +1,379 @@ | |||
import os | |||
from dataclasses import dataclass | |||
from typing import Any | |||
from pathlib import Path | |||
import re | |||
import pytest | |||
from fastNLP.core.controllers.trainer import Trainer | |||
from torchmetrics import Accuracy | |||
from fastNLP.core.callbacks import CheckpointCallback | |||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset | |||
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback | |||
from tests.helpers.utils import magic_argv_env_context | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
from fastNLP.envs import FASTNLP_LAUNCH_TIME, rank_zero_rm | |||
if _NEED_IMPORT_TORCH: | |||
import torch.distributed as dist | |||
from torch.optim import SGD | |||
from torch.utils.data import DataLoader | |||
@dataclass | |||
class ArgMaxDatasetConfig: | |||
num_labels: int = 10 | |||
feature_dimension: int = 10 | |||
data_num: int = 50 | |||
seed: int = 0 | |||
batch_size: int = 2 | |||
shuffle: bool = True | |||
@dataclass | |||
class TrainerParameters: | |||
model: Any = None | |||
optimizers: Any = None | |||
train_dataloader: Any = None | |||
evaluate_dataloaders: Any = None | |||
input_mapping: Any = None | |||
output_mapping: Any = None | |||
metrics: Any = None | |||
@pytest.fixture(scope="module", params=[0], autouse=True) | |||
def model_and_optimizers(request): | |||
trainer_params = TrainerParameters() | |||
trainer_params.model = TorchNormalModel_Classification_1( | |||
num_labels=ArgMaxDatasetConfig.num_labels, | |||
feature_dimension=ArgMaxDatasetConfig.feature_dimension | |||
) | |||
trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001) | |||
dataset = TorchArgMaxDataset( | |||
feature_dimension=ArgMaxDatasetConfig.feature_dimension, | |||
data_num=ArgMaxDatasetConfig.data_num, | |||
seed=ArgMaxDatasetConfig.seed | |||
) | |||
_dataloader = DataLoader( | |||
dataset=dataset, | |||
batch_size=ArgMaxDatasetConfig.batch_size, | |||
shuffle=True | |||
) | |||
trainer_params.train_dataloader = _dataloader | |||
trainer_params.evaluate_dataloaders = _dataloader | |||
trainer_params.metrics = {"acc": Accuracy()} | |||
return trainer_params | |||
@pytest.mark.torch | |||
@magic_argv_env_context | |||
def test_trainer_torch_without_evaluator( | |||
model_and_optimizers: TrainerParameters, | |||
n_epochs=3, | |||
): | |||
callbacks = [RecordLossCallback(loss_threshold=0.5)] | |||
trainer = Trainer( | |||
model=model_and_optimizers.model, | |||
driver="torch_fsdp", | |||
device=[4, 5], | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
n_epochs=3, | |||
callbacks=callbacks, | |||
output_from_new_proc="all" | |||
) | |||
trainer.run() | |||
if dist.is_initialized(): | |||
dist.destroy_process_group() | |||
@pytest.mark.torch | |||
@pytest.mark.parametrize("driver,device", [("torch_fsdp", [4, 5])]) | |||
@magic_argv_env_context(timeout=100) | |||
def test_model_checkpoint_callback_1( | |||
model_and_optimizers: TrainerParameters, | |||
driver, | |||
device | |||
): | |||
for version in [0]: | |||
# 需要在每一个循环开始重新初始化 model,是因为 fsdp 会将当前卡上的 model 删除,从而导致这个引用实际上引用到的是一个空模型; | |||
model_and_optimizers.model = TorchNormalModel_Classification_1( | |||
num_labels=ArgMaxDatasetConfig.num_labels, | |||
feature_dimension=ArgMaxDatasetConfig.feature_dimension | |||
) | |||
try: | |||
path = Path.cwd().joinpath(f"test_model_checkpoint") | |||
path.mkdir(exist_ok=True, parents=True) | |||
if version == 0: | |||
callbacks = [ | |||
CheckpointCallback(folder=path, every_n_epochs=1, every_n_batches=123, last=False, on_exceptions=None, topk=0, | |||
monitor=None, only_state_dict=True, save_object='model') | |||
] | |||
elif version == 1: | |||
callbacks = [ | |||
CheckpointCallback(folder=path, every_n_epochs=3, every_n_batches=None, last=True, on_exceptions=None, topk=2, | |||
monitor="acc", only_state_dict=True, save_object='model') | |||
] | |||
trainer = Trainer( | |||
model=model_and_optimizers.model, | |||
driver=driver, | |||
device=device, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
n_epochs=10, | |||
callbacks=callbacks, | |||
output_from_new_proc="all", | |||
# torch_kwargs={"fsdp_kwargs": {'save_on_rank0': True}} | |||
) | |||
trainer.run() | |||
print("Finish train") | |||
all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | |||
# 检查生成保存模型文件的数量是不是正确的; | |||
if version == 0: | |||
if not isinstance(device, list): | |||
assert "model-epoch_10" in all_saved_model_paths | |||
assert "model-epoch_4-batch_123" in all_saved_model_paths | |||
epoch_save_path = all_saved_model_paths["model-epoch_10"] | |||
step_save_path = all_saved_model_paths["model-epoch_4-batch_123"] | |||
assert len(all_saved_model_paths) == 12 | |||
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | |||
else: | |||
assert "model-epoch_6" in all_saved_model_paths | |||
assert "model-epoch_9-batch_123" in all_saved_model_paths | |||
epoch_save_path = all_saved_model_paths["model-epoch_6"] | |||
step_save_path = all_saved_model_paths["model-epoch_9-batch_123"] | |||
assert len(all_saved_model_paths) == 11 | |||
all_state_dicts = [epoch_save_path]#, step_save_path] | |||
elif version == 1: | |||
pattern = re.compile("model-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*") | |||
if not isinstance(device, list): | |||
assert "model-epoch_9" in all_saved_model_paths | |||
assert "model-last" in all_saved_model_paths | |||
aLL_topk_folders = [] | |||
for each_folder_name in all_saved_model_paths: | |||
each_folder_name = pattern.findall(each_folder_name) | |||
if len(each_folder_name) != 0: | |||
aLL_topk_folders.append(each_folder_name[0]) | |||
assert len(aLL_topk_folders) == 2 | |||
epoch_save_path = all_saved_model_paths["model-epoch_9"] | |||
last_save_path = all_saved_model_paths["model-last"] | |||
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] | |||
assert len(all_saved_model_paths) == 6 | |||
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | |||
else: | |||
assert "model-epoch_9" in all_saved_model_paths | |||
assert "model-last" in all_saved_model_paths | |||
aLL_topk_folders = [] | |||
for each_folder_name in all_saved_model_paths: | |||
each_folder_name = pattern.findall(each_folder_name) | |||
if len(each_folder_name) != 0: | |||
aLL_topk_folders.append(each_folder_name[0]) | |||
assert len(aLL_topk_folders) == 2 | |||
epoch_save_path = all_saved_model_paths["model-epoch_9"] | |||
last_save_path = all_saved_model_paths["model-last"] | |||
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] | |||
assert len(all_saved_model_paths) == 6 | |||
all_state_dicts = [epoch_save_path, last_save_path, topk_save_path] | |||
for folder in all_state_dicts: | |||
model_and_optimizers.model = TorchNormalModel_Classification_1( | |||
num_labels=ArgMaxDatasetConfig.num_labels, | |||
feature_dimension=ArgMaxDatasetConfig.feature_dimension | |||
) | |||
trainer = Trainer( | |||
model=model_and_optimizers.model, | |||
driver=driver, | |||
device=device, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
n_epochs=20, | |||
output_from_new_proc="all", | |||
) | |||
trainer.load_model(folder, only_state_dict=True) | |||
trainer.run() | |||
trainer.driver.barrier() | |||
finally: | |||
rank_zero_rm(path) | |||
if dist.is_initialized(): | |||
dist.destroy_process_group() | |||
@pytest.mark.skip("现在 fsdp 还不支持断点重训;") | |||
@pytest.mark.torch | |||
@pytest.mark.parametrize("driver,device", [("torch_fsdp", [6, 7])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||
@magic_argv_env_context(timeout=100) | |||
def test_trainer_checkpoint_callback_1( | |||
model_and_optimizers: TrainerParameters, | |||
driver, | |||
device | |||
): | |||
for version in [0, 1]: | |||
model_and_optimizers.model = TorchNormalModel_Classification_1( | |||
num_labels=ArgMaxDatasetConfig.num_labels, | |||
feature_dimension=ArgMaxDatasetConfig.feature_dimension | |||
) | |||
try: | |||
path = Path.cwd().joinpath(f"test_model_checkpoint") | |||
path.mkdir(exist_ok=True, parents=True) | |||
if version == 0: | |||
callbacks = [ | |||
CheckpointCallback(folder=path, every_n_epochs=7, every_n_batches=123, last=False, on_exceptions=None, topk=0, | |||
monitor=None, only_state_dict=True, save_object='trainer') | |||
] | |||
elif version == 1: | |||
callbacks = [ | |||
CheckpointCallback(folder=path, every_n_epochs=None, every_n_batches=None, last=True, on_exceptions=None, | |||
topk=2, monitor="acc", only_state_dict=True, save_object='trainer') | |||
] | |||
trainer = Trainer( | |||
model=model_and_optimizers.model, | |||
driver=driver, | |||
device=device, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
n_epochs=10, | |||
callbacks=callbacks, | |||
output_from_new_proc="all" | |||
) | |||
trainer.run() | |||
all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | |||
# 检查生成保存模型文件的数量是不是正确的; | |||
if version == 0: | |||
if not isinstance(device, list): | |||
assert "trainer-epoch_7" in all_saved_model_paths | |||
assert "trainer-epoch_4-batch_123" in all_saved_model_paths | |||
epoch_save_path = all_saved_model_paths["trainer-epoch_7"] | |||
step_save_path = all_saved_model_paths["trainer-epoch_4-batch_123"] | |||
assert len(all_saved_model_paths) == 3 | |||
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | |||
else: | |||
assert "trainer-epoch_7" in all_saved_model_paths | |||
assert "trainer-epoch_9-batch_123" in all_saved_model_paths | |||
epoch_save_path = all_saved_model_paths["trainer-epoch_7"] | |||
step_save_path = all_saved_model_paths["trainer-epoch_9-batch_123"] | |||
assert len(all_saved_model_paths) == 2 | |||
all_state_dicts = [epoch_save_path, step_save_path] | |||
elif version == 1: | |||
pattern = re.compile("trainer-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*") | |||
# all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | |||
if not isinstance(device, list): | |||
assert "trainer-last" in all_saved_model_paths | |||
aLL_topk_folders = [] | |||
for each_folder_name in all_saved_model_paths: | |||
each_folder_name = pattern.findall(each_folder_name) | |||
if len(each_folder_name) != 0: | |||
aLL_topk_folders.append(each_folder_name[0]) | |||
assert len(aLL_topk_folders) == 2 | |||
last_save_path = all_saved_model_paths["trainer-last"] | |||
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] | |||
assert len(all_saved_model_paths) == 3 | |||
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | |||
else: | |||
assert "trainer-last" in all_saved_model_paths | |||
aLL_topk_folders = [] | |||
for each_folder_name in all_saved_model_paths: | |||
each_folder_name = pattern.findall(each_folder_name) | |||
if len(each_folder_name) != 0: | |||
aLL_topk_folders.append(each_folder_name[0]) | |||
assert len(aLL_topk_folders) == 2 | |||
last_save_path = all_saved_model_paths["trainer-last"] | |||
topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] | |||
assert len(all_saved_model_paths) == 3 | |||
all_state_dicts = [last_save_path, topk_save_path] | |||
for folder in all_state_dicts: | |||
model_and_optimizers.model = TorchNormalModel_Classification_1( | |||
num_labels=ArgMaxDatasetConfig.num_labels, | |||
feature_dimension=ArgMaxDatasetConfig.feature_dimension | |||
) | |||
trainer = Trainer( | |||
model=model_and_optimizers.model, | |||
driver=driver, | |||
device=device, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
n_epochs=13, | |||
output_from_new_proc="all" | |||
) | |||
trainer.load_checkpoint(folder, only_state_dict=True) | |||
trainer.run() | |||
trainer.driver.barrier() | |||
finally: | |||
rank_zero_rm(path) | |||
if dist.is_initialized(): | |||
dist.destroy_process_group() |