@@ -14,12 +14,13 @@ from fastNLP.envs import( | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_DEEPSPEED | from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_DEEPSPEED | ||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
import pytorch_lightning | |||||
import torch | import torch | ||||
import torch.distributed as dist | import torch.distributed as dist | ||||
from torch.optim import Optimizer | |||||
if _NEED_IMPORT_DEEPSPEED: | if _NEED_IMPORT_DEEPSPEED: | ||||
import deepspeed | import deepspeed | ||||
from deepspeed import DeepSpeedEngine, DeepSpeedOptimizer | |||||
__all__ = [ | __all__ = [ | ||||
"DeepSpeedDriver", | "DeepSpeedDriver", | ||||
@@ -33,7 +34,6 @@ class DeepSpeedDriver(TorchDDPDriver): | |||||
parallel_device: Union[List["torch.device"], "torch.device"], | parallel_device: Union[List["torch.device"], "torch.device"], | ||||
is_pull_by_torch_run = False, | is_pull_by_torch_run = False, | ||||
fp16: bool = False, | fp16: bool = False, | ||||
strategy= "deepspeed", | |||||
**kwargs | **kwargs | ||||
): | ): | ||||
assert _NEED_IMPORT_DEEPSPEED, "Deepspeed is not imported." | assert _NEED_IMPORT_DEEPSPEED, "Deepspeed is not imported." | ||||
@@ -56,8 +56,22 @@ class DeepSpeedDriver(TorchDDPDriver): | |||||
# 我们的 model_device 一定是 torch.device,而不是一个 list; | # 我们的 model_device 一定是 torch.device,而不是一个 list; | ||||
self.model_device = parallel_device[self.local_rank] | self.model_device = parallel_device[self.local_rank] | ||||
# 暂时不允许在外面初始化 | |||||
# 如果用户自己在外面初始化了 deepspeed; | |||||
self.outside_ddp = False | self.outside_ddp = False | ||||
if dist.is_initialized() and FASTNLP_DISTRIBUTED_CHECK not in os.environ and \ | |||||
"fastnlp_torch_launch_not_ddp" not in os.environ: | |||||
# 如果用户自己在外面初始化了 deepspeed,那么我们要求用户传入的模型一定是已经由 DeepSpeedEngine 包裹后的模型; | |||||
if not isinstance(model, DeepSpeedEngine): | |||||
raise RuntimeError( | |||||
"It is not allowed to input a normal model instead of `DeepSpeedEngine` when" | |||||
"you initialize the ddp process out of our control.") | |||||
self.outside_ddp = True | |||||
self.config = model.config | |||||
# 用户只有将模型上传到对应机器上后才能用 DistributedDataParallel 包裹,因此如果用户在外面初始化了 DDP,那么在 TorchDDPDriver 中 | |||||
# 我们就直接将 model_device 置为 None; | |||||
self.model_device = None | |||||
self._data_device = kwargs.get("data_device", None) | self._data_device = kwargs.get("data_device", None) | ||||
if isinstance(self._data_device, int): | if isinstance(self._data_device, int): | ||||
if self._data_device < 0: | if self._data_device < 0: | ||||
@@ -84,7 +98,6 @@ class DeepSpeedDriver(TorchDDPDriver): | |||||
self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的; | self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的; | ||||
self._has_ddpwrapped = False # 判断传入的模型是否经过 _has_ddpwrapped 包裹; | self._has_ddpwrapped = False # 判断传入的模型是否经过 _has_ddpwrapped 包裹; | ||||
self.strategy = strategy | |||||
self.accumulation_steps = kwargs.get("accumulation_steps", 1) | self.accumulation_steps = kwargs.get("accumulation_steps", 1) | ||||
# 获取 batch_size 以设置 train_micro_batch_size_per_gpu 参数 | # 获取 batch_size 以设置 train_micro_batch_size_per_gpu 参数 | ||||
train_dl = kwargs.get("train_dataloader", None) | train_dl = kwargs.get("train_dataloader", None) | ||||
@@ -96,6 +109,14 @@ class DeepSpeedDriver(TorchDDPDriver): | |||||
self.train_micro_batch_size = 1 | self.train_micro_batch_size = 1 | ||||
self._ds_kwargs = kwargs.get("deepspeed_kwargs", {}) | self._ds_kwargs = kwargs.get("deepspeed_kwargs", {}) | ||||
self.strategy = self._ds_kwargs.get("strategy", "deepspeed") | |||||
@staticmethod | |||||
def _check_optimizer_legality(optimizers): | |||||
for each_optimizer in optimizers: | |||||
if not isinstance(each_optimizer, (Optimizer, DeepSpeedOptimizer)): | |||||
raise TypeError(f"Each optimizer of parameter `optimizers` should be 'Optimizer' or " | |||||
f"'DeepSpeedOptimizer'type, not {type(each_optimizer)}.") | |||||
def setup(self): | def setup(self): | ||||
r""" | r""" | ||||
@@ -112,15 +133,19 @@ class DeepSpeedDriver(TorchDDPDriver): | |||||
self.setup_config() | self.setup_config() | ||||
# 如果用户需要使用多机模式,那么一定进入到这里; | # 如果用户需要使用多机模式,那么一定进入到这里; | ||||
if self.is_pull_by_torch_run: | if self.is_pull_by_torch_run: | ||||
# dist.get_world_size() 只能在 dist.init_process_group 初始化之后进行调用; | |||||
self.world_size = int(os.environ.get("WORLD_SIZE")) | |||||
self.global_rank = int(os.environ.get("RANK")) | |||||
logger.info(f"World size: {self.world_size}, Global rank: {self.global_rank}") | |||||
if self.outside_ddp: | |||||
self.world_size = dist.get_world_size() | |||||
self.global_rank = dist.get_rank() | |||||
else: | |||||
# dist.get_world_size() 只能在 dist.init_process_group 初始化之后进行调用; | |||||
self.world_size = int(os.environ.get("WORLD_SIZE")) | |||||
self.global_rank = int(os.environ.get("RANK")) | |||||
logger.info(f"World size: {self.world_size}, Global rank: {self.global_rank}") | |||||
if not dist.is_initialized(): | |||||
deepspeed.init_distributed("nccl", distributed_port=self.master_port) | |||||
if not dist.is_initialized(): | |||||
deepspeed.init_distributed("nccl", distributed_port=self.master_port) | |||||
os.environ["fastnlp_torch_launch_not_ddp"] = "yes" | |||||
os.environ["fastnlp_torch_launch_not_ddp"] = "yes" | |||||
# 进入到这里的情况时: | # 进入到这里的情况时: | ||||
# dist.is_initialized 一定为 False; | # dist.is_initialized 一定为 False; | ||||
@@ -146,8 +171,9 @@ class DeepSpeedDriver(TorchDDPDriver): | |||||
self.world_size = dist.get_world_size() | self.world_size = dist.get_world_size() | ||||
self.global_rank = dist.get_rank() | self.global_rank = dist.get_rank() | ||||
torch.cuda.set_device(self.model_device) | |||||
self.configure_ddp() | |||||
if not self.outside_ddp: | |||||
torch.cuda.set_device(self.model_device) | |||||
self.configure_ddp() | |||||
self.barrier() | self.barrier() | ||||
# 初始化 self._pids,从而使得每一个进程都能接受到 rank0 的 send 操作; | # 初始化 self._pids,从而使得每一个进程都能接受到 rank0 的 send 操作; | ||||
@@ -166,7 +192,7 @@ class DeepSpeedDriver(TorchDDPDriver): | |||||
def configure_ddp(self): | def configure_ddp(self): | ||||
# 设置 deepspeed | # 设置 deepspeed | ||||
if not isinstance(self.model, deepspeed.DeepSpeedEngine): | |||||
if not isinstance(self.model, DeepSpeedEngine): | |||||
model=_DeepSpeedWrappingModel(self.model, self.fp16) | model=_DeepSpeedWrappingModel(self.model, self.fp16) | ||||
model_parameters = filter(lambda p: p.requires_grad, model.parameters()) | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) | ||||
self.model, ds_optimizer, _, _ = deepspeed.initialize( | self.model, ds_optimizer, _, _ = deepspeed.initialize( | ||||
@@ -193,7 +219,6 @@ class DeepSpeedDriver(TorchDDPDriver): | |||||
self.config = self._ds_kwargs.get("config") | self.config = self._ds_kwargs.get("config") | ||||
if self.config is not None: | if self.config is not None: | ||||
# TODO 究竟哪些参数按照config,哪些按照trainer参数 | |||||
logger.warn("Notice that you have defined a configuration for deepspeed and parameters like" | logger.warn("Notice that you have defined a configuration for deepspeed and parameters like" | ||||
"`optimizers`, `strategy` and `fp16` may not take effects.") | "`optimizers`, `strategy` and `fp16` may not take effects.") | ||||
return | return | ||||
@@ -258,12 +283,6 @@ class DeepSpeedDriver(TorchDDPDriver): | |||||
def step(self): | def step(self): | ||||
self.model.step() | self.model.step() | ||||
def unwrap_model(self): | |||||
r""" | |||||
:return: 返回原本的模型; | |||||
""" | |||||
return self.model.module.model | |||||
def get_model_no_sync_context(self): | def get_model_no_sync_context(self): | ||||
r""" | r""" | ||||
:return: 返回一个 ``context`` 上下文环境,用于关闭各个进程之间的同步;在 ``deepspeed`` 中,返回一个空的上下文 | :return: 返回一个 ``context`` 上下文环境,用于关闭各个进程之间的同步;在 ``deepspeed`` 中,返回一个空的上下文 | ||||
@@ -38,6 +38,9 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi | |||||
if driver == 'fairscale': | if driver == 'fairscale': | ||||
return FairScaleDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), | return FairScaleDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), | ||||
is_pull_by_torch_run=True, **kwargs) | is_pull_by_torch_run=True, **kwargs) | ||||
elif kwargs.get("deepspeed_kwargs") is not None: | |||||
return DeepSpeedDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), | |||||
is_pull_by_torch_run=True, **kwargs) | |||||
else: | else: | ||||
return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), | return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), | ||||
is_pull_by_torch_run=True, **kwargs) | is_pull_by_torch_run=True, **kwargs) | ||||
@@ -73,6 +76,14 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi | |||||
raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") | raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") | ||||
if driver == "torch": # single, ddp, 直接启动。 | if driver == "torch": # single, ddp, 直接启动。 | ||||
if kwargs.get("deepspeed_kwargs") is not None: | |||||
# 选择的是 deepspeed | |||||
if not isinstance(device, List): | |||||
if device.type == 'cpu': | |||||
raise ValueError("You are using `deepspeed` driver, but your chosen `device` is 'cpu'.") | |||||
logger.warning_once("Notice you are using `deepspeed`, but the `device` is only one gpu.") | |||||
return DeepSpeedDriver(model, [device], **kwargs) | |||||
return DeepSpeedDriver(model, device, **kwargs) | |||||
if not isinstance(device, List): | if not isinstance(device, List): | ||||
return TorchSingleDriver(model, device, **kwargs) | return TorchSingleDriver(model, device, **kwargs) | ||||
else: | else: | ||||
@@ -84,11 +95,4 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi | |||||
logger.warning_once("Notice you are using `fairscale`, but the `device` is only one gpu.") | logger.warning_once("Notice you are using `fairscale`, but the `device` is only one gpu.") | ||||
return FairScaleDriver(model, [device], **kwargs) | return FairScaleDriver(model, [device], **kwargs) | ||||
else: | else: | ||||
return FairScaleDriver(model, device, **kwargs) | |||||
elif driver == "deepspeed": | |||||
if not isinstance(device, List): | |||||
if device.type == 'cpu': | |||||
raise ValueError("You are using `deepspeed` driver, but your chosen `device` is 'cpu'.") | |||||
logger.warning_once("Notice you are using `deepspeed`, but the `device` is only one gpu.") | |||||
return DeepSpeedDriver(model, [device], **kwargs) | |||||
return DeepSpeedDriver(model, device, **kwargs) | |||||
return FairScaleDriver(model, device, **kwargs) |
@@ -0,0 +1,95 @@ | |||||
""" | |||||
这个文件测试多卡情况下使用 deepspeed 的情况:: | |||||
>>> # 测试直接使用多卡 | |||||
>>> python _test_trainer_deepspeed.py | |||||
>>> # 测试通过 deepspeed 拉起 | |||||
>>> deepspeed _test_trainer_deepspeed.py | |||||
""" | |||||
import sys | |||||
sys.path.append("../../../") | |||||
from dataclasses import dataclass | |||||
from fastNLP.core.controllers.trainer import Trainer | |||||
from fastNLP.core.metrics.accuracy import Accuracy | |||||
from fastNLP.core.callbacks.progress_callback import RichCallback | |||||
from torch.optim import Adam | |||||
from torch.utils.data import DataLoader | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from tests.helpers.datasets.torch_data import TorchArgMaxDataset | |||||
@dataclass | |||||
class TrainDeepSpeedConfig: | |||||
num_labels: int = 3 | |||||
feature_dimension: int = 3 | |||||
batch_size: int = 2 | |||||
shuffle: bool = True | |||||
evaluate_every = 2 | |||||
def test_trainer_deepspeed( | |||||
device, | |||||
callbacks, | |||||
strategy, | |||||
config, | |||||
n_epochs=2, | |||||
): | |||||
model = TorchNormalModel_Classification_1( | |||||
num_labels=TrainDeepSpeedConfig.num_labels, | |||||
feature_dimension=TrainDeepSpeedConfig.feature_dimension | |||||
) | |||||
optimizers = Adam(params=model.parameters(), lr=0.0001) | |||||
train_dataloader = DataLoader( | |||||
dataset=TorchArgMaxDataset(TrainDeepSpeedConfig.feature_dimension, 20), | |||||
batch_size=TrainDeepSpeedConfig.batch_size, | |||||
shuffle=True | |||||
) | |||||
val_dataloader = DataLoader( | |||||
dataset=TorchArgMaxDataset(TrainDeepSpeedConfig.feature_dimension, 12), | |||||
batch_size=TrainDeepSpeedConfig.batch_size, | |||||
shuffle=True | |||||
) | |||||
train_dataloader = train_dataloader | |||||
evaluate_dataloaders = val_dataloader | |||||
evaluate_every = TrainDeepSpeedConfig.evaluate_every | |||||
metrics = {"acc": Accuracy()} | |||||
if config is not None: | |||||
config["train_micro_batch_size_per_gpu"] = TrainDeepSpeedConfig.batch_size | |||||
trainer = Trainer( | |||||
model=model, | |||||
driver="torch", | |||||
device=device, | |||||
optimizers=optimizers, | |||||
train_dataloader=train_dataloader, | |||||
evaluate_dataloaders=evaluate_dataloaders, | |||||
evaluate_every=evaluate_every, | |||||
metrics=metrics, | |||||
output_mapping={"preds": "pred"}, | |||||
n_epochs=n_epochs, | |||||
callbacks=callbacks, | |||||
deepspeed_kwargs={ | |||||
"strategy": strategy, | |||||
"config": config | |||||
} | |||||
) | |||||
trainer.run() | |||||
if __name__ == "__main__": | |||||
device = [0,1] | |||||
# device = [0,1,3] | |||||
callbacks = [ | |||||
# RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True), | |||||
RichCallback(5), | |||||
] | |||||
config = None | |||||
test_trainer_deepspeed( | |||||
device=device, | |||||
callbacks=callbacks, | |||||
strategy="deepspeed", | |||||
config=config, | |||||
n_epochs=5, | |||||
) |
@@ -0,0 +1,105 @@ | |||||
""" | |||||
这个文件测试多卡情况下使用 deepspeed ,且用户自己调用了 deepspeed.initialize 的情况:: | |||||
>>> deepspeed _test_trainer_deepspeed_outside.py | |||||
""" | |||||
import os | |||||
import sys | |||||
sys.path.append("../../../") | |||||
from dataclasses import dataclass | |||||
from fastNLP.core.controllers.trainer import Trainer | |||||
from fastNLP.core.metrics.accuracy import Accuracy | |||||
from fastNLP.core.callbacks.progress_callback import RichCallback | |||||
from fastNLP.core.drivers.torch_driver.utils import _create_default_config | |||||
import deepspeed | |||||
import torch | |||||
from torch.optim import Adam | |||||
from torch.utils.data import DataLoader | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_2 | |||||
from tests.helpers.datasets.torch_data import TorchArgMaxDataset | |||||
local_rank = int(os.environ["LOCAL_RANK"]) | |||||
@dataclass | |||||
class TrainDeepSpeedConfig: | |||||
num_labels: int = 3 | |||||
feature_dimension: int = 3 | |||||
batch_size: int = 2 | |||||
shuffle: bool = True | |||||
evaluate_every = 2 | |||||
def test_trainer_deepspeed( | |||||
device, | |||||
callbacks, | |||||
strategy, | |||||
config, | |||||
n_epochs=2, | |||||
): | |||||
model = TorchNormalModel_Classification_2( | |||||
num_labels=TrainDeepSpeedConfig.num_labels, | |||||
feature_dimension=TrainDeepSpeedConfig.feature_dimension | |||||
) | |||||
optimizers = Adam(params=model.parameters(), lr=0.0001) | |||||
train_dataloader = DataLoader( | |||||
dataset=TorchArgMaxDataset(TrainDeepSpeedConfig.feature_dimension, 20), | |||||
batch_size=TrainDeepSpeedConfig.batch_size, | |||||
shuffle=True | |||||
) | |||||
val_dataloader = DataLoader( | |||||
dataset=TorchArgMaxDataset(TrainDeepSpeedConfig.feature_dimension, 12), | |||||
batch_size=TrainDeepSpeedConfig.batch_size, | |||||
shuffle=True | |||||
) | |||||
train_dataloader = train_dataloader | |||||
evaluate_dataloaders = val_dataloader | |||||
evaluate_every = TrainDeepSpeedConfig.evaluate_every | |||||
metrics = {"acc": Accuracy()} | |||||
if config is not None: | |||||
config["train_micro_batch_size_per_gpu"] = TrainDeepSpeedConfig.batch_size | |||||
model, optimizers, _, _ = deepspeed.initialize( | |||||
model=model, | |||||
optimizer=optimizers, | |||||
config=config, | |||||
) | |||||
trainer = Trainer( | |||||
model=model, | |||||
driver="torch", | |||||
device=device, | |||||
data_device=torch.device(f"cuda:{local_rank}"), | |||||
optimizers=optimizers, | |||||
train_dataloader=train_dataloader, | |||||
evaluate_dataloaders=evaluate_dataloaders, | |||||
evaluate_every=evaluate_every, | |||||
metrics=metrics, | |||||
output_mapping={"preds": "pred"}, | |||||
n_epochs=n_epochs, | |||||
callbacks=callbacks, | |||||
deepspeed_kwargs={ | |||||
"strategy": strategy, | |||||
"config": config | |||||
} | |||||
) | |||||
trainer.run() | |||||
if __name__ == "__main__": | |||||
device = [0,1] | |||||
# device = [0,1,3] | |||||
callbacks = [ | |||||
# RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True), | |||||
RichCallback(5), | |||||
] | |||||
config = _create_default_config(stage=2) | |||||
test_trainer_deepspeed( | |||||
device=device, | |||||
callbacks=callbacks, | |||||
strategy="deepspeed", | |||||
config=config, | |||||
n_epochs=5, | |||||
) |
@@ -0,0 +1,99 @@ | |||||
import pytest | |||||
from dataclasses import dataclass | |||||
from fastNLP.core.controllers.trainer import Trainer | |||||
from fastNLP.core.metrics.accuracy import Accuracy | |||||
from fastNLP.core.callbacks.progress_callback import RichCallback | |||||
from fastNLP.core.drivers.torch_driver import DeepSpeedDriver | |||||
from fastNLP.core.drivers.torch_driver.utils import _create_default_config | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
from torch.optim import Adam | |||||
from torch.utils.data import DataLoader | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from tests.helpers.datasets.torch_data import TorchArgMaxDataset | |||||
from tests.helpers.utils import magic_argv_env_context | |||||
@dataclass | |||||
class TrainDeepSpeedConfig: | |||||
num_labels: int = 3 | |||||
feature_dimension: int = 3 | |||||
batch_size: int = 2 | |||||
shuffle: bool = True | |||||
evaluate_every = 2 | |||||
@pytest.mark.deepspeed | |||||
class TestTrainer: | |||||
@classmethod | |||||
def setup_class(cls): | |||||
# 不初始化的话从第二个测试例开始会因为环境变量报错。 | |||||
torch_model = TorchNormalModel_Classification_1(1, 1) | |||||
torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01) | |||||
device = [torch.device(i) for i in [0,1]] | |||||
driver = DeepSpeedDriver( | |||||
model=torch_model, | |||||
parallel_device=device, | |||||
) | |||||
driver.set_optimizers(torch_opt) | |||||
driver.setup() | |||||
return driver | |||||
@pytest.mark.parametrize("device", [[0, 1]]) | |||||
@pytest.mark.parametrize("callbacks", [[RichCallback(5)]]) | |||||
@pytest.mark.parametrize("strategy", ["deepspeed", "deepspeed_stage_1"]) | |||||
@pytest.mark.parametrize("config", [None, _create_default_config(stage=1)]) | |||||
@magic_argv_env_context | |||||
def test_trainer_deepspeed( | |||||
self, | |||||
device, | |||||
callbacks, | |||||
strategy, | |||||
config, | |||||
n_epochs=2, | |||||
): | |||||
model = TorchNormalModel_Classification_1( | |||||
num_labels=TrainDeepSpeedConfig.num_labels, | |||||
feature_dimension=TrainDeepSpeedConfig.feature_dimension | |||||
) | |||||
optimizers = Adam(params=model.parameters(), lr=0.0001) | |||||
train_dataloader = DataLoader( | |||||
dataset=TorchArgMaxDataset(TrainDeepSpeedConfig.feature_dimension, 20), | |||||
batch_size=TrainDeepSpeedConfig.batch_size, | |||||
shuffle=True | |||||
) | |||||
val_dataloader = DataLoader( | |||||
dataset=TorchArgMaxDataset(TrainDeepSpeedConfig.feature_dimension, 12), | |||||
batch_size=TrainDeepSpeedConfig.batch_size, | |||||
shuffle=True | |||||
) | |||||
train_dataloader = train_dataloader | |||||
evaluate_dataloaders = val_dataloader | |||||
evaluate_every = TrainDeepSpeedConfig.evaluate_every | |||||
metrics = {"acc": Accuracy()} | |||||
if config is not None: | |||||
config["train_micro_batch_size_per_gpu"] = TrainDeepSpeedConfig.batch_size | |||||
trainer = Trainer( | |||||
model=model, | |||||
driver="torch", | |||||
device=device, | |||||
optimizers=optimizers, | |||||
train_dataloader=train_dataloader, | |||||
evaluate_dataloaders=evaluate_dataloaders, | |||||
evaluate_every=evaluate_every, | |||||
metrics=metrics, | |||||
output_mapping={"preds": "pred"}, | |||||
n_epochs=n_epochs, | |||||
callbacks=callbacks, | |||||
deepspeed_kwargs={ | |||||
"strategy": strategy, | |||||
"config": config | |||||
} | |||||
) | |||||
trainer.run() |