@@ -59,11 +59,12 @@ class Trainer(TrainerEventTrigger): | |||
1. 值为 ``"auto"`` 时,**fastNLP** 会根据传入模型的类型自行判断使用哪一种模式; | |||
2. 其值为 ``"torch"`` 时,表示使用 :class:`~fastNLP.core.drivers.TorchSingleDriver` 或者 :class:`~fastNLP.core.drivers.TorchDDPDriver`; | |||
3. 其值为 ``"paddle"`` 时,表示使用 :class:`~fastNLP.core.drivers.PaddleSingleDriver` 或者 :class:`~fastNLP.core.drivers.PaddleFleetDriver`; | |||
4. 其值为 ``"jittor"`` 时,表示使用 :class:`~fastNLP.core.drivers.JittorSingleDriver` 或者 :class:`~fastNLP.core.drivers.JittorMPIDriver`; | |||
5. 其值为 ``"fairscale"`` 时,表示使用 :class:`~fastNLP.core.drivers.FairScaleDriver`; | |||
6. 其值为 ``"deepspeed"`` 时,表示使用 :class:`~fastNLP.core.drivers.DeepSpeedDriver`; | |||
7. 其值为 ``"oneflow"`` 时,表示使用 :class:`~fastNLP.core.drivers.OneflowSingleDriver` 或者 :class:`~fastNLP.core.drivers.OneflowDDPDriver`; | |||
3. 其值为 ``"torch_fsdp"`` 时,表示使用 :class:`~fastNLP.core.drivers.TorchFSDPDriver`; | |||
4. 其值为 ``"paddle"`` 时,表示使用 :class:`~fastNLP.core.drivers.PaddleSingleDriver` 或者 :class:`~fastNLP.core.drivers.PaddleFleetDriver`; | |||
5. 其值为 ``"jittor"`` 时,表示使用 :class:`~fastNLP.core.drivers.JittorSingleDriver` 或者 :class:`~fastNLP.core.drivers.JittorMPIDriver`; | |||
6. 其值为 ``"fairscale"`` 时,表示使用 :class:`~fastNLP.core.drivers.FairScaleDriver`; | |||
7. 其值为 ``"deepspeed"`` 时,表示使用 :class:`~fastNLP.core.drivers.DeepSpeedDriver`; | |||
8. 其值为 ``"oneflow"`` 时,表示使用 :class:`~fastNLP.core.drivers.OneflowSingleDriver` 或者 :class:`~fastNLP.core.drivers.OneflowDDPDriver`; | |||
在指定了框架的情况下,具体使用哪一种取决于参数 ``device`` 的设置; | |||
@@ -301,6 +302,21 @@ class Trainer(TrainerEventTrigger): | |||
:kwargs: | |||
* *torch_kwargs* -- ``TorchDriver`` 所需的其它参数,详见 :class:`~fastNLP.core.drivers.torch_driver.TorchSingleDriver` 和 | |||
:class:`~fastNLP.core.drivers.torch_driver.TorchDDPDriver`; | |||
.. note:: | |||
注意如果对于 ``TorchDDPDriver`` 中初始化 ``DistributedDataParallel`` 时有特别的参数,您可以通过在 ``torch_kwargs`` 中传入 | |||
``ddp_kwargs`` 来实现,例如: | |||
.. code-block:: | |||
trainer = Trainer( | |||
..., | |||
torch_kwargs = {'ddp_kwargs': {'find_unused_parameters': True, ...}} | |||
) | |||
对于 ``TorchFSDPDriver`` 也是类似,只是对应的 ``**_kwargs`` 修改为 ``fsdp_kwargs``; | |||
* *paddle_kwargs* -- ``PaddleDriver`` 所需的其它参数,详见 :class:`~fastNLP.core.drivers.paddle_driver.PaddleSingleDriver` 和 | |||
:class:`~fastNLP.core.drivers.paddle_driver.PaddleSingleDriver`; | |||
* *fairscale_kwargs* -- ``FairScaleDriver`` 所需的其它参数,详见 :class:`~fastNLP.core.drivers.torch_driver.FairScaleDriver`; | |||
@@ -48,6 +48,12 @@ class TorchFSDPDriver(TorchDDPDriver): | |||
``TorchFSDPDriver`` 现在还不支持断点重训功能,但是支持保存模型和加载模型; | |||
注意当您在加载和保存模型的 checkpointcallback 的时候,您可以通过在初始化 ``Trainer`` 时传入 | |||
``torch_kwargs={"fsdp_kwargs": {'save_on_rank0': True/False, 'load_on_rank0': True/False}}`` 来指定保存模型的行为: | |||
1. save/load_on_rank0 = True:表示在加载和保存模型时将所有 rank 上的模型参数全部聚合到 rank0 上,注意这样可能会造成 OOM; | |||
2. save/load_on_rank0 = False:表示每个 rank 分别保存加载自己独有的模型参数; | |||
""" | |||
def __init__( | |||
@@ -318,7 +318,7 @@ def test_torch_distributed_launch_2(version): | |||
@pytest.mark.torch | |||
@pytest.mark.parametrize("driver,device", [("torch", 0), ("torch", [0, 1])]) | |||
@pytest.mark.parametrize("driver,device", [("torch", 0), ("torch", [0, 1]), ("torch_fsdp", [0, 1])]) | |||
@magic_argv_env_context | |||
def test_torch_wo_auto_param_call( | |||
driver, | |||
@@ -363,7 +363,7 @@ def test_torch_wo_auto_param_call( | |||
# 测试 accumulation_steps; | |||
@pytest.mark.torch | |||
@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1])]) | |||
@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1]), ("torch_fsdp", [0, 1])]) | |||
@pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)]) | |||
@magic_argv_env_context | |||
def test_trainer_overfit_torch( | |||
@@ -71,7 +71,6 @@ def model_and_optimizers(request): | |||
@magic_argv_env_context | |||
def test_trainer_torch_without_evaluator( | |||
model_and_optimizers: TrainerParameters, | |||
n_epochs=3, | |||
): | |||
callbacks = [RecordLossCallback(loss_threshold=0.5)] | |||
trainer = Trainer( | |||
@@ -98,14 +97,14 @@ def test_trainer_torch_without_evaluator( | |||
@pytest.mark.torch | |||
@pytest.mark.parametrize("driver,device", [("torch_fsdp", [4, 5])]) | |||
@pytest.mark.parametrize("save_on_rank0", [True, False]) | |||
@magic_argv_env_context(timeout=100) | |||
def test_model_checkpoint_callback_1( | |||
model_and_optimizers: TrainerParameters, | |||
driver, | |||
device | |||
save_on_rank0 | |||
): | |||
for version in [0]: | |||
device = [6, 7] | |||
for version in [0, 1]: | |||
# 需要在每一个循环开始重新初始化 model,是因为 fsdp 会将当前卡上的 model 删除,从而导致这个引用实际上引用到的是一个空模型; | |||
model_and_optimizers.model = TorchNormalModel_Classification_1( | |||
num_labels=ArgMaxDatasetConfig.num_labels, | |||
@@ -128,7 +127,7 @@ def test_model_checkpoint_callback_1( | |||
trainer = Trainer( | |||
model=model_and_optimizers.model, | |||
driver=driver, | |||
driver="torch_fsdp", | |||
device=device, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
@@ -139,7 +138,7 @@ def test_model_checkpoint_callback_1( | |||
n_epochs=10, | |||
callbacks=callbacks, | |||
output_from_new_proc="all", | |||
# torch_kwargs={"fsdp_kwargs": {'save_on_rank0': True}} | |||
torch_kwargs={"fsdp_kwargs": {'save_on_rank0': True, 'load_on_rank0': True}} if save_on_rank0 else None | |||
) | |||
trainer.run() | |||
@@ -165,7 +164,7 @@ def test_model_checkpoint_callback_1( | |||
step_save_path = all_saved_model_paths["model-epoch_9-batch_123"] | |||
assert len(all_saved_model_paths) == 11 | |||
all_state_dicts = [epoch_save_path]#, step_save_path] | |||
all_state_dicts = [epoch_save_path, step_save_path] | |||
elif version == 1: | |||
@@ -214,7 +213,7 @@ def test_model_checkpoint_callback_1( | |||
trainer = Trainer( | |||
model=model_and_optimizers.model, | |||
driver=driver, | |||
driver="torch_fsdp", | |||
device=device, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
@@ -223,9 +222,10 @@ def test_model_checkpoint_callback_1( | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
n_epochs=20, | |||
n_epochs=2, | |||
output_from_new_proc="all", | |||
torch_kwargs={ | |||
"fsdp_kwargs": {'save_on_rank0': True, 'load_on_rank0': True}} if save_on_rank0 else None | |||
) | |||
trainer.load_model(folder, only_state_dict=True) | |||
@@ -238,9 +238,6 @@ def test_model_checkpoint_callback_1( | |||
dist.destroy_process_group() | |||
@pytest.mark.skip("现在 fsdp 还不支持断点重训;") | |||
@pytest.mark.torch | |||
@pytest.mark.parametrize("driver,device", [("torch_fsdp", [6, 7])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||