| @@ -14,12 +14,13 @@ from fastNLP.envs import( | |||||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_DEEPSPEED | from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_DEEPSPEED | ||||
| if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
| import pytorch_lightning | |||||
| import torch | import torch | ||||
| import torch.distributed as dist | import torch.distributed as dist | ||||
| from torch.optim import Optimizer | |||||
| if _NEED_IMPORT_DEEPSPEED: | if _NEED_IMPORT_DEEPSPEED: | ||||
| import deepspeed | import deepspeed | ||||
| from deepspeed import DeepSpeedEngine, DeepSpeedOptimizer | |||||
| __all__ = [ | __all__ = [ | ||||
| "DeepSpeedDriver", | "DeepSpeedDriver", | ||||
| @@ -33,7 +34,6 @@ class DeepSpeedDriver(TorchDDPDriver): | |||||
| parallel_device: Union[List["torch.device"], "torch.device"], | parallel_device: Union[List["torch.device"], "torch.device"], | ||||
| is_pull_by_torch_run = False, | is_pull_by_torch_run = False, | ||||
| fp16: bool = False, | fp16: bool = False, | ||||
| strategy= "deepspeed", | |||||
| **kwargs | **kwargs | ||||
| ): | ): | ||||
| assert _NEED_IMPORT_DEEPSPEED, "Deepspeed is not imported." | assert _NEED_IMPORT_DEEPSPEED, "Deepspeed is not imported." | ||||
| @@ -56,8 +56,22 @@ class DeepSpeedDriver(TorchDDPDriver): | |||||
| # 我们的 model_device 一定是 torch.device,而不是一个 list; | # 我们的 model_device 一定是 torch.device,而不是一个 list; | ||||
| self.model_device = parallel_device[self.local_rank] | self.model_device = parallel_device[self.local_rank] | ||||
| # 暂时不允许在外面初始化 | |||||
| # 如果用户自己在外面初始化了 deepspeed; | |||||
| self.outside_ddp = False | self.outside_ddp = False | ||||
| if dist.is_initialized() and FASTNLP_DISTRIBUTED_CHECK not in os.environ and \ | |||||
| "fastnlp_torch_launch_not_ddp" not in os.environ: | |||||
| # 如果用户自己在外面初始化了 deepspeed,那么我们要求用户传入的模型一定是已经由 DeepSpeedEngine 包裹后的模型; | |||||
| if not isinstance(model, DeepSpeedEngine): | |||||
| raise RuntimeError( | |||||
| "It is not allowed to input a normal model instead of `DeepSpeedEngine` when" | |||||
| "you initialize the ddp process out of our control.") | |||||
| self.outside_ddp = True | |||||
| self.config = model.config | |||||
| # 用户只有将模型上传到对应机器上后才能用 DistributedDataParallel 包裹,因此如果用户在外面初始化了 DDP,那么在 TorchDDPDriver 中 | |||||
| # 我们就直接将 model_device 置为 None; | |||||
| self.model_device = None | |||||
| self._data_device = kwargs.get("data_device", None) | self._data_device = kwargs.get("data_device", None) | ||||
| if isinstance(self._data_device, int): | if isinstance(self._data_device, int): | ||||
| if self._data_device < 0: | if self._data_device < 0: | ||||
| @@ -84,7 +98,6 @@ class DeepSpeedDriver(TorchDDPDriver): | |||||
| self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的; | self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的; | ||||
| self._has_ddpwrapped = False # 判断传入的模型是否经过 _has_ddpwrapped 包裹; | self._has_ddpwrapped = False # 判断传入的模型是否经过 _has_ddpwrapped 包裹; | ||||
| self.strategy = strategy | |||||
| self.accumulation_steps = kwargs.get("accumulation_steps", 1) | self.accumulation_steps = kwargs.get("accumulation_steps", 1) | ||||
| # 获取 batch_size 以设置 train_micro_batch_size_per_gpu 参数 | # 获取 batch_size 以设置 train_micro_batch_size_per_gpu 参数 | ||||
| train_dl = kwargs.get("train_dataloader", None) | train_dl = kwargs.get("train_dataloader", None) | ||||
| @@ -96,6 +109,14 @@ class DeepSpeedDriver(TorchDDPDriver): | |||||
| self.train_micro_batch_size = 1 | self.train_micro_batch_size = 1 | ||||
| self._ds_kwargs = kwargs.get("deepspeed_kwargs", {}) | self._ds_kwargs = kwargs.get("deepspeed_kwargs", {}) | ||||
| self.strategy = self._ds_kwargs.get("strategy", "deepspeed") | |||||
| @staticmethod | |||||
| def _check_optimizer_legality(optimizers): | |||||
| for each_optimizer in optimizers: | |||||
| if not isinstance(each_optimizer, (Optimizer, DeepSpeedOptimizer)): | |||||
| raise TypeError(f"Each optimizer of parameter `optimizers` should be 'Optimizer' or " | |||||
| f"'DeepSpeedOptimizer'type, not {type(each_optimizer)}.") | |||||
| def setup(self): | def setup(self): | ||||
| r""" | r""" | ||||
| @@ -112,15 +133,19 @@ class DeepSpeedDriver(TorchDDPDriver): | |||||
| self.setup_config() | self.setup_config() | ||||
| # 如果用户需要使用多机模式,那么一定进入到这里; | # 如果用户需要使用多机模式,那么一定进入到这里; | ||||
| if self.is_pull_by_torch_run: | if self.is_pull_by_torch_run: | ||||
| # dist.get_world_size() 只能在 dist.init_process_group 初始化之后进行调用; | |||||
| self.world_size = int(os.environ.get("WORLD_SIZE")) | |||||
| self.global_rank = int(os.environ.get("RANK")) | |||||
| logger.info(f"World size: {self.world_size}, Global rank: {self.global_rank}") | |||||
| if self.outside_ddp: | |||||
| self.world_size = dist.get_world_size() | |||||
| self.global_rank = dist.get_rank() | |||||
| else: | |||||
| # dist.get_world_size() 只能在 dist.init_process_group 初始化之后进行调用; | |||||
| self.world_size = int(os.environ.get("WORLD_SIZE")) | |||||
| self.global_rank = int(os.environ.get("RANK")) | |||||
| logger.info(f"World size: {self.world_size}, Global rank: {self.global_rank}") | |||||
| if not dist.is_initialized(): | |||||
| deepspeed.init_distributed("nccl", distributed_port=self.master_port) | |||||
| if not dist.is_initialized(): | |||||
| deepspeed.init_distributed("nccl", distributed_port=self.master_port) | |||||
| os.environ["fastnlp_torch_launch_not_ddp"] = "yes" | |||||
| os.environ["fastnlp_torch_launch_not_ddp"] = "yes" | |||||
| # 进入到这里的情况时: | # 进入到这里的情况时: | ||||
| # dist.is_initialized 一定为 False; | # dist.is_initialized 一定为 False; | ||||
| @@ -146,8 +171,9 @@ class DeepSpeedDriver(TorchDDPDriver): | |||||
| self.world_size = dist.get_world_size() | self.world_size = dist.get_world_size() | ||||
| self.global_rank = dist.get_rank() | self.global_rank = dist.get_rank() | ||||
| torch.cuda.set_device(self.model_device) | |||||
| self.configure_ddp() | |||||
| if not self.outside_ddp: | |||||
| torch.cuda.set_device(self.model_device) | |||||
| self.configure_ddp() | |||||
| self.barrier() | self.barrier() | ||||
| # 初始化 self._pids,从而使得每一个进程都能接受到 rank0 的 send 操作; | # 初始化 self._pids,从而使得每一个进程都能接受到 rank0 的 send 操作; | ||||
| @@ -166,7 +192,7 @@ class DeepSpeedDriver(TorchDDPDriver): | |||||
| def configure_ddp(self): | def configure_ddp(self): | ||||
| # 设置 deepspeed | # 设置 deepspeed | ||||
| if not isinstance(self.model, deepspeed.DeepSpeedEngine): | |||||
| if not isinstance(self.model, DeepSpeedEngine): | |||||
| model=_DeepSpeedWrappingModel(self.model, self.fp16) | model=_DeepSpeedWrappingModel(self.model, self.fp16) | ||||
| model_parameters = filter(lambda p: p.requires_grad, model.parameters()) | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) | ||||
| self.model, ds_optimizer, _, _ = deepspeed.initialize( | self.model, ds_optimizer, _, _ = deepspeed.initialize( | ||||
| @@ -193,7 +219,6 @@ class DeepSpeedDriver(TorchDDPDriver): | |||||
| self.config = self._ds_kwargs.get("config") | self.config = self._ds_kwargs.get("config") | ||||
| if self.config is not None: | if self.config is not None: | ||||
| # TODO 究竟哪些参数按照config,哪些按照trainer参数 | |||||
| logger.warn("Notice that you have defined a configuration for deepspeed and parameters like" | logger.warn("Notice that you have defined a configuration for deepspeed and parameters like" | ||||
| "`optimizers`, `strategy` and `fp16` may not take effects.") | "`optimizers`, `strategy` and `fp16` may not take effects.") | ||||
| return | return | ||||
| @@ -258,12 +283,6 @@ class DeepSpeedDriver(TorchDDPDriver): | |||||
| def step(self): | def step(self): | ||||
| self.model.step() | self.model.step() | ||||
| def unwrap_model(self): | |||||
| r""" | |||||
| :return: 返回原本的模型; | |||||
| """ | |||||
| return self.model.module.model | |||||
| def get_model_no_sync_context(self): | def get_model_no_sync_context(self): | ||||
| r""" | r""" | ||||
| :return: 返回一个 ``context`` 上下文环境,用于关闭各个进程之间的同步;在 ``deepspeed`` 中,返回一个空的上下文 | :return: 返回一个 ``context`` 上下文环境,用于关闭各个进程之间的同步;在 ``deepspeed`` 中,返回一个空的上下文 | ||||
| @@ -38,6 +38,9 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi | |||||
| if driver == 'fairscale': | if driver == 'fairscale': | ||||
| return FairScaleDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), | return FairScaleDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), | ||||
| is_pull_by_torch_run=True, **kwargs) | is_pull_by_torch_run=True, **kwargs) | ||||
| elif kwargs.get("deepspeed_kwargs") is not None: | |||||
| return DeepSpeedDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), | |||||
| is_pull_by_torch_run=True, **kwargs) | |||||
| else: | else: | ||||
| return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), | return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), | ||||
| is_pull_by_torch_run=True, **kwargs) | is_pull_by_torch_run=True, **kwargs) | ||||
| @@ -73,6 +76,14 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi | |||||
| raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") | raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") | ||||
| if driver == "torch": # single, ddp, 直接启动。 | if driver == "torch": # single, ddp, 直接启动。 | ||||
| if kwargs.get("deepspeed_kwargs") is not None: | |||||
| # 选择的是 deepspeed | |||||
| if not isinstance(device, List): | |||||
| if device.type == 'cpu': | |||||
| raise ValueError("You are using `deepspeed` driver, but your chosen `device` is 'cpu'.") | |||||
| logger.warning_once("Notice you are using `deepspeed`, but the `device` is only one gpu.") | |||||
| return DeepSpeedDriver(model, [device], **kwargs) | |||||
| return DeepSpeedDriver(model, device, **kwargs) | |||||
| if not isinstance(device, List): | if not isinstance(device, List): | ||||
| return TorchSingleDriver(model, device, **kwargs) | return TorchSingleDriver(model, device, **kwargs) | ||||
| else: | else: | ||||
| @@ -84,11 +95,4 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi | |||||
| logger.warning_once("Notice you are using `fairscale`, but the `device` is only one gpu.") | logger.warning_once("Notice you are using `fairscale`, but the `device` is only one gpu.") | ||||
| return FairScaleDriver(model, [device], **kwargs) | return FairScaleDriver(model, [device], **kwargs) | ||||
| else: | else: | ||||
| return FairScaleDriver(model, device, **kwargs) | |||||
| elif driver == "deepspeed": | |||||
| if not isinstance(device, List): | |||||
| if device.type == 'cpu': | |||||
| raise ValueError("You are using `deepspeed` driver, but your chosen `device` is 'cpu'.") | |||||
| logger.warning_once("Notice you are using `deepspeed`, but the `device` is only one gpu.") | |||||
| return DeepSpeedDriver(model, [device], **kwargs) | |||||
| return DeepSpeedDriver(model, device, **kwargs) | |||||
| return FairScaleDriver(model, device, **kwargs) | |||||
| @@ -0,0 +1,95 @@ | |||||
| """ | |||||
| 这个文件测试多卡情况下使用 deepspeed 的情况:: | |||||
| >>> # 测试直接使用多卡 | |||||
| >>> python _test_trainer_deepspeed.py | |||||
| >>> # 测试通过 deepspeed 拉起 | |||||
| >>> deepspeed _test_trainer_deepspeed.py | |||||
| """ | |||||
| import sys | |||||
| sys.path.append("../../../") | |||||
| from dataclasses import dataclass | |||||
| from fastNLP.core.controllers.trainer import Trainer | |||||
| from fastNLP.core.metrics.accuracy import Accuracy | |||||
| from fastNLP.core.callbacks.progress_callback import RichCallback | |||||
| from torch.optim import Adam | |||||
| from torch.utils.data import DataLoader | |||||
| from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
| from tests.helpers.datasets.torch_data import TorchArgMaxDataset | |||||
| @dataclass | |||||
| class TrainDeepSpeedConfig: | |||||
| num_labels: int = 3 | |||||
| feature_dimension: int = 3 | |||||
| batch_size: int = 2 | |||||
| shuffle: bool = True | |||||
| evaluate_every = 2 | |||||
| def test_trainer_deepspeed( | |||||
| device, | |||||
| callbacks, | |||||
| strategy, | |||||
| config, | |||||
| n_epochs=2, | |||||
| ): | |||||
| model = TorchNormalModel_Classification_1( | |||||
| num_labels=TrainDeepSpeedConfig.num_labels, | |||||
| feature_dimension=TrainDeepSpeedConfig.feature_dimension | |||||
| ) | |||||
| optimizers = Adam(params=model.parameters(), lr=0.0001) | |||||
| train_dataloader = DataLoader( | |||||
| dataset=TorchArgMaxDataset(TrainDeepSpeedConfig.feature_dimension, 20), | |||||
| batch_size=TrainDeepSpeedConfig.batch_size, | |||||
| shuffle=True | |||||
| ) | |||||
| val_dataloader = DataLoader( | |||||
| dataset=TorchArgMaxDataset(TrainDeepSpeedConfig.feature_dimension, 12), | |||||
| batch_size=TrainDeepSpeedConfig.batch_size, | |||||
| shuffle=True | |||||
| ) | |||||
| train_dataloader = train_dataloader | |||||
| evaluate_dataloaders = val_dataloader | |||||
| evaluate_every = TrainDeepSpeedConfig.evaluate_every | |||||
| metrics = {"acc": Accuracy()} | |||||
| if config is not None: | |||||
| config["train_micro_batch_size_per_gpu"] = TrainDeepSpeedConfig.batch_size | |||||
| trainer = Trainer( | |||||
| model=model, | |||||
| driver="torch", | |||||
| device=device, | |||||
| optimizers=optimizers, | |||||
| train_dataloader=train_dataloader, | |||||
| evaluate_dataloaders=evaluate_dataloaders, | |||||
| evaluate_every=evaluate_every, | |||||
| metrics=metrics, | |||||
| output_mapping={"preds": "pred"}, | |||||
| n_epochs=n_epochs, | |||||
| callbacks=callbacks, | |||||
| deepspeed_kwargs={ | |||||
| "strategy": strategy, | |||||
| "config": config | |||||
| } | |||||
| ) | |||||
| trainer.run() | |||||
| if __name__ == "__main__": | |||||
| device = [0,1] | |||||
| # device = [0,1,3] | |||||
| callbacks = [ | |||||
| # RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True), | |||||
| RichCallback(5), | |||||
| ] | |||||
| config = None | |||||
| test_trainer_deepspeed( | |||||
| device=device, | |||||
| callbacks=callbacks, | |||||
| strategy="deepspeed", | |||||
| config=config, | |||||
| n_epochs=5, | |||||
| ) | |||||
| @@ -0,0 +1,105 @@ | |||||
| """ | |||||
| 这个文件测试多卡情况下使用 deepspeed ,且用户自己调用了 deepspeed.initialize 的情况:: | |||||
| >>> deepspeed _test_trainer_deepspeed_outside.py | |||||
| """ | |||||
| import os | |||||
| import sys | |||||
| sys.path.append("../../../") | |||||
| from dataclasses import dataclass | |||||
| from fastNLP.core.controllers.trainer import Trainer | |||||
| from fastNLP.core.metrics.accuracy import Accuracy | |||||
| from fastNLP.core.callbacks.progress_callback import RichCallback | |||||
| from fastNLP.core.drivers.torch_driver.utils import _create_default_config | |||||
| import deepspeed | |||||
| import torch | |||||
| from torch.optim import Adam | |||||
| from torch.utils.data import DataLoader | |||||
| from tests.helpers.models.torch_model import TorchNormalModel_Classification_2 | |||||
| from tests.helpers.datasets.torch_data import TorchArgMaxDataset | |||||
| local_rank = int(os.environ["LOCAL_RANK"]) | |||||
| @dataclass | |||||
| class TrainDeepSpeedConfig: | |||||
| num_labels: int = 3 | |||||
| feature_dimension: int = 3 | |||||
| batch_size: int = 2 | |||||
| shuffle: bool = True | |||||
| evaluate_every = 2 | |||||
| def test_trainer_deepspeed( | |||||
| device, | |||||
| callbacks, | |||||
| strategy, | |||||
| config, | |||||
| n_epochs=2, | |||||
| ): | |||||
| model = TorchNormalModel_Classification_2( | |||||
| num_labels=TrainDeepSpeedConfig.num_labels, | |||||
| feature_dimension=TrainDeepSpeedConfig.feature_dimension | |||||
| ) | |||||
| optimizers = Adam(params=model.parameters(), lr=0.0001) | |||||
| train_dataloader = DataLoader( | |||||
| dataset=TorchArgMaxDataset(TrainDeepSpeedConfig.feature_dimension, 20), | |||||
| batch_size=TrainDeepSpeedConfig.batch_size, | |||||
| shuffle=True | |||||
| ) | |||||
| val_dataloader = DataLoader( | |||||
| dataset=TorchArgMaxDataset(TrainDeepSpeedConfig.feature_dimension, 12), | |||||
| batch_size=TrainDeepSpeedConfig.batch_size, | |||||
| shuffle=True | |||||
| ) | |||||
| train_dataloader = train_dataloader | |||||
| evaluate_dataloaders = val_dataloader | |||||
| evaluate_every = TrainDeepSpeedConfig.evaluate_every | |||||
| metrics = {"acc": Accuracy()} | |||||
| if config is not None: | |||||
| config["train_micro_batch_size_per_gpu"] = TrainDeepSpeedConfig.batch_size | |||||
| model, optimizers, _, _ = deepspeed.initialize( | |||||
| model=model, | |||||
| optimizer=optimizers, | |||||
| config=config, | |||||
| ) | |||||
| trainer = Trainer( | |||||
| model=model, | |||||
| driver="torch", | |||||
| device=device, | |||||
| data_device=torch.device(f"cuda:{local_rank}"), | |||||
| optimizers=optimizers, | |||||
| train_dataloader=train_dataloader, | |||||
| evaluate_dataloaders=evaluate_dataloaders, | |||||
| evaluate_every=evaluate_every, | |||||
| metrics=metrics, | |||||
| output_mapping={"preds": "pred"}, | |||||
| n_epochs=n_epochs, | |||||
| callbacks=callbacks, | |||||
| deepspeed_kwargs={ | |||||
| "strategy": strategy, | |||||
| "config": config | |||||
| } | |||||
| ) | |||||
| trainer.run() | |||||
| if __name__ == "__main__": | |||||
| device = [0,1] | |||||
| # device = [0,1,3] | |||||
| callbacks = [ | |||||
| # RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True), | |||||
| RichCallback(5), | |||||
| ] | |||||
| config = _create_default_config(stage=2) | |||||
| test_trainer_deepspeed( | |||||
| device=device, | |||||
| callbacks=callbacks, | |||||
| strategy="deepspeed", | |||||
| config=config, | |||||
| n_epochs=5, | |||||
| ) | |||||
| @@ -0,0 +1,99 @@ | |||||
| import pytest | |||||
| from dataclasses import dataclass | |||||
| from fastNLP.core.controllers.trainer import Trainer | |||||
| from fastNLP.core.metrics.accuracy import Accuracy | |||||
| from fastNLP.core.callbacks.progress_callback import RichCallback | |||||
| from fastNLP.core.drivers.torch_driver import DeepSpeedDriver | |||||
| from fastNLP.core.drivers.torch_driver.utils import _create_default_config | |||||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
| if _NEED_IMPORT_TORCH: | |||||
| import torch | |||||
| from torch.optim import Adam | |||||
| from torch.utils.data import DataLoader | |||||
| from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
| from tests.helpers.datasets.torch_data import TorchArgMaxDataset | |||||
| from tests.helpers.utils import magic_argv_env_context | |||||
| @dataclass | |||||
| class TrainDeepSpeedConfig: | |||||
| num_labels: int = 3 | |||||
| feature_dimension: int = 3 | |||||
| batch_size: int = 2 | |||||
| shuffle: bool = True | |||||
| evaluate_every = 2 | |||||
| @pytest.mark.deepspeed | |||||
| class TestTrainer: | |||||
| @classmethod | |||||
| def setup_class(cls): | |||||
| # 不初始化的话从第二个测试例开始会因为环境变量报错。 | |||||
| torch_model = TorchNormalModel_Classification_1(1, 1) | |||||
| torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01) | |||||
| device = [torch.device(i) for i in [0,1]] | |||||
| driver = DeepSpeedDriver( | |||||
| model=torch_model, | |||||
| parallel_device=device, | |||||
| ) | |||||
| driver.set_optimizers(torch_opt) | |||||
| driver.setup() | |||||
| return driver | |||||
| @pytest.mark.parametrize("device", [[0, 1]]) | |||||
| @pytest.mark.parametrize("callbacks", [[RichCallback(5)]]) | |||||
| @pytest.mark.parametrize("strategy", ["deepspeed", "deepspeed_stage_1"]) | |||||
| @pytest.mark.parametrize("config", [None, _create_default_config(stage=1)]) | |||||
| @magic_argv_env_context | |||||
| def test_trainer_deepspeed( | |||||
| self, | |||||
| device, | |||||
| callbacks, | |||||
| strategy, | |||||
| config, | |||||
| n_epochs=2, | |||||
| ): | |||||
| model = TorchNormalModel_Classification_1( | |||||
| num_labels=TrainDeepSpeedConfig.num_labels, | |||||
| feature_dimension=TrainDeepSpeedConfig.feature_dimension | |||||
| ) | |||||
| optimizers = Adam(params=model.parameters(), lr=0.0001) | |||||
| train_dataloader = DataLoader( | |||||
| dataset=TorchArgMaxDataset(TrainDeepSpeedConfig.feature_dimension, 20), | |||||
| batch_size=TrainDeepSpeedConfig.batch_size, | |||||
| shuffle=True | |||||
| ) | |||||
| val_dataloader = DataLoader( | |||||
| dataset=TorchArgMaxDataset(TrainDeepSpeedConfig.feature_dimension, 12), | |||||
| batch_size=TrainDeepSpeedConfig.batch_size, | |||||
| shuffle=True | |||||
| ) | |||||
| train_dataloader = train_dataloader | |||||
| evaluate_dataloaders = val_dataloader | |||||
| evaluate_every = TrainDeepSpeedConfig.evaluate_every | |||||
| metrics = {"acc": Accuracy()} | |||||
| if config is not None: | |||||
| config["train_micro_batch_size_per_gpu"] = TrainDeepSpeedConfig.batch_size | |||||
| trainer = Trainer( | |||||
| model=model, | |||||
| driver="torch", | |||||
| device=device, | |||||
| optimizers=optimizers, | |||||
| train_dataloader=train_dataloader, | |||||
| evaluate_dataloaders=evaluate_dataloaders, | |||||
| evaluate_every=evaluate_every, | |||||
| metrics=metrics, | |||||
| output_mapping={"preds": "pred"}, | |||||
| n_epochs=n_epochs, | |||||
| callbacks=callbacks, | |||||
| deepspeed_kwargs={ | |||||
| "strategy": strategy, | |||||
| "config": config | |||||
| } | |||||
| ) | |||||
| trainer.run() | |||||