| @@ -54,18 +54,9 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
| if model_save_fn is not None: | if model_save_fn is not None: | ||||
| assert save_folder is not None, "When passing `model_save_fn`, `save_folder` must be provided." | assert save_folder is not None, "When passing `model_save_fn`, `save_folder` must be provided." | ||||
| if save_folder is not None: | |||||
| if save_folder: | |||||
| if os.path.exists(save_folder): | if os.path.exists(save_folder): | ||||
| assert os.path.isdir(save_folder), f"`save_folder` must be a directory." | |||||
| else: | |||||
| os.makedirs(save_folder, exist_ok=True) | |||||
| save_folder = os.path.join(save_folder, os.environ.get(FASTNLP_LAUNCH_TIME)) | |||||
| self.real_save_folder = os.path.join(save_folder, 'best_so_far') | |||||
| if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | |||||
| os.makedirs(self.real_save_folder, exist_ok=True) | |||||
| else: # 创建出一个 stringio | |||||
| self.real_save_folder = None | |||||
| self.buffer = BytesIO() | |||||
| assert os.path.isdir(save_folder), f"`save_folder={save_folder}` must be a directory." | |||||
| self.save_folder = save_folder | self.save_folder = save_folder | ||||
| self.only_state_dict = only_state_dict | self.only_state_dict = only_state_dict | ||||
| @@ -73,21 +64,37 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
| self.model_load_fn = model_load_fn | self.model_load_fn = model_load_fn | ||||
| self.delete_after_after = delete_after_train | self.delete_after_after = delete_after_train | ||||
| def on_after_trainer_initialized(self, trainer, driver): | |||||
| if self.save_folder is not None and driver.is_distributed() and int(os.environ.get(FASTNLP_BACKEND_LAUNCH, 0))==1: | |||||
| # 如果需要保存,但是又是不是 fastNLP 拉起的, 需要同步一下 folder | |||||
| try: | |||||
| self.real_save_folder = driver.broadcast_object(self.real_save_folder, src=0, group=None) | |||||
| logger.debug(f"Synchronize best model save folder: {self.real_save_folder} for LoadBestModelCallback.") | |||||
| except NotImplementedError: | |||||
| raise RuntimeError(f"Currently {driver.__class__.__name__} does not support using `save_folder` to " | |||||
| f"save best model when launch using module.") | |||||
| def prepare_save_folder(self, trainer): | |||||
| if not hasattr(self, 'real_save_folder'): | |||||
| if self.save_folder is not None: | |||||
| if not os.path.exists(self.save_folder): | |||||
| os.makedirs(self.save_folder, exist_ok=True) | |||||
| self.save_folder = os.path.join(self.save_folder, os.environ.get(FASTNLP_LAUNCH_TIME)) | |||||
| self.real_save_folder = os.path.join(self.save_folder, 'best_so_far') | |||||
| if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | |||||
| os.makedirs(self.real_save_folder, exist_ok=True) | |||||
| if self.save_folder is not None and trainer.driver.is_distributed() and int( | |||||
| os.environ.get(FASTNLP_BACKEND_LAUNCH, 0)) == 1: | |||||
| trainer.driver.barrier() | |||||
| try: | |||||
| self.real_save_folder = trainer.driver.broadcast_object(self.real_save_folder, src=0, group=None) | |||||
| logger.debug( | |||||
| f"Synchronize best model save folder: {self.real_save_folder} for LoadBestModelCallback.") | |||||
| except NotImplementedError: | |||||
| raise RuntimeError( | |||||
| f"Currently {trainer.driver.__class__.__name__} does not support using `save_folder` to " | |||||
| f"save best model when launch using module.") | |||||
| else: # 创建出一个 stringio | |||||
| self.real_save_folder = None | |||||
| self.buffer = BytesIO() | |||||
| def on_after_trainer_initialized(self, trainer, driver): | |||||
| super().on_after_trainer_initialized(trainer, driver) | super().on_after_trainer_initialized(trainer, driver) | ||||
| self.encounter_exception = False | self.encounter_exception = False | ||||
| def on_evaluate_end(self, trainer, results): | def on_evaluate_end(self, trainer, results): | ||||
| if self.is_better_results(results, keep_if_better=True): | if self.is_better_results(results, keep_if_better=True): | ||||
| self.prepare_save_folder(trainer) | |||||
| if self.real_save_folder: | if self.real_save_folder: | ||||
| trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | ||||
| model_save_fn=self.model_save_fn) | model_save_fn=self.model_save_fn) | ||||
| @@ -103,8 +110,7 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
| trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | ||||
| model_load_fn=self.model_load_fn) | model_load_fn=self.model_load_fn) | ||||
| else: | else: | ||||
| logger.info( | |||||
| f"Loading best model from buffer with {self.monitor_name}: {self.monitor_value}...") | |||||
| logger.info(f"Loading best model from buffer with {self.monitor_name}: {self.monitor_value}...") | |||||
| self.buffer.seek(0) | self.buffer.seek(0) | ||||
| trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict) | trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict) | ||||
| if self.delete_after_after: | if self.delete_after_after: | ||||
| @@ -119,7 +125,7 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
| self.encounter_exception = True | self.encounter_exception = True | ||||
| def _delete_folder(self): | def _delete_folder(self): | ||||
| if self.real_save_folder: | |||||
| if getattr(self, 'real_save_folder', None): | |||||
| logger.info(f"Deleting {self.real_save_folder}...") | logger.info(f"Deleting {self.real_save_folder}...") | ||||
| shutil.rmtree(self.real_save_folder, ignore_errors=True) | shutil.rmtree(self.real_save_folder, ignore_errors=True) | ||||
| try: | try: | ||||
| @@ -3,7 +3,11 @@ __all__ = [ | |||||
| ] | ] | ||||
| from typing import Union, List | from typing import Union, List | ||||
| from ..callback import Callback | from ..callback import Callback | ||||
| from ...drivers.torch_driver.fairscale import FairScaleDriver | |||||
| from ...drivers.torch_driver import TorchDriver | |||||
| from fastNLP.envs.imports import _NEED_IMPORT_FAIRSCALE | |||||
| if _NEED_IMPORT_FAIRSCALE: | |||||
| from fairscale.nn import FullyShardedDataParallel | |||||
| class TorchGradClipCallback(Callback): | class TorchGradClipCallback(Callback): | ||||
| r""" | r""" | ||||
| @@ -35,15 +39,20 @@ class TorchGradClipCallback(Callback): | |||||
| else: | else: | ||||
| self.parameters = None | self.parameters = None | ||||
| self.clip_value = clip_value | self.clip_value = clip_value | ||||
| self.clip_type = clip_type | |||||
| def on_after_trainer_initialized(self, trainer, driver): | def on_after_trainer_initialized(self, trainer, driver): | ||||
| assert 'torch' in driver.__class__.__name__.lower(), f"Callback:{self.__class__.__name__} only supports torch " \ | |||||
| assert isinstance(driver, TorchDriver), f"Callback:{self.__class__.__name__} only supports torch " \ | |||||
| f"related drivers for now." | f"related drivers for now." | ||||
| parameters = [] | parameters = [] | ||||
| for optimizer in trainer.driver.optimizers: | for optimizer in trainer.driver.optimizers: | ||||
| for param_group in optimizer.param_groups: | for param_group in optimizer.param_groups: | ||||
| parameters.extend(param_group['params']) | parameters.extend(param_group['params']) | ||||
| self.parameters = parameters | self.parameters = parameters | ||||
| if isinstance(trainer.driver, FairScaleDriver): | |||||
| if isinstance(trainer.driver.model, FullyShardedDataParallel) and self.clip_type == 'norm': | |||||
| self.clip_fun = trainer.driver.model.clip_grad_norm_ | |||||
| assert len(self.parameters), "There is no parameters need to be clipped." | assert len(self.parameters), "There is no parameters need to be clipped." | ||||
| def on_before_optimizers_step(self, trainer, optimizers): | def on_before_optimizers_step(self, trainer, optimizers): | ||||
| @@ -58,7 +58,7 @@ class TrainBatchLoop(Loop): | |||||
| trainer.on_train_batch_end() | trainer.on_train_batch_end() | ||||
| except BaseException as e: | except BaseException as e: | ||||
| if indices is not None and not isinstance(e, (EarlyStopException, KeyboardInterrupt)): | if indices is not None and not isinstance(e, (EarlyStopException, KeyboardInterrupt)): | ||||
| logger.error(f"Exception happens when running on samples: {indices}") | |||||
| logger.error(f"Exception happens when training on samples: {indices}") | |||||
| raise e | raise e | ||||
| trainer.step_evaluate() | trainer.step_evaluate() | ||||
| trainer.batch_idx_in_epoch = 0 | trainer.batch_idx_in_epoch = 0 | ||||
| @@ -267,7 +267,8 @@ class Trainer(TrainerEventTrigger): | |||||
| * ddp_kwargs -- 用于在使用 ``TorchDDPDriver`` 时指定 ``DistributedDataParallel`` 初始化时的参数;例如传入 | * ddp_kwargs -- 用于在使用 ``TorchDDPDriver`` 时指定 ``DistributedDataParallel`` 初始化时的参数;例如传入 | ||||
| {'find_unused_parameters': True} 来解决有参数不参与前向运算导致的报错等; | {'find_unused_parameters': True} 来解决有参数不参与前向运算导致的报错等; | ||||
| * set_grad_to_none -- 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; | * set_grad_to_none -- 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; | ||||
| * torch_non_blocking -- 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking; | |||||
| * non_blocking -- 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking; | |||||
| * gradscaler_kwargs -- 用于 fp16=True 时,提供给 ``torch.amp.cuda.GradScaler`` 的参数。 | |||||
| * *paddle_kwargs* -- 用于在指定 ``driver`` 为 'paddle' 时设定具体 driver 实例的一些参数: | * *paddle_kwargs* -- 用于在指定 ``driver`` 为 'paddle' 时设定具体 driver 实例的一些参数: | ||||
| * fleet_kwargs -- 用于在使用 ``PaddleFleetDriver`` 时指定 ``DataParallel`` 和 ``fleet`` 初始化时的参数,包括: | * fleet_kwargs -- 用于在使用 ``PaddleFleetDriver`` 时指定 ``DataParallel`` 和 ``fleet`` 初始化时的参数,包括: | ||||
| @@ -494,9 +495,6 @@ class Trainer(TrainerEventTrigger): | |||||
| self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler, | self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler, | ||||
| reproducible=self.callback_manager._need_reproducible_sampler) | reproducible=self.callback_manager._need_reproducible_sampler) | ||||
| _torch_kwargs = kwargs.get("torch_kwargs", {}) | |||||
| self.set_grad_to_none = _torch_kwargs.get("set_grad_to_none", True) | |||||
| self.evaluate_batch_step_fn = evaluate_batch_step_fn | self.evaluate_batch_step_fn = evaluate_batch_step_fn | ||||
| self.kwargs = kwargs | self.kwargs = kwargs | ||||
| @@ -596,7 +594,7 @@ class Trainer(TrainerEventTrigger): | |||||
| try: | try: | ||||
| self.on_train_begin() | self.on_train_begin() | ||||
| self.driver.barrier() | self.driver.barrier() | ||||
| self.driver.zero_grad(self.set_grad_to_none) | |||||
| self.driver.zero_grad() | |||||
| while self.cur_epoch_idx < self.n_epochs: | while self.cur_epoch_idx < self.n_epochs: | ||||
| # 这个是防止在 Trainer.load_checkpoint 之后还没结束当前 epoch 又继续 save | # 这个是防止在 Trainer.load_checkpoint 之后还没结束当前 epoch 又继续 save | ||||
| self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch | self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch | ||||
| @@ -1236,7 +1234,7 @@ class Trainer(TrainerEventTrigger): | |||||
| """ | """ | ||||
| if (self.global_forward_batches + 1) % self.accumulation_steps == 0: | if (self.global_forward_batches + 1) % self.accumulation_steps == 0: | ||||
| self.on_before_zero_grad(self.optimizers) | self.on_before_zero_grad(self.optimizers) | ||||
| self.driver.zero_grad(self.set_grad_to_none) | |||||
| self.driver.zero_grad() | |||||
| self.on_after_zero_grad(self.optimizers) | self.on_after_zero_grad(self.optimizers) | ||||
| def step(self): | def step(self): | ||||
| @@ -198,12 +198,11 @@ class Driver(ABC): | |||||
| raise NotImplementedError("Each specific driver should implemented its own `step` function.") | raise NotImplementedError("Each specific driver should implemented its own `step` function.") | ||||
| @abstractmethod | @abstractmethod | ||||
| def zero_grad(self, set_to_none: bool = False): | |||||
| def zero_grad(self): | |||||
| r""" | r""" | ||||
| 实现深度学习中的梯度的置零操作,应当直接通过优化器 optimizers 来将梯度置零; | 实现深度学习中的梯度的置零操作,应当直接通过优化器 optimizers 来将梯度置零; | ||||
| 注意梯度累积不需要在这里实现,trainer 已经在内部实现了梯度累积; | 注意梯度累积不需要在这里实现,trainer 已经在内部实现了梯度累积; | ||||
| :param set_to_none: 用来判断是否需要将梯度直接置为 None; | |||||
| """ | """ | ||||
| raise NotImplementedError("Each specific driver should implemented its own `zero_grad` function.") | raise NotImplementedError("Each specific driver should implemented its own `zero_grad` function.") | ||||
| @@ -46,7 +46,7 @@ class JittorSingleDriver(JittorDriver): | |||||
| for optimizer in self.optimizers: | for optimizer in self.optimizers: | ||||
| optimizer.backward(loss) | optimizer.backward(loss) | ||||
| def zero_grad(self, set_to_none=False): | |||||
| def zero_grad(self): | |||||
| for optimizer in self.optimizers: | for optimizer in self.optimizers: | ||||
| optimizer.zero_grad() | optimizer.zero_grad() | ||||
| @@ -199,7 +199,7 @@ class PaddleFleetDriver(PaddleDriver): | |||||
| paddle_kwargs = kwargs.get("paddle_kwargs", {}) | paddle_kwargs = kwargs.get("paddle_kwargs", {}) | ||||
| self._fleet_kwargs = paddle_kwargs.get("fleet_kwargs", {}) | self._fleet_kwargs = paddle_kwargs.get("fleet_kwargs", {}) | ||||
| check_user_specific_params(self._fleet_kwargs, DataParallel.__init__) | |||||
| check_user_specific_params(self._fleet_kwargs, DataParallel.__init__, DataParallel.__name__) | |||||
| # fleet.init 中对于分布式策略的设置,详情可以参考 PaddlePaddle 的官方文档 | # fleet.init 中对于分布式策略的设置,详情可以参考 PaddlePaddle 的官方文档 | ||||
| self.strategy = self._fleet_kwargs.get("strategy", fleet.DistributedStrategy()) | self.strategy = self._fleet_kwargs.get("strategy", fleet.DistributedStrategy()) | ||||
| self.is_collective = self._fleet_kwargs.pop("is_collective", True) | self.is_collective = self._fleet_kwargs.pop("is_collective", True) | ||||
| @@ -83,12 +83,11 @@ class PaddleDriver(Driver): | |||||
| # 用来设置是否关闭 auto_param_call 中的参数匹配问题; | # 用来设置是否关闭 auto_param_call 中的参数匹配问题; | ||||
| self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) | self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) | ||||
| def zero_grad(self, set_to_none: bool = False): | |||||
| def zero_grad(self): | |||||
| r""" | r""" | ||||
| 实现深度学习中的梯度的置零操作,应当直接通过优化器 ``optimizers`` 来将梯度置零; | 实现深度学习中的梯度的置零操作,应当直接通过优化器 ``optimizers`` 来将梯度置零; | ||||
| 注意梯度累积不需要在这里实现,:class:`~fastNLP.core.Trainer` 已经在内部实现了梯度累积; | 注意梯度累积不需要在这里实现,:class:`~fastNLP.core.Trainer` 已经在内部实现了梯度累积; | ||||
| :param set_to_none: 用来判断是否需要将梯度直接置为 ``None``;在 **PaddlePaddle** 中这个参数无效。 | |||||
| """ | """ | ||||
| for optimizer in self.optimizers: | for optimizer in self.optimizers: | ||||
| optimizer.clear_grad() | optimizer.clear_grad() | ||||
| @@ -304,11 +304,11 @@ class TorchDDPDriver(TorchDriver): | |||||
| self.global_rank = 0 | self.global_rank = 0 | ||||
| self._ddp_kwargs = self._torch_kwargs.get("ddp_kwargs", {}) | self._ddp_kwargs = self._torch_kwargs.get("ddp_kwargs", {}) | ||||
| check_user_specific_params(self._ddp_kwargs, DistributedDataParallel.__init__) | |||||
| check_user_specific_params(self._ddp_kwargs, DistributedDataParallel.__init__, DistributedDataParallel.__name__) | |||||
| if len(self.model._buffers) != 0 and self._ddp_kwargs.get("broadcast_buffers", None) is None: | if len(self.model._buffers) != 0 and self._ddp_kwargs.get("broadcast_buffers", None) is None: | ||||
| logger.info("Notice your model has buffers and you are using `TorchDDPDriver`, but you do not set " | logger.info("Notice your model has buffers and you are using `TorchDDPDriver`, but you do not set " | ||||
| "'broadcast_buffers' in your trainer. Cause in most situations, this parameter can be set" | "'broadcast_buffers' in your trainer. Cause in most situations, this parameter can be set" | ||||
| " to 'False' to avoid redundant data translation between different processes.") | |||||
| " to 'False' to avoid redundant data communication between different processes.") | |||||
| self.output_from_new_proc = kwargs.get("output_from_new_proc", "only_error") | self.output_from_new_proc = kwargs.get("output_from_new_proc", "only_error") | ||||
| assert isinstance(self.output_from_new_proc, str), "Parameter `output_from_new_proc` can only be `str` type." | assert isinstance(self.output_from_new_proc, str), "Parameter `output_from_new_proc` can only be `str` type." | ||||
| @@ -471,7 +471,7 @@ class TorchDDPDriver(TorchDriver): | |||||
| self._global_rank = rank | self._global_rank = rank | ||||
| @property | @property | ||||
| def local_rank(self) -> int: | |||||
| def local_rank(self) -> int: # 这个不会受到 all_rank_call_context 的影响 | |||||
| return int(os.environ.get("LOCAL_RANK", 0)) | return int(os.environ.get("LOCAL_RANK", 0)) | ||||
| @property | @property | ||||
| @@ -0,0 +1,307 @@ | |||||
| __all__ = [ | |||||
| 'FairScaleDriver' | |||||
| ] | |||||
| from typing import List, Sequence, Union, Dict, Mapping | |||||
| from pathlib import Path | |||||
| import os | |||||
| import functools | |||||
| from fastNLP.envs.imports import _NEED_IMPORT_FAIRSCALE | |||||
| if _NEED_IMPORT_FAIRSCALE: | |||||
| import torch | |||||
| import torch.distributed as dist | |||||
| from fairscale.optim import OSS | |||||
| from fairscale.nn import ShardedDataParallel | |||||
| from fairscale.nn import FullyShardedDataParallel | |||||
| from fairscale.optim.grad_scaler import ShardedGradScaler | |||||
| from torch.nn.parallel import DistributedDataParallel | |||||
| from fairscale.nn.wrap import auto_wrap, enable_wrap, default_auto_wrap_policy | |||||
| from ...log import logger | |||||
| from .utils import reset_seed, _DDPWrappingModel | |||||
| from .ddp import TorchDDPDriver | |||||
| from .torch_driver import TorchDriver | |||||
| from .utils import _build_fp16_env | |||||
| from ....envs.distributed import all_rank_call_context | |||||
| from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK | |||||
| from .utils import optimizer_state_to_device | |||||
| class FairScaleDriver(TorchDDPDriver): | |||||
| def __init__( | |||||
| self, | |||||
| model, | |||||
| parallel_device: Union[List["torch.device"], "torch.device"], | |||||
| is_pull_by_torch_run = False, | |||||
| fp16: bool = False, | |||||
| **kwargs | |||||
| ): | |||||
| assert _NEED_IMPORT_FAIRSCALE, "fairscale is not imported." | |||||
| assert not dist.is_initialized(), "FairScaleDriver does not support initialize distributed by user." | |||||
| self._fairscale_kwargs = kwargs.get('fairscale_kwargs', {}) | |||||
| self.fs_type = self._fairscale_kwargs.get('fs_type', 'sdp') # ddp, sdp, fsdp | |||||
| if self.fs_type == 'fsdp': | |||||
| self._fairscale_kwargs['set_grad_to_none'] = self._fairscale_kwargs.get('set_grad_to_none', True) | |||||
| # 将最顶上的进行初始化 | |||||
| kwargs.pop('torch_kwargs', None) | |||||
| TorchDriver.__init__(self, model=model, fp16=False, torch_kwargs=self._fairscale_kwargs, **kwargs) | |||||
| self.is_pull_by_torch_run = is_pull_by_torch_run | |||||
| assert self.fs_type in ['ddp', 'sdp', 'fsdp'] | |||||
| self._oss_kwargs = self._fairscale_kwargs.get('oss_kwargs', {}) # 仅在 ddp 和 sdp 下有使用到 | |||||
| self._sdp_kwargs = self._fairscale_kwargs.get('sdp_kwargs', {}) | |||||
| self._fdsp_kwargs = self._fairscale_kwargs.get('fsdp_kwargs', {}) | |||||
| self._ddp_kwargs = self._fairscale_kwargs.get('ddp_kwargs', {}) | |||||
| if self.fs_type == 'ddp' or fp16 is False: | |||||
| self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) | |||||
| self.grad_scaler = _grad_scaler(**self._fairscale_kwargs.get('gradscaler_kwargs', {})) | |||||
| else: | |||||
| self.auto_cast, self.grad_scaler = torch.cuda.amp.autocast, \ | |||||
| ShardedGradScaler(**self._fairscale_kwargs.get('gradscaler_kwargs', {})) | |||||
| self.parallel_device = parallel_device | |||||
| if is_pull_by_torch_run: | |||||
| self.model_device = parallel_device | |||||
| else: | |||||
| self.model_device = parallel_device[self.local_rank] | |||||
| self.outside_ddp = False # 不允许在外部初始化 | |||||
| self._data_device = kwargs.get("data_device", None) | |||||
| if isinstance(self._data_device, int): | |||||
| if self._data_device < 0: | |||||
| raise ValueError("Parameter `data_device` can not be smaller than 0.") | |||||
| _could_use_device_num = torch.cuda.device_count() | |||||
| if self._data_device >= _could_use_device_num: | |||||
| raise ValueError("The gpu device that parameter `device` specifies is not existed.") | |||||
| self._data_device = torch.device(f"cuda:{self._data_device}") | |||||
| elif isinstance(self._data_device, str): | |||||
| self._data_device = torch.device(self._data_device) | |||||
| elif self._data_device is not None and not isinstance(self._data_device, torch.device): | |||||
| raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") | |||||
| self._master_port = None | |||||
| # world_size 表示的就是全局的显卡的数量; | |||||
| self.world_size = None # int(os.environ.get("WORLD_SIZE")) len(self.parallel_device) | |||||
| self.global_rank = 0 | |||||
| if self.fs_type == 'ddp': | |||||
| if len(self.model._buffers) != 0 and self._ddp_kwargs.get("broadcast_buffers", None) is None: | |||||
| logger.info("Notice your model has buffers and you are using `FairScaleDriver`, but you do not set " | |||||
| "'broadcast_buffers' in your trainer. Cause in most situations, this parameter can be set" | |||||
| " to 'False' to avoid redundant data communication between different processes.") | |||||
| self.output_from_new_proc = kwargs.get("output_from_new_proc", "only_error") | |||||
| assert isinstance(self.output_from_new_proc, str), "Parameter `output_from_new_proc` can only be `str` type." | |||||
| if self.output_from_new_proc not in {"all", "ignore", "only_error"}: | |||||
| os.makedirs(self.output_from_new_proc, exist_ok=True) | |||||
| self.output_from_new_proc = os.path.abspath(self.output_from_new_proc) | |||||
| self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的; | |||||
| self._has_ddpwrapped = False # 判断传入的模型是否经过 _has_ddpwrapped 包裹; | |||||
| def setup(self): | |||||
| r""" | |||||
| 准备分布式环境,该函数主要做以下两件事情: | |||||
| 1. 开启多进程,每个 gpu 设备对应单独的一个进程; | |||||
| 2. 每个进程将模型迁移到自己对应的 ``gpu`` 设备上;然后使用 ``DistributedDataParallel`` 包裹模型; | |||||
| """ | |||||
| if self._has_setup: | |||||
| return | |||||
| self._has_setup = True | |||||
| if self.is_pull_by_torch_run: | |||||
| # dist.get_world_size() 只能在 dist.init_process_group 初始化之后进行调用; | |||||
| self.world_size = int(os.environ.get("WORLD_SIZE")) | |||||
| self.global_rank = int(os.environ.get("RANK")) | |||||
| reset_seed() | |||||
| logger.info(f"World size: {self.world_size}, Global rank: {self.global_rank}") | |||||
| if not dist.is_initialized(): | |||||
| dist.init_process_group( | |||||
| backend="nccl", rank=self.global_rank, world_size=self.world_size | |||||
| ) | |||||
| os.environ["fastnlp_torch_launch_not_ddp"] = "yes" | |||||
| else: | |||||
| if not dist.is_initialized(): | |||||
| # 这里主要的问题在于要区分 rank0 和其它 rank 的情况; | |||||
| self.world_size = len(self.parallel_device) | |||||
| self.open_subprocess() | |||||
| self.global_rank = self.local_rank # rank 一定是通过环境变量去获取的; | |||||
| reset_seed() | |||||
| dist.init_process_group( | |||||
| backend="nccl", rank=self.global_rank, world_size=self.world_size | |||||
| ) | |||||
| # 用户在这个 trainer 前面又初始化了一个 trainer,并且使用的是 TorchDDPDriver; | |||||
| else: | |||||
| # 如果 `dist.is_initialized() == True`,那么说明 TorchDDPDriver 在之前已经初始化并且已经 setup 过一次,那么我们需要保证现在 | |||||
| # 使用的(即之后的)TorchDDPDriver 的设置和第一个 TorchDDPDriver 是完全一样的; | |||||
| pre_num_processes = int(os.environ[FASTNLP_DISTRIBUTED_CHECK]) | |||||
| if pre_num_processes != len(self.parallel_device): | |||||
| raise RuntimeError( | |||||
| "Notice you are using `TorchDDPDriver` after one instantiated `TorchDDPDriver`, it is not" | |||||
| "allowed that your second `TorchDDPDriver` has a new setting of parameters " | |||||
| "`num_nodes` and `num_processes`.") | |||||
| self.world_size = dist.get_world_size() | |||||
| self.global_rank = dist.get_rank() | |||||
| torch.cuda.set_device(self.model_device) | |||||
| if self.fs_type != 'fsdp': | |||||
| self.model.to(self.model_device) | |||||
| self.configure_ddp() | |||||
| self.barrier() | |||||
| # 初始化 self._pids,从而使得每一个进程都能接受到 rank0 的 send 操作; | |||||
| self._pids = [torch.tensor(0, dtype=torch.int).to(self.data_device) for _ in range(dist.get_world_size())] | |||||
| dist.all_gather(self._pids, torch.tensor(os.getpid(), dtype=torch.int).to(self.data_device)) | |||||
| local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE")) if "LOCAL_WORLD_SIZE" in os.environ else None | |||||
| if local_world_size is None: | |||||
| local_world_size = torch.tensor(int(os.environ.get("LOCAL_RANK")), dtype=torch.int).to(self.data_device) | |||||
| dist.all_reduce(local_world_size, op=dist.ReduceOp.MAX) | |||||
| local_world_size = local_world_size.tolist() + 1 | |||||
| node_rank = self.global_rank // local_world_size | |||||
| self._pids = self._pids[node_rank * local_world_size: (node_rank + 1) * local_world_size] | |||||
| self._pids = self.tensor_to_numeric(self._pids) | |||||
| def configure_ddp(self): | |||||
| model = _DDPWrappingModel(self.model) | |||||
| if self.fs_type == 'ddp': | |||||
| self.model = DistributedDataParallel( | |||||
| # 注意这里的 self.model_device 是 `torch.device` type,因此 self.model_device.index; | |||||
| model, device_ids=[self.model_device.index], | |||||
| **self._ddp_kwargs | |||||
| ) | |||||
| elif self.fs_type == 'sdp': | |||||
| sdp_kwargs = self._sdp_kwargs | |||||
| sdp_kwargs = {**sdp_kwargs, 'module': model} | |||||
| sdp_kwargs['reduce_fp16'] = sdp_kwargs.get('reduce_fp16', self.fp16) | |||||
| oss_lst = [] | |||||
| for optimizer in self.optimizers: | |||||
| oss = OSS(optimizer.param_groups, optim=type(optimizer), **optimizer.defaults) | |||||
| oss_lst.append(oss) | |||||
| sdp_kwargs['sharded_optimizer'] = oss_lst | |||||
| sdp_kwargs['warn_on_trainable_params_changed'] = sdp_kwargs.get('warn_on_trainable_params_changed', False) | |||||
| self.model = ShardedDataParallel(**sdp_kwargs) | |||||
| self.optimizers = oss_lst | |||||
| else: | |||||
| assert len(self.optimizers) == 1, "When fs_type='fsdp', only one optimizer is allowed." | |||||
| optimizer = self.optimizers[0] | |||||
| assert len(optimizer.param_groups) == 1, "Cannot assign parameter specific optimizer parameter for 'fsdp'." | |||||
| fsdp_kwargs = self._fdsp_kwargs | |||||
| fsdp_kwargs['mixed_precision'] = self.fp16 | |||||
| fsdp_kwargs['state_dict_on_rank_0_only'] = fsdp_kwargs.get('state_dict_on_rank_0_only', True) | |||||
| fsdp_kwargs['state_dict_device'] = fsdp_kwargs.get('state_dict_device', torch.device('cpu')) | |||||
| fsdp_kwargs['compute_device'] = fsdp_kwargs.get('compute_device', self.model_device) | |||||
| optimizer = self.optimizers[0] | |||||
| # wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=1e6) | |||||
| # with enable_wrap(wrapper_cls=FullyShardedDataParallel, auto_wrap_policy=wrap_policy, | |||||
| # **fsdp_kwargs): | |||||
| # model = auto_wrap(model) | |||||
| fsdp_kwargs = {**fsdp_kwargs, 'module': model} | |||||
| self.model = None # 释放掉 | |||||
| self.model = FullyShardedDataParallel(**fsdp_kwargs).to(self.model_device) | |||||
| self.optimizers = type(optimizer)(self.model.parameters(), **optimizer.defaults) | |||||
| self._has_ddpwrapped = True | |||||
| def save_model(self, filepath: Union[str, Path], only_state_dict: bool = True, **kwargs): | |||||
| """ | |||||
| 保存当前 driver 的模型到 folder 下。 | |||||
| :param filepath: 保存到哪个文件夹; | |||||
| :param only_state_dict: 是否只保存权重; | |||||
| :return: | |||||
| """ | |||||
| if self.fs_type in ('ddp', 'sdp'): | |||||
| model = self.model.module.model | |||||
| if only_state_dict: | |||||
| if self.fs_type != 'fsdp': | |||||
| if self.local_rank == 0: | |||||
| states = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} | |||||
| else: | |||||
| # 所有 rank 都需要调用 | |||||
| states = self.model.state_dict() | |||||
| if self.local_rank == 0: | |||||
| states = {key[len('model.'):]:value for key, value in states.items()} # 这里需要去掉那个 _wrap 的 key | |||||
| if self.local_rank == 0: # | |||||
| torch.save(states, filepath) | |||||
| elif self.fs_type == 'fsdp': | |||||
| raise RuntimeError("When fs_type='fsdp', only `only_state_dict=True` is allowed.") | |||||
| else: | |||||
| if self.local_rank == 0: | |||||
| torch.save(model, filepath) | |||||
| def load_model(self, filepath: str, only_state_dict: bool = True, **kwargs): | |||||
| """ | |||||
| 从 folder 中加载权重并赋值到当前 driver 的模型上。 | |||||
| :param filepath: 加载权重或模型的路径 | |||||
| :param load_state_dict: 保存的内容是否只是权重。 | |||||
| :param kwargs: | |||||
| :return: | |||||
| """ | |||||
| states = torch.load(filepath, map_location='cpu') | |||||
| if isinstance(states, dict) and only_state_dict is False: | |||||
| logger.rank_zero_warning(f"It seems like that {filepath} only contains state, you may need to use " | |||||
| f"`only_state_dict=True`") | |||||
| elif not isinstance(states, dict) and only_state_dict is True: | |||||
| logger.rank_zero_warning(f"It seems like that {filepath} is not state, you may need to use " | |||||
| f"`only_state_dict=False`") | |||||
| if not isinstance(states, Mapping): | |||||
| states = states.state_dict() | |||||
| if self.fs_type in ('ddp', 'sdp'): | |||||
| model = self.model.module.model | |||||
| else: | |||||
| model = self.model | |||||
| states = {f'model.{k}':v for k, v in states.items()} | |||||
| model.load_state_dict(states) | |||||
| def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||||
| if self.fs_type == 'fsdp': | |||||
| if should_save_model is False: | |||||
| logger.warning("When save model using fs_type='fsdp', please make sure use " | |||||
| "`with trainer.driver.model.summon_full_params():` context to gather all parameters.") | |||||
| with all_rank_call_context(): | |||||
| super().save_checkpoint(folder=folder, states=states, dataloader=dataloader, only_state_dict=only_state_dict, | |||||
| should_save_model=should_save_model, **kwargs) | |||||
| else: | |||||
| super().save_checkpoint(folder=folder, states=states, dataloader=dataloader, | |||||
| only_state_dict=only_state_dict, should_save_model=should_save_model, **kwargs) | |||||
| def get_optimizer_state(self): | |||||
| optimizers_state_dict = {} | |||||
| for i in range(len(self.optimizers)): | |||||
| optimizer: torch.optim.Optimizer = self.optimizers[i] | |||||
| if self.fs_type == 'fsdp': | |||||
| optimizer_state = self.model.gather_full_optim_state_dict(optimizer) | |||||
| elif self.fs_type == 'sdp': | |||||
| optimizer.consolidate_state_dict(recipient_rank=0) | |||||
| else: | |||||
| optimizer_state = optimizer.state_dict() | |||||
| if self.local_rank == 0: | |||||
| optimizer_state["state"] = optimizer_state_to_device(optimizer_state["state"], torch.device("cpu")) | |||||
| optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; | |||||
| return optimizers_state_dict | |||||
| def load_optimizer_state(self, states): | |||||
| assert len(states) == len(self.optimizers), f"The number of optimizers is:{len(self.optimizers)}, while in " \ | |||||
| f"checkpoint it is:{len(states)}" | |||||
| for i in range(len(self.optimizers)): | |||||
| optimizer: torch.optim.Optimizer = self.optimizers[i] | |||||
| state = states[f'optimizer{i}'] | |||||
| if self.fs_type == 'fsdp': | |||||
| state = self.model.get_shard_from_optim_state_dict(state) | |||||
| optimizer.load_state_dict(state) | |||||
| logger.debug("Load optimizer state dict.") | |||||
| def unwrap_model(self): | |||||
| r""" | |||||
| :return: 返回原本的模型,例如没有被 ``DataParallel`` 包裹; | |||||
| """ | |||||
| return self.model.module.model | |||||
| @@ -1,63 +0,0 @@ | |||||
| from typing import List | |||||
| from fastNLP.envs.imports import _NEED_IMPORT_FAIRSCALE | |||||
| if _NEED_IMPORT_FAIRSCALE: | |||||
| import torch | |||||
| from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel | |||||
| from fairscale.optim import OSS | |||||
| __all__ = [ | |||||
| 'ShardedDriver' | |||||
| ] | |||||
| from .ddp import TorchDDPDriver | |||||
| # todo 注意 fairscale 现在几乎所有的功能都没有实现; | |||||
| # TODO:预跑前后对模型和 optimizers 的支持; | |||||
| # TODO:fairscale 的 fp16 额外的处理; | |||||
| class ShardedDriver(TorchDDPDriver): | |||||
| _REDUCE_BUFFER_SIZE_DEFAULT: int = 2 ** 23 # 8M | |||||
| def __init__( | |||||
| self, | |||||
| model, | |||||
| parallel_device: List["torch.device"], | |||||
| num_nodes: int = 1, | |||||
| fp16: bool = False, | |||||
| **kwargs | |||||
| ): | |||||
| super(ShardedDriver, self).__init__( | |||||
| model=model, | |||||
| parallel_device=parallel_device, | |||||
| num_nodes=num_nodes, | |||||
| fp16=fp16, | |||||
| **kwargs | |||||
| ) | |||||
| def configure_ddp(self): | |||||
| if "reduce_buffer_size" not in self._ddp_kwargs: | |||||
| # For multi-node training, enabling bucketing will improve performance. | |||||
| self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0 | |||||
| self.optimizers = self._wrap_optimizers(self.optimizers) | |||||
| self.model = ShardedDataParallel(self.model, sharded_optimizer=self.optimizers, **self._ddp_kwargs) | |||||
| def _wrap_optimizers(self, optimizers) -> List["OSS"]: | |||||
| # TODO:之后得去研究一下 pytorch lightning 为什么这样写,我们是不是也需要这样写; | |||||
| # if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING: | |||||
| # return optimizers | |||||
| return self._reinit_optimizers_with_oss(optimizers) | |||||
| def _reinit_optimizers_with_oss(self, optimizers) -> List["OSS"]: | |||||
| for x, optimizer in enumerate(optimizers): | |||||
| if not isinstance(optimizer, OSS): | |||||
| optim_class = type(optimizer) | |||||
| zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults) | |||||
| # TODO:具体细节见 pytorch lightning 的这一函数,主要的点在于加入 fp16 相关的一些东西; | |||||
| optimizers[x] = zero_optimizer | |||||
| del optimizer | |||||
| return optimizers | |||||
| @@ -7,11 +7,14 @@ if _NEED_IMPORT_TORCH: | |||||
| from .torch_driver import TorchDriver | from .torch_driver import TorchDriver | ||||
| from .single_device import TorchSingleDriver | from .single_device import TorchSingleDriver | ||||
| from .ddp import TorchDDPDriver | from .ddp import TorchDDPDriver | ||||
| from .fairscale import FairScaleDriver | |||||
| from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
| from fastNLP.envs import FASTNLP_BACKEND_LAUNCH | from fastNLP.envs import FASTNLP_BACKEND_LAUNCH | ||||
| from pkg_resources import parse_version | |||||
| __all__ = [] | __all__ = [] | ||||
| def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.device", int, List[int]]], | def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.device", int, List[int]]], | ||||
| model: "torch.nn.Module", **kwargs) -> TorchDriver: | model: "torch.nn.Module", **kwargs) -> TorchDriver: | ||||
| r""" | r""" | ||||
| @@ -23,13 +26,20 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi | |||||
| :return: 返回一个 :class:`~fastNLP.core.TorchSingleDriver` 或 :class:`~fastNLP.core.TorchDDPDriver` 实例; | :return: 返回一个 :class:`~fastNLP.core.TorchSingleDriver` 或 :class:`~fastNLP.core.TorchDDPDriver` 实例; | ||||
| """ | """ | ||||
| if parse_version(torch.__version__) < parse_version('1.6'): | |||||
| raise RuntimeError(f"Pytorch(current version:{torch.__version__}) need to be older than 1.6.") | |||||
| # world_size 和 rank | # world_size 和 rank | ||||
| if FASTNLP_BACKEND_LAUNCH in os.environ: | if FASTNLP_BACKEND_LAUNCH in os.environ: | ||||
| if device is not None: | if device is not None: | ||||
| logger.rank_zero_warning("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull " | logger.rank_zero_warning("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull " | ||||
| "up your script. And we will directly get the local device via " | "up your script. And we will directly get the local device via " | ||||
| "`os.environ['LOCAL_RANK']`.", once=True) | "`os.environ['LOCAL_RANK']`.", once=True) | ||||
| return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), True, **kwargs) | |||||
| if driver == 'fairscale': | |||||
| return FairScaleDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), | |||||
| is_pull_by_torch_run=True, **kwargs) | |||||
| else: | |||||
| return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), | |||||
| is_pull_by_torch_run=True, **kwargs) | |||||
| if driver not in {"torch", "fairscale"}: | if driver not in {"torch", "fairscale"}: | ||||
| raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'fairscale'].") | raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'fairscale'].") | ||||
| @@ -67,13 +77,10 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi | |||||
| else: | else: | ||||
| return TorchDDPDriver(model, device, **kwargs) | return TorchDDPDriver(model, device, **kwargs) | ||||
| elif driver == "fairscale": | elif driver == "fairscale": | ||||
| raise NotImplementedError("`fairscale` is not support right now.") | |||||
| # if not isinstance(device, List): | |||||
| # if device.type == 'cpu': | |||||
| # raise ValueError("You are using `fairscale` driver, but your chosen `device` is 'cpu'.") | |||||
| # log.info("Notice you are using `fairscale` driver, but your chosen `device` is only one gpu, we will" | |||||
| # "still use `fairscale` for you, but if you mean using `TorchSingleDriver`, you should " | |||||
| # "choose `torch` driver.") | |||||
| # return ShardedDriver(model, [device], **kwargs) | |||||
| # else: | |||||
| # return ShardedDriver(model, device, **kwargs) | |||||
| if not isinstance(device, List): | |||||
| if device.type == 'cpu': | |||||
| raise ValueError("You are using `fairscale` driver, but your chosen `device` is 'cpu'.") | |||||
| logger.warning_once("Notice you are using `fairscale`, but the `device` is only one gpu.") | |||||
| return FairScaleDriver(model, [device], **kwargs) | |||||
| else: | |||||
| return FairScaleDriver(model, device, **kwargs) | |||||
| @@ -1,7 +1,6 @@ | |||||
| import os | import os | ||||
| from typing import Union, Dict, Optional, Callable | from typing import Union, Dict, Optional, Callable | ||||
| from functools import partial | from functools import partial | ||||
| from pkg_resources import parse_version | |||||
| import numpy as np | import numpy as np | ||||
| import random | import random | ||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||
| @@ -52,23 +51,23 @@ class TorchDriver(Driver): | |||||
| super(TorchDriver, self).__init__(model) | super(TorchDriver, self).__init__(model) | ||||
| """ 进行 fp16 的设置 """ | """ 进行 fp16 的设置 """ | ||||
| self._torch_kwargs = kwargs.get("torch_kwargs", {}) | |||||
| # 因为 ddp 和 single_device 的混合精度训练的设置是一样的,因此可以统一抽象到这里; | # 因为 ddp 和 single_device 的混合精度训练的设置是一样的,因此可以统一抽象到这里; | ||||
| self.fp16 = fp16 | self.fp16 = fp16 | ||||
| if parse_version(torch.__version__) < parse_version('1.6'): | |||||
| raise RuntimeError(f"Pytorch({torch.__version__}) need to be older than 1.6.") | |||||
| self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) | |||||
| self.grad_scaler = _grad_scaler() | |||||
| self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not self.fp16) | |||||
| self.grad_scaler = _grad_scaler(**self._torch_kwargs.get('gradscaler_kwargs', {})) | |||||
| self.set_grad_to_none = self._torch_kwargs.get('set_grad_to_none') | |||||
| self._torch_kwargs = kwargs.get("torch_kwargs", {}) | |||||
| # 用来设置 `torch_move_data_to_device` 中的 `non_blocking` 参数; | # 用来设置 `torch_move_data_to_device` 中的 `non_blocking` 参数; | ||||
| self.non_blocking = self._torch_kwargs.get("torch_non_blocking", True) | |||||
| self.non_blocking = self._torch_kwargs.get("non_blocking", True) | |||||
| # 用来设置是否关闭 auto_param_call 中的参数匹配问题; | # 用来设置是否关闭 auto_param_call 中的参数匹配问题; | ||||
| self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) | self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) | ||||
| def zero_grad(self, set_to_none: bool = False): | |||||
| def zero_grad(self): | |||||
| for optimizer in self.optimizers: | for optimizer in self.optimizers: | ||||
| self._clear_grad(optimizer, set_to_none) | |||||
| self._clear_grad(optimizer, self.set_grad_to_none) | |||||
| def _clear_grad(self, optimizer, set_to_none): | def _clear_grad(self, optimizer, set_to_none): | ||||
| param_groups = optimizer.param_groups | param_groups = optimizer.param_groups | ||||
| @@ -178,7 +177,7 @@ class TorchDriver(Driver): | |||||
| else: | else: | ||||
| torch.save(model, filepath) | torch.save(model, filepath) | ||||
| def load_model(self, filepath: str, only_state_dict: bool = True, **kwargs): | |||||
| def load_model(self, filepath: Union[Path, str], only_state_dict: bool = True, **kwargs): | |||||
| """ | """ | ||||
| 从 folder 中加载权重并赋值到当前 driver 的模型上。 | 从 folder 中加载权重并赋值到当前 driver 的模型上。 | ||||
| @@ -195,10 +194,9 @@ class TorchDriver(Driver): | |||||
| elif not isinstance(res, dict) and only_state_dict is True: | elif not isinstance(res, dict) and only_state_dict is True: | ||||
| logger.rank_zero_warning(f"It seems like that {filepath} is not state, you may need to use " | logger.rank_zero_warning(f"It seems like that {filepath} is not state, you may need to use " | ||||
| f"`only_state_dict=False`") | f"`only_state_dict=False`") | ||||
| if only_state_dict: | |||||
| model.load_state_dict(res) | |||||
| else: | |||||
| model.load_state_dict(res.state_dict()) | |||||
| if not isinstance(res, dict): | |||||
| res = res.state_dict() | |||||
| model.load_state_dict(res) | |||||
| @rank_zero_call | @rank_zero_call | ||||
| def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | ||||
| @@ -246,25 +244,13 @@ class TorchDriver(Driver): | |||||
| # 2. 保存模型的状态; | # 2. 保存模型的状态; | ||||
| if should_save_model: | if should_save_model: | ||||
| model = self.unwrap_model() | |||||
| if not os.path.exists(folder): | if not os.path.exists(folder): | ||||
| os.mkdir(folder) | os.mkdir(folder) | ||||
| if only_state_dict: | |||||
| model_state_dict = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} | |||||
| # 对于单卡的 driver 来讲,我们实际上(现在)不应该考虑用户在DDP环境下使用单卡模式,从而造成效率损失; | |||||
| torch.save(model_state_dict, folder.joinpath(FASTNLP_MODEL_FILENAME)) | |||||
| logger.debug("Save model state dict") | |||||
| else: | |||||
| torch.save(model, folder.joinpath(FASTNLP_MODEL_FILENAME)) | |||||
| logger.debug("Save model") | |||||
| model_path = folder.joinpath(FASTNLP_MODEL_FILENAME) | |||||
| self.save_model(model_path, only_state_dict=only_state_dict) | |||||
| # 3. 保存 optimizers 的状态; | # 3. 保存 optimizers 的状态; | ||||
| optimizers_state_dict = {} | |||||
| for i in range(len(self.optimizers)): | |||||
| optimizer: torch.optim.Optimizer = self.optimizers[i] | |||||
| optimizer_state = optimizer.state_dict() | |||||
| optimizer_state["state"] = optimizer_state_to_device(optimizer_state["state"], torch.device("cpu")) | |||||
| optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; | |||||
| optimizers_state_dict = self.get_optimizer_state() | |||||
| # 4. 保存fp16的状态 | # 4. 保存fp16的状态 | ||||
| if not isinstance(self.grad_scaler, DummyGradScaler): | if not isinstance(self.grad_scaler, DummyGradScaler): | ||||
| @@ -275,38 +261,42 @@ class TorchDriver(Driver): | |||||
| states["optimizers_state_dict"] = optimizers_state_dict | states["optimizers_state_dict"] = optimizers_state_dict | ||||
| torch.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) | torch.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) | ||||
| def get_optimizer_state(self): | |||||
| optimizers_state_dict = {} | |||||
| for i in range(len(self.optimizers)): | |||||
| optimizer: torch.optim.Optimizer = self.optimizers[i] | |||||
| optimizer_state = optimizer.state_dict() | |||||
| optimizer_state["state"] = optimizer_state_to_device(optimizer_state["state"], torch.device("cpu")) | |||||
| optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; | |||||
| return optimizers_state_dict | |||||
| def load_optimizer_state(self, states): | |||||
| assert len(states) == len(self.optimizers), f"The number of optimizers is:{len(self.optimizers)}, while in " \ | |||||
| f"checkpoint it is:{len(states)}" | |||||
| for i in range(len(self.optimizers)): | |||||
| optimizer: torch.optim.Optimizer = self.optimizers[i] | |||||
| optimizer.load_state_dict(states[f"optimizer{i}"]) | |||||
| logger.debug("Load optimizer state dict.") | |||||
| def load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: | def load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: | ||||
| states = torch.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)) | states = torch.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)) | ||||
| # 1. 加载 optimizers 的状态; | # 1. 加载 optimizers 的状态; | ||||
| optimizers_state_dict = states.pop("optimizers_state_dict") | optimizers_state_dict = states.pop("optimizers_state_dict") | ||||
| for i in range(len(self.optimizers)): | |||||
| optimizer: torch.optim.Optimizer = self.optimizers[i] | |||||
| optimizer.load_state_dict(optimizers_state_dict[f"optimizer{i}"]) | |||||
| logger.debug("Load optimizer state dict.") | |||||
| self.load_optimizer_state(optimizers_state_dict) | |||||
| # 2. 加载模型状态; | # 2. 加载模型状态; | ||||
| if should_load_model: | if should_load_model: | ||||
| model = self.unwrap_model() | |||||
| res = torch.load(folder.joinpath(FASTNLP_MODEL_FILENAME), map_location='cpu') | |||||
| if only_state_dict: | |||||
| model.load_state_dict(res) | |||||
| logger.debug("Load model state dict...") | |||||
| else: | |||||
| model.load_state_dict(res.state_dict()) | |||||
| logger.debug("Load model...") | |||||
| self.load_model(filepath=folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict=only_state_dict) | |||||
| # 3. 加载fp16的状态 | # 3. 加载fp16的状态 | ||||
| if "grad_scaler_state_dict" in states: | if "grad_scaler_state_dict" in states: | ||||
| grad_scaler_state_dict = states.pop("grad_scaler_state_dict") | grad_scaler_state_dict = states.pop("grad_scaler_state_dict") | ||||
| if isinstance(self.grad_scaler, DummyGradScaler): | |||||
| self.auto_cast, _grad_scaler = _build_fp16_env(dummy=False) | |||||
| self.grad_scaler = _grad_scaler() | |||||
| self.fp16 = True | |||||
| self.grad_scaler.load_state_dict(grad_scaler_state_dict) | |||||
| logger.debug("Load grad_scaler state dict...") | |||||
| if not isinstance(self.grad_scaler, DummyGradScaler): | |||||
| self.grad_scaler.load_state_dict(grad_scaler_state_dict) | |||||
| logger.debug("Load grad_scaler state dict...") | |||||
| elif not isinstance(self.grad_scaler, DummyGradScaler): | elif not isinstance(self.grad_scaler, DummyGradScaler): | ||||
| logger.warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, " | |||||
| logger.rank_zero_warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, " | |||||
| f"the training process may be unstable.") | f"the training process may be unstable.") | ||||
| # 4. 恢复 sampler 的状态; | # 4. 恢复 sampler 的状态; | ||||
| @@ -153,7 +153,7 @@ class ClassifyFPreRecMetric(Metric): | |||||
| f"size:{pred.shape}, target should have size: {pred.shape} or " | f"size:{pred.shape}, target should have size: {pred.shape} or " | ||||
| f"{pred.shape[:-1]}, got {target.shape}.") | f"{pred.shape[:-1]}, got {target.shape}.") | ||||
| target_idxes = set(target.reshape(-1).tolist()) | |||||
| target_idxes = set(target.reshape(-1).tolist()+pred.reshape(-1).tolist()) | |||||
| for target_idx in target_idxes: | for target_idx in target_idxes: | ||||
| self._tp[target_idx] += ((pred == target_idx) * (target == target_idx) * masks).sum().item() | self._tp[target_idx] += ((pred == target_idx) * (target == target_idx) * masks).sum().item() | ||||
| self._fp[target_idx] += ((pred == target_idx) * (target != target_idx) * masks).sum().item() | self._fp[target_idx] += ((pred == target_idx) * (target != target_idx) * masks).sum().item() | ||||
| @@ -227,7 +227,7 @@ def _check_valid_parameters_number(fn, expected_params:List[str], fn_name=None): | |||||
| raise e | raise e | ||||
| def check_user_specific_params(user_params: Dict, fn: Callable): | |||||
| def check_user_specific_params(user_params: Dict, fn: Callable, fn_name=None): | |||||
| """ | """ | ||||
| 该函数使用用户的输入来对指定函数的参数进行赋值,主要用于一些用户无法直接调用函数的情况; | 该函数使用用户的输入来对指定函数的参数进行赋值,主要用于一些用户无法直接调用函数的情况; | ||||
| 主要作用在于帮助检查用户对使用函数 ``fn`` 的参数输入是否有误; | 主要作用在于帮助检查用户对使用函数 ``fn`` 的参数输入是否有误; | ||||
| @@ -235,13 +235,16 @@ def check_user_specific_params(user_params: Dict, fn: Callable): | |||||
| :param user_params: 用户指定的参数的值,应当是一个字典,其中 ``key`` 表示每一个参数的名字, | :param user_params: 用户指定的参数的值,应当是一个字典,其中 ``key`` 表示每一个参数的名字, | ||||
| ``value`` 为每一个参数的值; | ``value`` 为每一个参数的值; | ||||
| :param fn: 将要被调用的函数; | :param fn: 将要被调用的函数; | ||||
| :param fn_name: 在打印提示信息是如何显示函数名 | |||||
| :return: 返回一个字典,其中为在之后调用函数 ``fn`` 时真正会被传进去的参数的值; | :return: 返回一个字典,其中为在之后调用函数 ``fn`` 时真正会被传进去的参数的值; | ||||
| """ | """ | ||||
| if fn_name is None: | |||||
| fn_name = fn.__name__ | |||||
| fn_arg_names = get_fn_arg_names(fn) | fn_arg_names = get_fn_arg_names(fn) | ||||
| for arg_name, arg_value in user_params.items(): | for arg_name, arg_value in user_params.items(): | ||||
| if arg_name not in fn_arg_names: | if arg_name not in fn_arg_names: | ||||
| logger.rank_zero_warning(f"Notice your specific parameter `{arg_name}` is not used by function `{fn.__name__}`.") | |||||
| logger.rank_zero_warning(f"Notice parameter `{arg_name}` may not be used by `{fn_name}`.") | |||||
| return user_params | return user_params | ||||
| @@ -18,7 +18,7 @@ else: | |||||
| _IS_WINDOWS = platform.system() == "Windows" | _IS_WINDOWS = platform.system() == "Windows" | ||||
| _NEED_IMPORT_FAIRSCALE = not _IS_WINDOWS and _module_available("fairscale.nn") and 'torch' in need_import | |||||
| _NEED_IMPORT_FAIRSCALE = not _IS_WINDOWS and _module_available("fairscale") and 'torch' in need_import | |||||
| _NEED_IMPORT_TORCH = _module_available("torch") and 'torch' in need_import | _NEED_IMPORT_TORCH = _module_available("torch") and 'torch' in need_import | ||||
| _NEED_IMPORT_JITTOR = _module_available("jittor") and 'jittor' in need_import | _NEED_IMPORT_JITTOR = _module_available("jittor") and 'jittor' in need_import | ||||
| _NEED_IMPORT_PADDLE = _module_available("paddle") and 'paddle' in need_import | _NEED_IMPORT_PADDLE = _module_available("paddle") and 'paddle' in need_import | ||||
| @@ -277,13 +277,12 @@ def test_trainer_specific_params_1( | |||||
| model_wo_auto_param_call=True, | model_wo_auto_param_call=True, | ||||
| torch_kwargs={ | torch_kwargs={ | ||||
| "torch_non_blocking": False, | |||||
| "non_blocking": False, | |||||
| "set_grad_to_none": True | "set_grad_to_none": True | ||||
| } | } | ||||
| ) | ) | ||||
| assert trainer.set_grad_to_none is True | |||||
| assert trainer.driver.non_blocking is False | assert trainer.driver.non_blocking is False | ||||
| assert trainer.driver.wo_auto_param_call is True | assert trainer.driver.wo_auto_param_call is True | ||||
| @@ -320,13 +319,11 @@ def test_trainer_specific_params_2( | |||||
| "broadcast_buffers": True, | "broadcast_buffers": True, | ||||
| "find_unused_parameters": True | "find_unused_parameters": True | ||||
| }, | }, | ||||
| "torch_non_blocking": False, | |||||
| "set_grad_to_none": True | |||||
| "non_blocking": False, | |||||
| } | } | ||||
| ) | ) | ||||
| assert trainer.set_grad_to_none is True | |||||
| assert trainer.driver.non_blocking is False | assert trainer.driver.non_blocking is False | ||||
| assert trainer.driver.wo_auto_param_call is True | assert trainer.driver.wo_auto_param_call is True | ||||
| assert trainer.driver.output_from_new_proc == "all" | assert trainer.driver.output_from_new_proc == "all" | ||||
| @@ -682,7 +682,7 @@ class TestSaveLoad: | |||||
| # 3. 检查 fp16 是否被加载 | # 3. 检查 fp16 是否被加载 | ||||
| if fp16: | if fp16: | ||||
| assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler) | |||||
| assert not isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler) | |||||
| # 4. 检查 model 的参数是否正确 | # 4. 检查 model 的参数是否正确 | ||||
| # 5. 检查 batch_idx | # 5. 检查 batch_idx | ||||
| @@ -731,7 +731,7 @@ class TestSaveLoad: | |||||
| """ | """ | ||||
| try: | try: | ||||
| path = "model.ckp" | |||||
| path = "checkpoints/" | |||||
| num_replicas = len(device) | num_replicas = len(device) | ||||
| @@ -764,6 +764,7 @@ class TestSaveLoad: | |||||
| driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | ||||
| else: | else: | ||||
| driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))]) | driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))]) | ||||
| dist.barrier() # 等待save成功 | |||||
| # 加载 | # 加载 | ||||
| # 更改 batch_size | # 更改 batch_size | ||||
| dataloader = dataloader_with_randomsampler(self.dataset, 2, True, False, unrepeated=False) | dataloader = dataloader_with_randomsampler(self.dataset, 2, True, False, unrepeated=False) | ||||
| @@ -788,7 +789,7 @@ class TestSaveLoad: | |||||
| assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] | assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] | ||||
| # 3. 检查 fp16 是否被加载 | # 3. 检查 fp16 是否被加载 | ||||
| if fp16: | if fp16: | ||||
| assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler) | |||||
| assert not isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler) | |||||
| # 4. 检查 model 的参数是否正确 | # 4. 检查 model 的参数是否正确 | ||||
| # 5. 检查 batch_idx | # 5. 检查 batch_idx | ||||
| @@ -617,7 +617,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||||
| # 3. 检查 fp16 是否被加载 | # 3. 检查 fp16 是否被加载 | ||||
| if fp16: | if fp16: | ||||
| assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler) | |||||
| assert not isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler) | |||||
| # 4. 检查 model 的参数是否正确 | # 4. 检查 model 的参数是否正确 | ||||
| # 5. 检查 batch_idx | # 5. 检查 batch_idx | ||||
| @@ -689,7 +689,7 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16): | |||||
| # 3. 检查 fp16 是否被加载 | # 3. 检查 fp16 是否被加载 | ||||
| if fp16: | if fp16: | ||||
| assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler) | |||||
| assert not isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler) | |||||
| # 4. 检查 model 的参数是否正确 | # 4. 检查 model 的参数是否正确 | ||||
| # 5. 检查 batch_idx | # 5. 检查 batch_idx | ||||
| @@ -195,3 +195,21 @@ class TestClassfiyFPreRecMetric: | |||||
| pool.close() | pool.close() | ||||
| pool.join() | pool.join() | ||||
| def test_binary(self): | |||||
| pred = torch.randn(10, 2) | |||||
| target = torch.randint(1, size=(10,)) | |||||
| metric = ClassifyFPreRecMetric() | |||||
| metric.update(pred, target) | |||||
| results = metric.get_metric() | |||||
| print(target) | |||||
| print(metric._tp, metric._fp, metric._fn) | |||||
| assert results['f']==results['rec']==results['pre'] | |||||
| pred = torch.randn(10, 2) | |||||
| target = torch.randint(2, size=(10,)) | |||||
| metric = ClassifyFPreRecMetric() | |||||
| metric.update(pred, target) | |||||
| results = metric.get_metric() | |||||
| print(target) | |||||
| print(metric._tp, metric._fp, metric._fn) | |||||
| assert results['f']==results['rec']==results['pre'] | |||||