diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index b993912e..822fce27 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -59,11 +59,12 @@ class Trainer(TrainerEventTrigger): 1. 值为 ``"auto"`` 时,**fastNLP** 会根据传入模型的类型自行判断使用哪一种模式; 2. 其值为 ``"torch"`` 时,表示使用 :class:`~fastNLP.core.drivers.TorchSingleDriver` 或者 :class:`~fastNLP.core.drivers.TorchDDPDriver`; - 3. 其值为 ``"paddle"`` 时,表示使用 :class:`~fastNLP.core.drivers.PaddleSingleDriver` 或者 :class:`~fastNLP.core.drivers.PaddleFleetDriver`; - 4. 其值为 ``"jittor"`` 时,表示使用 :class:`~fastNLP.core.drivers.JittorSingleDriver` 或者 :class:`~fastNLP.core.drivers.JittorMPIDriver`; - 5. 其值为 ``"fairscale"`` 时,表示使用 :class:`~fastNLP.core.drivers.FairScaleDriver`; - 6. 其值为 ``"deepspeed"`` 时,表示使用 :class:`~fastNLP.core.drivers.DeepSpeedDriver`; - 7. 其值为 ``"oneflow"`` 时,表示使用 :class:`~fastNLP.core.drivers.OneflowSingleDriver` 或者 :class:`~fastNLP.core.drivers.OneflowDDPDriver`; + 3. 其值为 ``"torch_fsdp"`` 时,表示使用 :class:`~fastNLP.core.drivers.TorchFSDPDriver`; + 4. 其值为 ``"paddle"`` 时,表示使用 :class:`~fastNLP.core.drivers.PaddleSingleDriver` 或者 :class:`~fastNLP.core.drivers.PaddleFleetDriver`; + 5. 其值为 ``"jittor"`` 时,表示使用 :class:`~fastNLP.core.drivers.JittorSingleDriver` 或者 :class:`~fastNLP.core.drivers.JittorMPIDriver`; + 6. 其值为 ``"fairscale"`` 时,表示使用 :class:`~fastNLP.core.drivers.FairScaleDriver`; + 7. 其值为 ``"deepspeed"`` 时,表示使用 :class:`~fastNLP.core.drivers.DeepSpeedDriver`; + 8. 其值为 ``"oneflow"`` 时,表示使用 :class:`~fastNLP.core.drivers.OneflowSingleDriver` 或者 :class:`~fastNLP.core.drivers.OneflowDDPDriver`; 在指定了框架的情况下,具体使用哪一种取决于参数 ``device`` 的设置; @@ -301,6 +302,21 @@ class Trainer(TrainerEventTrigger): :kwargs: * *torch_kwargs* -- ``TorchDriver`` 所需的其它参数,详见 :class:`~fastNLP.core.drivers.torch_driver.TorchSingleDriver` 和 :class:`~fastNLP.core.drivers.torch_driver.TorchDDPDriver`; + + .. note:: + + 注意如果对于 ``TorchDDPDriver`` 中初始化 ``DistributedDataParallel`` 时有特别的参数,您可以通过在 ``torch_kwargs`` 中传入 + ``ddp_kwargs`` 来实现,例如: + + .. code-block:: + + trainer = Trainer( + ..., + torch_kwargs = {'ddp_kwargs': {'find_unused_parameters': True, ...}} + ) + + 对于 ``TorchFSDPDriver`` 也是类似,只是对应的 ``**_kwargs`` 修改为 ``fsdp_kwargs``; + * *paddle_kwargs* -- ``PaddleDriver`` 所需的其它参数,详见 :class:`~fastNLP.core.drivers.paddle_driver.PaddleSingleDriver` 和 :class:`~fastNLP.core.drivers.paddle_driver.PaddleSingleDriver`; * *fairscale_kwargs* -- ``FairScaleDriver`` 所需的其它参数,详见 :class:`~fastNLP.core.drivers.torch_driver.FairScaleDriver`; diff --git a/fastNLP/core/drivers/torch_driver/torch_fsdp.py b/fastNLP/core/drivers/torch_driver/torch_fsdp.py index e6011603..c6d2c5d0 100644 --- a/fastNLP/core/drivers/torch_driver/torch_fsdp.py +++ b/fastNLP/core/drivers/torch_driver/torch_fsdp.py @@ -48,6 +48,12 @@ class TorchFSDPDriver(TorchDDPDriver): ``TorchFSDPDriver`` 现在还不支持断点重训功能,但是支持保存模型和加载模型; + 注意当您在加载和保存模型的 checkpointcallback 的时候,您可以通过在初始化 ``Trainer`` 时传入 + ``torch_kwargs={"fsdp_kwargs": {'save_on_rank0': True/False, 'load_on_rank0': True/False}}`` 来指定保存模型的行为: + + 1. save/load_on_rank0 = True:表示在加载和保存模型时将所有 rank 上的模型参数全部聚合到 rank0 上,注意这样可能会造成 OOM; + 2. save/load_on_rank0 = False:表示每个 rank 分别保存加载自己独有的模型参数; + """ def __init__( diff --git a/tests/core/controllers/test_trainer_wo_evaluator_torch.py b/tests/core/controllers/test_trainer_wo_evaluator_torch.py index ce67814e..2cdbe189 100644 --- a/tests/core/controllers/test_trainer_wo_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_wo_evaluator_torch.py @@ -318,7 +318,7 @@ def test_torch_distributed_launch_2(version): @pytest.mark.torch -@pytest.mark.parametrize("driver,device", [("torch", 0), ("torch", [0, 1])]) +@pytest.mark.parametrize("driver,device", [("torch", 0), ("torch", [0, 1]), ("torch_fsdp", [0, 1])]) @magic_argv_env_context def test_torch_wo_auto_param_call( driver, @@ -363,7 +363,7 @@ def test_torch_wo_auto_param_call( # 测试 accumulation_steps; @pytest.mark.torch -@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1])]) +@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1]), ("torch_fsdp", [0, 1])]) @pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)]) @magic_argv_env_context def test_trainer_overfit_torch( diff --git a/tests/core/drivers/torch_driver/test_fsdp.py b/tests/core/drivers/torch_driver/test_fsdp.py index 9ba890ca..df8f9e91 100644 --- a/tests/core/drivers/torch_driver/test_fsdp.py +++ b/tests/core/drivers/torch_driver/test_fsdp.py @@ -71,7 +71,6 @@ def model_and_optimizers(request): @magic_argv_env_context def test_trainer_torch_without_evaluator( model_and_optimizers: TrainerParameters, - n_epochs=3, ): callbacks = [RecordLossCallback(loss_threshold=0.5)] trainer = Trainer( @@ -98,14 +97,14 @@ def test_trainer_torch_without_evaluator( @pytest.mark.torch -@pytest.mark.parametrize("driver,device", [("torch_fsdp", [4, 5])]) +@pytest.mark.parametrize("save_on_rank0", [True, False]) @magic_argv_env_context(timeout=100) def test_model_checkpoint_callback_1( model_and_optimizers: TrainerParameters, - driver, - device + save_on_rank0 ): - for version in [0]: + device = [6, 7] + for version in [0, 1]: # 需要在每一个循环开始重新初始化 model,是因为 fsdp 会将当前卡上的 model 删除,从而导致这个引用实际上引用到的是一个空模型; model_and_optimizers.model = TorchNormalModel_Classification_1( num_labels=ArgMaxDatasetConfig.num_labels, @@ -128,7 +127,7 @@ def test_model_checkpoint_callback_1( trainer = Trainer( model=model_and_optimizers.model, - driver=driver, + driver="torch_fsdp", device=device, optimizers=model_and_optimizers.optimizers, train_dataloader=model_and_optimizers.train_dataloader, @@ -139,7 +138,7 @@ def test_model_checkpoint_callback_1( n_epochs=10, callbacks=callbacks, output_from_new_proc="all", - # torch_kwargs={"fsdp_kwargs": {'save_on_rank0': True}} + torch_kwargs={"fsdp_kwargs": {'save_on_rank0': True, 'load_on_rank0': True}} if save_on_rank0 else None ) trainer.run() @@ -165,7 +164,7 @@ def test_model_checkpoint_callback_1( step_save_path = all_saved_model_paths["model-epoch_9-batch_123"] assert len(all_saved_model_paths) == 11 - all_state_dicts = [epoch_save_path]#, step_save_path] + all_state_dicts = [epoch_save_path, step_save_path] elif version == 1: @@ -214,7 +213,7 @@ def test_model_checkpoint_callback_1( trainer = Trainer( model=model_and_optimizers.model, - driver=driver, + driver="torch_fsdp", device=device, optimizers=model_and_optimizers.optimizers, train_dataloader=model_and_optimizers.train_dataloader, @@ -223,9 +222,10 @@ def test_model_checkpoint_callback_1( output_mapping=model_and_optimizers.output_mapping, metrics=model_and_optimizers.metrics, - n_epochs=20, + n_epochs=2, output_from_new_proc="all", - + torch_kwargs={ + "fsdp_kwargs": {'save_on_rank0': True, 'load_on_rank0': True}} if save_on_rank0 else None ) trainer.load_model(folder, only_state_dict=True) @@ -238,9 +238,6 @@ def test_model_checkpoint_callback_1( dist.destroy_process_group() - - - @pytest.mark.skip("现在 fsdp 还不支持断点重训;") @pytest.mark.torch @pytest.mark.parametrize("driver,device", [("torch_fsdp", [6, 7])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1)