@@ -59,11 +59,12 @@ class Trainer(TrainerEventTrigger): | |||||
1. 值为 ``"auto"`` 时,**fastNLP** 会根据传入模型的类型自行判断使用哪一种模式; | 1. 值为 ``"auto"`` 时,**fastNLP** 会根据传入模型的类型自行判断使用哪一种模式; | ||||
2. 其值为 ``"torch"`` 时,表示使用 :class:`~fastNLP.core.drivers.TorchSingleDriver` 或者 :class:`~fastNLP.core.drivers.TorchDDPDriver`; | 2. 其值为 ``"torch"`` 时,表示使用 :class:`~fastNLP.core.drivers.TorchSingleDriver` 或者 :class:`~fastNLP.core.drivers.TorchDDPDriver`; | ||||
3. 其值为 ``"paddle"`` 时,表示使用 :class:`~fastNLP.core.drivers.PaddleSingleDriver` 或者 :class:`~fastNLP.core.drivers.PaddleFleetDriver`; | |||||
4. 其值为 ``"jittor"`` 时,表示使用 :class:`~fastNLP.core.drivers.JittorSingleDriver` 或者 :class:`~fastNLP.core.drivers.JittorMPIDriver`; | |||||
5. 其值为 ``"fairscale"`` 时,表示使用 :class:`~fastNLP.core.drivers.FairScaleDriver`; | |||||
6. 其值为 ``"deepspeed"`` 时,表示使用 :class:`~fastNLP.core.drivers.DeepSpeedDriver`; | |||||
7. 其值为 ``"oneflow"`` 时,表示使用 :class:`~fastNLP.core.drivers.OneflowSingleDriver` 或者 :class:`~fastNLP.core.drivers.OneflowDDPDriver`; | |||||
3. 其值为 ``"torch_fsdp"`` 时,表示使用 :class:`~fastNLP.core.drivers.TorchFSDPDriver`; | |||||
4. 其值为 ``"paddle"`` 时,表示使用 :class:`~fastNLP.core.drivers.PaddleSingleDriver` 或者 :class:`~fastNLP.core.drivers.PaddleFleetDriver`; | |||||
5. 其值为 ``"jittor"`` 时,表示使用 :class:`~fastNLP.core.drivers.JittorSingleDriver` 或者 :class:`~fastNLP.core.drivers.JittorMPIDriver`; | |||||
6. 其值为 ``"fairscale"`` 时,表示使用 :class:`~fastNLP.core.drivers.FairScaleDriver`; | |||||
7. 其值为 ``"deepspeed"`` 时,表示使用 :class:`~fastNLP.core.drivers.DeepSpeedDriver`; | |||||
8. 其值为 ``"oneflow"`` 时,表示使用 :class:`~fastNLP.core.drivers.OneflowSingleDriver` 或者 :class:`~fastNLP.core.drivers.OneflowDDPDriver`; | |||||
在指定了框架的情况下,具体使用哪一种取决于参数 ``device`` 的设置; | 在指定了框架的情况下,具体使用哪一种取决于参数 ``device`` 的设置; | ||||
@@ -301,6 +302,21 @@ class Trainer(TrainerEventTrigger): | |||||
:kwargs: | :kwargs: | ||||
* *torch_kwargs* -- ``TorchDriver`` 所需的其它参数,详见 :class:`~fastNLP.core.drivers.torch_driver.TorchSingleDriver` 和 | * *torch_kwargs* -- ``TorchDriver`` 所需的其它参数,详见 :class:`~fastNLP.core.drivers.torch_driver.TorchSingleDriver` 和 | ||||
:class:`~fastNLP.core.drivers.torch_driver.TorchDDPDriver`; | :class:`~fastNLP.core.drivers.torch_driver.TorchDDPDriver`; | ||||
.. note:: | |||||
注意如果对于 ``TorchDDPDriver`` 中初始化 ``DistributedDataParallel`` 时有特别的参数,您可以通过在 ``torch_kwargs`` 中传入 | |||||
``ddp_kwargs`` 来实现,例如: | |||||
.. code-block:: | |||||
trainer = Trainer( | |||||
..., | |||||
torch_kwargs = {'ddp_kwargs': {'find_unused_parameters': True, ...}} | |||||
) | |||||
对于 ``TorchFSDPDriver`` 也是类似,只是对应的 ``**_kwargs`` 修改为 ``fsdp_kwargs``; | |||||
* *paddle_kwargs* -- ``PaddleDriver`` 所需的其它参数,详见 :class:`~fastNLP.core.drivers.paddle_driver.PaddleSingleDriver` 和 | * *paddle_kwargs* -- ``PaddleDriver`` 所需的其它参数,详见 :class:`~fastNLP.core.drivers.paddle_driver.PaddleSingleDriver` 和 | ||||
:class:`~fastNLP.core.drivers.paddle_driver.PaddleSingleDriver`; | :class:`~fastNLP.core.drivers.paddle_driver.PaddleSingleDriver`; | ||||
* *fairscale_kwargs* -- ``FairScaleDriver`` 所需的其它参数,详见 :class:`~fastNLP.core.drivers.torch_driver.FairScaleDriver`; | * *fairscale_kwargs* -- ``FairScaleDriver`` 所需的其它参数,详见 :class:`~fastNLP.core.drivers.torch_driver.FairScaleDriver`; | ||||
@@ -48,6 +48,12 @@ class TorchFSDPDriver(TorchDDPDriver): | |||||
``TorchFSDPDriver`` 现在还不支持断点重训功能,但是支持保存模型和加载模型; | ``TorchFSDPDriver`` 现在还不支持断点重训功能,但是支持保存模型和加载模型; | ||||
注意当您在加载和保存模型的 checkpointcallback 的时候,您可以通过在初始化 ``Trainer`` 时传入 | |||||
``torch_kwargs={"fsdp_kwargs": {'save_on_rank0': True/False, 'load_on_rank0': True/False}}`` 来指定保存模型的行为: | |||||
1. save/load_on_rank0 = True:表示在加载和保存模型时将所有 rank 上的模型参数全部聚合到 rank0 上,注意这样可能会造成 OOM; | |||||
2. save/load_on_rank0 = False:表示每个 rank 分别保存加载自己独有的模型参数; | |||||
""" | """ | ||||
def __init__( | def __init__( | ||||
@@ -318,7 +318,7 @@ def test_torch_distributed_launch_2(version): | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch", 0), ("torch", [0, 1])]) | |||||
@pytest.mark.parametrize("driver,device", [("torch", 0), ("torch", [0, 1]), ("torch_fsdp", [0, 1])]) | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_torch_wo_auto_param_call( | def test_torch_wo_auto_param_call( | ||||
driver, | driver, | ||||
@@ -363,7 +363,7 @@ def test_torch_wo_auto_param_call( | |||||
# 测试 accumulation_steps; | # 测试 accumulation_steps; | ||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1])]) | |||||
@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1]), ("torch_fsdp", [0, 1])]) | |||||
@pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)]) | @pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)]) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_overfit_torch( | def test_trainer_overfit_torch( | ||||
@@ -71,7 +71,6 @@ def model_and_optimizers(request): | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_torch_without_evaluator( | def test_trainer_torch_without_evaluator( | ||||
model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
n_epochs=3, | |||||
): | ): | ||||
callbacks = [RecordLossCallback(loss_threshold=0.5)] | callbacks = [RecordLossCallback(loss_threshold=0.5)] | ||||
trainer = Trainer( | trainer = Trainer( | ||||
@@ -98,14 +97,14 @@ def test_trainer_torch_without_evaluator( | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch_fsdp", [4, 5])]) | |||||
@pytest.mark.parametrize("save_on_rank0", [True, False]) | |||||
@magic_argv_env_context(timeout=100) | @magic_argv_env_context(timeout=100) | ||||
def test_model_checkpoint_callback_1( | def test_model_checkpoint_callback_1( | ||||
model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
driver, | |||||
device | |||||
save_on_rank0 | |||||
): | ): | ||||
for version in [0]: | |||||
device = [6, 7] | |||||
for version in [0, 1]: | |||||
# 需要在每一个循环开始重新初始化 model,是因为 fsdp 会将当前卡上的 model 删除,从而导致这个引用实际上引用到的是一个空模型; | # 需要在每一个循环开始重新初始化 model,是因为 fsdp 会将当前卡上的 model 删除,从而导致这个引用实际上引用到的是一个空模型; | ||||
model_and_optimizers.model = TorchNormalModel_Classification_1( | model_and_optimizers.model = TorchNormalModel_Classification_1( | ||||
num_labels=ArgMaxDatasetConfig.num_labels, | num_labels=ArgMaxDatasetConfig.num_labels, | ||||
@@ -128,7 +127,7 @@ def test_model_checkpoint_callback_1( | |||||
trainer = Trainer( | trainer = Trainer( | ||||
model=model_and_optimizers.model, | model=model_and_optimizers.model, | ||||
driver=driver, | |||||
driver="torch_fsdp", | |||||
device=device, | device=device, | ||||
optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
@@ -139,7 +138,7 @@ def test_model_checkpoint_callback_1( | |||||
n_epochs=10, | n_epochs=10, | ||||
callbacks=callbacks, | callbacks=callbacks, | ||||
output_from_new_proc="all", | output_from_new_proc="all", | ||||
# torch_kwargs={"fsdp_kwargs": {'save_on_rank0': True}} | |||||
torch_kwargs={"fsdp_kwargs": {'save_on_rank0': True, 'load_on_rank0': True}} if save_on_rank0 else None | |||||
) | ) | ||||
trainer.run() | trainer.run() | ||||
@@ -165,7 +164,7 @@ def test_model_checkpoint_callback_1( | |||||
step_save_path = all_saved_model_paths["model-epoch_9-batch_123"] | step_save_path = all_saved_model_paths["model-epoch_9-batch_123"] | ||||
assert len(all_saved_model_paths) == 11 | assert len(all_saved_model_paths) == 11 | ||||
all_state_dicts = [epoch_save_path]#, step_save_path] | |||||
all_state_dicts = [epoch_save_path, step_save_path] | |||||
elif version == 1: | elif version == 1: | ||||
@@ -214,7 +213,7 @@ def test_model_checkpoint_callback_1( | |||||
trainer = Trainer( | trainer = Trainer( | ||||
model=model_and_optimizers.model, | model=model_and_optimizers.model, | ||||
driver=driver, | |||||
driver="torch_fsdp", | |||||
device=device, | device=device, | ||||
optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
@@ -223,9 +222,10 @@ def test_model_checkpoint_callback_1( | |||||
output_mapping=model_and_optimizers.output_mapping, | output_mapping=model_and_optimizers.output_mapping, | ||||
metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||
n_epochs=20, | |||||
n_epochs=2, | |||||
output_from_new_proc="all", | output_from_new_proc="all", | ||||
torch_kwargs={ | |||||
"fsdp_kwargs": {'save_on_rank0': True, 'load_on_rank0': True}} if save_on_rank0 else None | |||||
) | ) | ||||
trainer.load_model(folder, only_state_dict=True) | trainer.load_model(folder, only_state_dict=True) | ||||
@@ -238,9 +238,6 @@ def test_model_checkpoint_callback_1( | |||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
@pytest.mark.skip("现在 fsdp 还不支持断点重训;") | @pytest.mark.skip("现在 fsdp 还不支持断点重训;") | ||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch_fsdp", [6, 7])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | @pytest.mark.parametrize("driver,device", [("torch_fsdp", [6, 7])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | ||||