diff --git a/fastNLP/core/drivers/torch_driver/deepspeed.py b/fastNLP/core/drivers/torch_driver/deepspeed.py index 79451b13..579a50f4 100644 --- a/fastNLP/core/drivers/torch_driver/deepspeed.py +++ b/fastNLP/core/drivers/torch_driver/deepspeed.py @@ -14,12 +14,13 @@ from fastNLP.envs import( from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_DEEPSPEED if _NEED_IMPORT_TORCH: - import pytorch_lightning import torch import torch.distributed as dist + from torch.optim import Optimizer if _NEED_IMPORT_DEEPSPEED: import deepspeed + from deepspeed import DeepSpeedEngine, DeepSpeedOptimizer __all__ = [ "DeepSpeedDriver", @@ -33,7 +34,6 @@ class DeepSpeedDriver(TorchDDPDriver): parallel_device: Union[List["torch.device"], "torch.device"], is_pull_by_torch_run = False, fp16: bool = False, - strategy= "deepspeed", **kwargs ): assert _NEED_IMPORT_DEEPSPEED, "Deepspeed is not imported." @@ -56,8 +56,22 @@ class DeepSpeedDriver(TorchDDPDriver): # 我们的 model_device 一定是 torch.device,而不是一个 list; self.model_device = parallel_device[self.local_rank] - # 暂时不允许在外面初始化 + # 如果用户自己在外面初始化了 deepspeed; self.outside_ddp = False + if dist.is_initialized() and FASTNLP_DISTRIBUTED_CHECK not in os.environ and \ + "fastnlp_torch_launch_not_ddp" not in os.environ: + # 如果用户自己在外面初始化了 deepspeed,那么我们要求用户传入的模型一定是已经由 DeepSpeedEngine 包裹后的模型; + if not isinstance(model, DeepSpeedEngine): + raise RuntimeError( + "It is not allowed to input a normal model instead of `DeepSpeedEngine` when" + "you initialize the ddp process out of our control.") + + self.outside_ddp = True + self.config = model.config + # 用户只有将模型上传到对应机器上后才能用 DistributedDataParallel 包裹,因此如果用户在外面初始化了 DDP,那么在 TorchDDPDriver 中 + # 我们就直接将 model_device 置为 None; + self.model_device = None + self._data_device = kwargs.get("data_device", None) if isinstance(self._data_device, int): if self._data_device < 0: @@ -84,7 +98,6 @@ class DeepSpeedDriver(TorchDDPDriver): self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的; self._has_ddpwrapped = False # 判断传入的模型是否经过 _has_ddpwrapped 包裹; - self.strategy = strategy self.accumulation_steps = kwargs.get("accumulation_steps", 1) # 获取 batch_size 以设置 train_micro_batch_size_per_gpu 参数 train_dl = kwargs.get("train_dataloader", None) @@ -96,6 +109,14 @@ class DeepSpeedDriver(TorchDDPDriver): self.train_micro_batch_size = 1 self._ds_kwargs = kwargs.get("deepspeed_kwargs", {}) + self.strategy = self._ds_kwargs.get("strategy", "deepspeed") + + @staticmethod + def _check_optimizer_legality(optimizers): + for each_optimizer in optimizers: + if not isinstance(each_optimizer, (Optimizer, DeepSpeedOptimizer)): + raise TypeError(f"Each optimizer of parameter `optimizers` should be 'Optimizer' or " + f"'DeepSpeedOptimizer'type, not {type(each_optimizer)}.") def setup(self): r""" @@ -112,15 +133,19 @@ class DeepSpeedDriver(TorchDDPDriver): self.setup_config() # 如果用户需要使用多机模式,那么一定进入到这里; if self.is_pull_by_torch_run: - # dist.get_world_size() 只能在 dist.init_process_group 初始化之后进行调用; - self.world_size = int(os.environ.get("WORLD_SIZE")) - self.global_rank = int(os.environ.get("RANK")) - logger.info(f"World size: {self.world_size}, Global rank: {self.global_rank}") + if self.outside_ddp: + self.world_size = dist.get_world_size() + self.global_rank = dist.get_rank() + else: + # dist.get_world_size() 只能在 dist.init_process_group 初始化之后进行调用; + self.world_size = int(os.environ.get("WORLD_SIZE")) + self.global_rank = int(os.environ.get("RANK")) + logger.info(f"World size: {self.world_size}, Global rank: {self.global_rank}") - if not dist.is_initialized(): - deepspeed.init_distributed("nccl", distributed_port=self.master_port) + if not dist.is_initialized(): + deepspeed.init_distributed("nccl", distributed_port=self.master_port) - os.environ["fastnlp_torch_launch_not_ddp"] = "yes" + os.environ["fastnlp_torch_launch_not_ddp"] = "yes" # 进入到这里的情况时: # dist.is_initialized 一定为 False; @@ -146,8 +171,9 @@ class DeepSpeedDriver(TorchDDPDriver): self.world_size = dist.get_world_size() self.global_rank = dist.get_rank() - torch.cuda.set_device(self.model_device) - self.configure_ddp() + if not self.outside_ddp: + torch.cuda.set_device(self.model_device) + self.configure_ddp() self.barrier() # 初始化 self._pids,从而使得每一个进程都能接受到 rank0 的 send 操作; @@ -166,7 +192,7 @@ class DeepSpeedDriver(TorchDDPDriver): def configure_ddp(self): # 设置 deepspeed - if not isinstance(self.model, deepspeed.DeepSpeedEngine): + if not isinstance(self.model, DeepSpeedEngine): model=_DeepSpeedWrappingModel(self.model, self.fp16) model_parameters = filter(lambda p: p.requires_grad, model.parameters()) self.model, ds_optimizer, _, _ = deepspeed.initialize( @@ -193,7 +219,6 @@ class DeepSpeedDriver(TorchDDPDriver): self.config = self._ds_kwargs.get("config") if self.config is not None: - # TODO 究竟哪些参数按照config,哪些按照trainer参数 logger.warn("Notice that you have defined a configuration for deepspeed and parameters like" "`optimizers`, `strategy` and `fp16` may not take effects.") return @@ -258,12 +283,6 @@ class DeepSpeedDriver(TorchDDPDriver): def step(self): self.model.step() - def unwrap_model(self): - r""" - :return: 返回原本的模型; - """ - return self.model.module.model - def get_model_no_sync_context(self): r""" :return: 返回一个 ``context`` 上下文环境,用于关闭各个进程之间的同步;在 ``deepspeed`` 中,返回一个空的上下文 diff --git a/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py b/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py index 5d4d2ab5..b0a16112 100644 --- a/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py @@ -38,6 +38,9 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi if driver == 'fairscale': return FairScaleDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), is_pull_by_torch_run=True, **kwargs) + elif kwargs.get("deepspeed_kwargs") is not None: + return DeepSpeedDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), + is_pull_by_torch_run=True, **kwargs) else: return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), is_pull_by_torch_run=True, **kwargs) @@ -73,6 +76,14 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") if driver == "torch": # single, ddp, 直接启动。 + if kwargs.get("deepspeed_kwargs") is not None: + # 选择的是 deepspeed + if not isinstance(device, List): + if device.type == 'cpu': + raise ValueError("You are using `deepspeed` driver, but your chosen `device` is 'cpu'.") + logger.warning_once("Notice you are using `deepspeed`, but the `device` is only one gpu.") + return DeepSpeedDriver(model, [device], **kwargs) + return DeepSpeedDriver(model, device, **kwargs) if not isinstance(device, List): return TorchSingleDriver(model, device, **kwargs) else: @@ -84,11 +95,4 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi logger.warning_once("Notice you are using `fairscale`, but the `device` is only one gpu.") return FairScaleDriver(model, [device], **kwargs) else: - return FairScaleDriver(model, device, **kwargs) - elif driver == "deepspeed": - if not isinstance(device, List): - if device.type == 'cpu': - raise ValueError("You are using `deepspeed` driver, but your chosen `device` is 'cpu'.") - logger.warning_once("Notice you are using `deepspeed`, but the `device` is only one gpu.") - return DeepSpeedDriver(model, [device], **kwargs) - return DeepSpeedDriver(model, device, **kwargs) \ No newline at end of file + return FairScaleDriver(model, device, **kwargs) \ No newline at end of file diff --git a/tests/core/controllers/_test_trainer_deepspeed.py b/tests/core/controllers/_test_trainer_deepspeed.py new file mode 100644 index 00000000..2dc6326c --- /dev/null +++ b/tests/core/controllers/_test_trainer_deepspeed.py @@ -0,0 +1,95 @@ +""" +这个文件测试多卡情况下使用 deepspeed 的情况:: + + >>> # 测试直接使用多卡 + >>> python _test_trainer_deepspeed.py + >>> # 测试通过 deepspeed 拉起 + >>> deepspeed _test_trainer_deepspeed.py + +""" +import sys +sys.path.append("../../../") +from dataclasses import dataclass + +from fastNLP.core.controllers.trainer import Trainer +from fastNLP.core.metrics.accuracy import Accuracy +from fastNLP.core.callbacks.progress_callback import RichCallback + +from torch.optim import Adam +from torch.utils.data import DataLoader + +from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 +from tests.helpers.datasets.torch_data import TorchArgMaxDataset + +@dataclass +class TrainDeepSpeedConfig: + num_labels: int = 3 + feature_dimension: int = 3 + + batch_size: int = 2 + shuffle: bool = True + evaluate_every = 2 + +def test_trainer_deepspeed( + device, + callbacks, + strategy, + config, + n_epochs=2, +): + model = TorchNormalModel_Classification_1( + num_labels=TrainDeepSpeedConfig.num_labels, + feature_dimension=TrainDeepSpeedConfig.feature_dimension + ) + optimizers = Adam(params=model.parameters(), lr=0.0001) + train_dataloader = DataLoader( + dataset=TorchArgMaxDataset(TrainDeepSpeedConfig.feature_dimension, 20), + batch_size=TrainDeepSpeedConfig.batch_size, + shuffle=True + ) + val_dataloader = DataLoader( + dataset=TorchArgMaxDataset(TrainDeepSpeedConfig.feature_dimension, 12), + batch_size=TrainDeepSpeedConfig.batch_size, + shuffle=True + ) + train_dataloader = train_dataloader + evaluate_dataloaders = val_dataloader + evaluate_every = TrainDeepSpeedConfig.evaluate_every + metrics = {"acc": Accuracy()} + if config is not None: + config["train_micro_batch_size_per_gpu"] = TrainDeepSpeedConfig.batch_size + trainer = Trainer( + model=model, + driver="torch", + device=device, + optimizers=optimizers, + train_dataloader=train_dataloader, + evaluate_dataloaders=evaluate_dataloaders, + evaluate_every=evaluate_every, + metrics=metrics, + output_mapping={"preds": "pred"}, + + n_epochs=n_epochs, + callbacks=callbacks, + deepspeed_kwargs={ + "strategy": strategy, + "config": config + } + ) + trainer.run() + +if __name__ == "__main__": + device = [0,1] + # device = [0,1,3] + callbacks = [ + # RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True), + RichCallback(5), + ] + config = None + test_trainer_deepspeed( + device=device, + callbacks=callbacks, + strategy="deepspeed", + config=config, + n_epochs=5, + ) \ No newline at end of file diff --git a/tests/core/controllers/_test_trainer_deepspeed_outside.py b/tests/core/controllers/_test_trainer_deepspeed_outside.py new file mode 100644 index 00000000..a8dbd823 --- /dev/null +++ b/tests/core/controllers/_test_trainer_deepspeed_outside.py @@ -0,0 +1,105 @@ +""" +这个文件测试多卡情况下使用 deepspeed ,且用户自己调用了 deepspeed.initialize 的情况:: + + >>> deepspeed _test_trainer_deepspeed_outside.py + +""" +import os +import sys +sys.path.append("../../../") +from dataclasses import dataclass + +from fastNLP.core.controllers.trainer import Trainer +from fastNLP.core.metrics.accuracy import Accuracy +from fastNLP.core.callbacks.progress_callback import RichCallback +from fastNLP.core.drivers.torch_driver.utils import _create_default_config + +import deepspeed +import torch +from torch.optim import Adam +from torch.utils.data import DataLoader + + +from tests.helpers.models.torch_model import TorchNormalModel_Classification_2 +from tests.helpers.datasets.torch_data import TorchArgMaxDataset + +local_rank = int(os.environ["LOCAL_RANK"]) + +@dataclass +class TrainDeepSpeedConfig: + num_labels: int = 3 + feature_dimension: int = 3 + + batch_size: int = 2 + shuffle: bool = True + evaluate_every = 2 + +def test_trainer_deepspeed( + device, + callbacks, + strategy, + config, + n_epochs=2, +): + model = TorchNormalModel_Classification_2( + num_labels=TrainDeepSpeedConfig.num_labels, + feature_dimension=TrainDeepSpeedConfig.feature_dimension + ) + optimizers = Adam(params=model.parameters(), lr=0.0001) + train_dataloader = DataLoader( + dataset=TorchArgMaxDataset(TrainDeepSpeedConfig.feature_dimension, 20), + batch_size=TrainDeepSpeedConfig.batch_size, + shuffle=True + ) + val_dataloader = DataLoader( + dataset=TorchArgMaxDataset(TrainDeepSpeedConfig.feature_dimension, 12), + batch_size=TrainDeepSpeedConfig.batch_size, + shuffle=True + ) + train_dataloader = train_dataloader + evaluate_dataloaders = val_dataloader + evaluate_every = TrainDeepSpeedConfig.evaluate_every + metrics = {"acc": Accuracy()} + if config is not None: + config["train_micro_batch_size_per_gpu"] = TrainDeepSpeedConfig.batch_size + model, optimizers, _, _ = deepspeed.initialize( + model=model, + optimizer=optimizers, + config=config, + ) + trainer = Trainer( + model=model, + driver="torch", + device=device, + data_device=torch.device(f"cuda:{local_rank}"), + optimizers=optimizers, + train_dataloader=train_dataloader, + evaluate_dataloaders=evaluate_dataloaders, + evaluate_every=evaluate_every, + metrics=metrics, + output_mapping={"preds": "pred"}, + + n_epochs=n_epochs, + callbacks=callbacks, + deepspeed_kwargs={ + "strategy": strategy, + "config": config + } + ) + trainer.run() + +if __name__ == "__main__": + device = [0,1] + # device = [0,1,3] + callbacks = [ + # RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True), + RichCallback(5), + ] + config = _create_default_config(stage=2) + test_trainer_deepspeed( + device=device, + callbacks=callbacks, + strategy="deepspeed", + config=config, + n_epochs=5, + ) \ No newline at end of file diff --git a/tests/core/controllers/test_trainer_deepspeed.py b/tests/core/controllers/test_trainer_deepspeed.py index e69de29b..c718e01d 100644 --- a/tests/core/controllers/test_trainer_deepspeed.py +++ b/tests/core/controllers/test_trainer_deepspeed.py @@ -0,0 +1,99 @@ +import pytest +from dataclasses import dataclass + +from fastNLP.core.controllers.trainer import Trainer +from fastNLP.core.metrics.accuracy import Accuracy +from fastNLP.core.callbacks.progress_callback import RichCallback +from fastNLP.core.drivers.torch_driver import DeepSpeedDriver +from fastNLP.core.drivers.torch_driver.utils import _create_default_config +from fastNLP.envs.imports import _NEED_IMPORT_TORCH + +if _NEED_IMPORT_TORCH: + import torch + from torch.optim import Adam + from torch.utils.data import DataLoader + + +from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 +from tests.helpers.datasets.torch_data import TorchArgMaxDataset +from tests.helpers.utils import magic_argv_env_context + +@dataclass +class TrainDeepSpeedConfig: + num_labels: int = 3 + feature_dimension: int = 3 + + batch_size: int = 2 + shuffle: bool = True + evaluate_every = 2 + +@pytest.mark.deepspeed +class TestTrainer: + @classmethod + def setup_class(cls): + # 不初始化的话从第二个测试例开始会因为环境变量报错。 + torch_model = TorchNormalModel_Classification_1(1, 1) + torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01) + device = [torch.device(i) for i in [0,1]] + driver = DeepSpeedDriver( + model=torch_model, + parallel_device=device, + ) + driver.set_optimizers(torch_opt) + driver.setup() + + return driver + + @pytest.mark.parametrize("device", [[0, 1]]) + @pytest.mark.parametrize("callbacks", [[RichCallback(5)]]) + @pytest.mark.parametrize("strategy", ["deepspeed", "deepspeed_stage_1"]) + @pytest.mark.parametrize("config", [None, _create_default_config(stage=1)]) + @magic_argv_env_context + def test_trainer_deepspeed( + self, + device, + callbacks, + strategy, + config, + n_epochs=2, + ): + model = TorchNormalModel_Classification_1( + num_labels=TrainDeepSpeedConfig.num_labels, + feature_dimension=TrainDeepSpeedConfig.feature_dimension + ) + optimizers = Adam(params=model.parameters(), lr=0.0001) + train_dataloader = DataLoader( + dataset=TorchArgMaxDataset(TrainDeepSpeedConfig.feature_dimension, 20), + batch_size=TrainDeepSpeedConfig.batch_size, + shuffle=True + ) + val_dataloader = DataLoader( + dataset=TorchArgMaxDataset(TrainDeepSpeedConfig.feature_dimension, 12), + batch_size=TrainDeepSpeedConfig.batch_size, + shuffle=True + ) + train_dataloader = train_dataloader + evaluate_dataloaders = val_dataloader + evaluate_every = TrainDeepSpeedConfig.evaluate_every + metrics = {"acc": Accuracy()} + if config is not None: + config["train_micro_batch_size_per_gpu"] = TrainDeepSpeedConfig.batch_size + trainer = Trainer( + model=model, + driver="torch", + device=device, + optimizers=optimizers, + train_dataloader=train_dataloader, + evaluate_dataloaders=evaluate_dataloaders, + evaluate_every=evaluate_every, + metrics=metrics, + output_mapping={"preds": "pred"}, + + n_epochs=n_epochs, + callbacks=callbacks, + deepspeed_kwargs={ + "strategy": strategy, + "config": config + } + ) + trainer.run()