Browse Source

更新了文档

tags/v1.0.0alpha
YWMditto 2 years ago
parent
commit
4b6e455247
4 changed files with 40 additions and 21 deletions
  1. +21
    -5
      fastNLP/core/controllers/trainer.py
  2. +6
    -0
      fastNLP/core/drivers/torch_driver/torch_fsdp.py
  3. +2
    -2
      tests/core/controllers/test_trainer_wo_evaluator_torch.py
  4. +11
    -14
      tests/core/drivers/torch_driver/test_fsdp.py

+ 21
- 5
fastNLP/core/controllers/trainer.py View File

@@ -59,11 +59,12 @@ class Trainer(TrainerEventTrigger):
1. 值为 ``"auto"`` 时,**fastNLP** 会根据传入模型的类型自行判断使用哪一种模式; 1. 值为 ``"auto"`` 时,**fastNLP** 会根据传入模型的类型自行判断使用哪一种模式;
2. 其值为 ``"torch"`` 时,表示使用 :class:`~fastNLP.core.drivers.TorchSingleDriver` 或者 :class:`~fastNLP.core.drivers.TorchDDPDriver`; 2. 其值为 ``"torch"`` 时,表示使用 :class:`~fastNLP.core.drivers.TorchSingleDriver` 或者 :class:`~fastNLP.core.drivers.TorchDDPDriver`;
3. 其值为 ``"paddle"`` 时,表示使用 :class:`~fastNLP.core.drivers.PaddleSingleDriver` 或者 :class:`~fastNLP.core.drivers.PaddleFleetDriver`;
4. 其值为 ``"jittor"`` 时,表示使用 :class:`~fastNLP.core.drivers.JittorSingleDriver` 或者 :class:`~fastNLP.core.drivers.JittorMPIDriver`;
5. 其值为 ``"fairscale"`` 时,表示使用 :class:`~fastNLP.core.drivers.FairScaleDriver`;
6. 其值为 ``"deepspeed"`` 时,表示使用 :class:`~fastNLP.core.drivers.DeepSpeedDriver`;
7. 其值为 ``"oneflow"`` 时,表示使用 :class:`~fastNLP.core.drivers.OneflowSingleDriver` 或者 :class:`~fastNLP.core.drivers.OneflowDDPDriver`;
3. 其值为 ``"torch_fsdp"`` 时,表示使用 :class:`~fastNLP.core.drivers.TorchFSDPDriver`;
4. 其值为 ``"paddle"`` 时,表示使用 :class:`~fastNLP.core.drivers.PaddleSingleDriver` 或者 :class:`~fastNLP.core.drivers.PaddleFleetDriver`;
5. 其值为 ``"jittor"`` 时,表示使用 :class:`~fastNLP.core.drivers.JittorSingleDriver` 或者 :class:`~fastNLP.core.drivers.JittorMPIDriver`;
6. 其值为 ``"fairscale"`` 时,表示使用 :class:`~fastNLP.core.drivers.FairScaleDriver`;
7. 其值为 ``"deepspeed"`` 时,表示使用 :class:`~fastNLP.core.drivers.DeepSpeedDriver`;
8. 其值为 ``"oneflow"`` 时,表示使用 :class:`~fastNLP.core.drivers.OneflowSingleDriver` 或者 :class:`~fastNLP.core.drivers.OneflowDDPDriver`;
在指定了框架的情况下,具体使用哪一种取决于参数 ``device`` 的设置; 在指定了框架的情况下,具体使用哪一种取决于参数 ``device`` 的设置;


@@ -301,6 +302,21 @@ class Trainer(TrainerEventTrigger):
:kwargs: :kwargs:
* *torch_kwargs* -- ``TorchDriver`` 所需的其它参数,详见 :class:`~fastNLP.core.drivers.torch_driver.TorchSingleDriver` 和 * *torch_kwargs* -- ``TorchDriver`` 所需的其它参数,详见 :class:`~fastNLP.core.drivers.torch_driver.TorchSingleDriver` 和
:class:`~fastNLP.core.drivers.torch_driver.TorchDDPDriver`; :class:`~fastNLP.core.drivers.torch_driver.TorchDDPDriver`;

.. note::

注意如果对于 ``TorchDDPDriver`` 中初始化 ``DistributedDataParallel`` 时有特别的参数,您可以通过在 ``torch_kwargs`` 中传入
``ddp_kwargs`` 来实现,例如:

.. code-block::

trainer = Trainer(
...,
torch_kwargs = {'ddp_kwargs': {'find_unused_parameters': True, ...}}
)

对于 ``TorchFSDPDriver`` 也是类似,只是对应的 ``**_kwargs`` 修改为 ``fsdp_kwargs``;

* *paddle_kwargs* -- ``PaddleDriver`` 所需的其它参数,详见 :class:`~fastNLP.core.drivers.paddle_driver.PaddleSingleDriver` 和 * *paddle_kwargs* -- ``PaddleDriver`` 所需的其它参数,详见 :class:`~fastNLP.core.drivers.paddle_driver.PaddleSingleDriver` 和
:class:`~fastNLP.core.drivers.paddle_driver.PaddleSingleDriver`; :class:`~fastNLP.core.drivers.paddle_driver.PaddleSingleDriver`;
* *fairscale_kwargs* -- ``FairScaleDriver`` 所需的其它参数,详见 :class:`~fastNLP.core.drivers.torch_driver.FairScaleDriver`; * *fairscale_kwargs* -- ``FairScaleDriver`` 所需的其它参数,详见 :class:`~fastNLP.core.drivers.torch_driver.FairScaleDriver`;


+ 6
- 0
fastNLP/core/drivers/torch_driver/torch_fsdp.py View File

@@ -48,6 +48,12 @@ class TorchFSDPDriver(TorchDDPDriver):


``TorchFSDPDriver`` 现在还不支持断点重训功能,但是支持保存模型和加载模型; ``TorchFSDPDriver`` 现在还不支持断点重训功能,但是支持保存模型和加载模型;


注意当您在加载和保存模型的 checkpointcallback 的时候,您可以通过在初始化 ``Trainer`` 时传入
``torch_kwargs={"fsdp_kwargs": {'save_on_rank0': True/False, 'load_on_rank0': True/False}}`` 来指定保存模型的行为:

1. save/load_on_rank0 = True:表示在加载和保存模型时将所有 rank 上的模型参数全部聚合到 rank0 上,注意这样可能会造成 OOM;
2. save/load_on_rank0 = False:表示每个 rank 分别保存加载自己独有的模型参数;

""" """


def __init__( def __init__(


+ 2
- 2
tests/core/controllers/test_trainer_wo_evaluator_torch.py View File

@@ -318,7 +318,7 @@ def test_torch_distributed_launch_2(version):




@pytest.mark.torch @pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", 0), ("torch", [0, 1])])
@pytest.mark.parametrize("driver,device", [("torch", 0), ("torch", [0, 1]), ("torch_fsdp", [0, 1])])
@magic_argv_env_context @magic_argv_env_context
def test_torch_wo_auto_param_call( def test_torch_wo_auto_param_call(
driver, driver,
@@ -363,7 +363,7 @@ def test_torch_wo_auto_param_call(


# 测试 accumulation_steps; # 测试 accumulation_steps;
@pytest.mark.torch @pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1])])
@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1]), ("torch_fsdp", [0, 1])])
@pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)]) @pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)])
@magic_argv_env_context @magic_argv_env_context
def test_trainer_overfit_torch( def test_trainer_overfit_torch(


+ 11
- 14
tests/core/drivers/torch_driver/test_fsdp.py View File

@@ -71,7 +71,6 @@ def model_and_optimizers(request):
@magic_argv_env_context @magic_argv_env_context
def test_trainer_torch_without_evaluator( def test_trainer_torch_without_evaluator(
model_and_optimizers: TrainerParameters, model_and_optimizers: TrainerParameters,
n_epochs=3,
): ):
callbacks = [RecordLossCallback(loss_threshold=0.5)] callbacks = [RecordLossCallback(loss_threshold=0.5)]
trainer = Trainer( trainer = Trainer(
@@ -98,14 +97,14 @@ def test_trainer_torch_without_evaluator(




@pytest.mark.torch @pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch_fsdp", [4, 5])])
@pytest.mark.parametrize("save_on_rank0", [True, False])
@magic_argv_env_context(timeout=100) @magic_argv_env_context(timeout=100)
def test_model_checkpoint_callback_1( def test_model_checkpoint_callback_1(
model_and_optimizers: TrainerParameters, model_and_optimizers: TrainerParameters,
driver,
device
save_on_rank0
): ):
for version in [0]:
device = [6, 7]
for version in [0, 1]:
# 需要在每一个循环开始重新初始化 model,是因为 fsdp 会将当前卡上的 model 删除,从而导致这个引用实际上引用到的是一个空模型; # 需要在每一个循环开始重新初始化 model,是因为 fsdp 会将当前卡上的 model 删除,从而导致这个引用实际上引用到的是一个空模型;
model_and_optimizers.model = TorchNormalModel_Classification_1( model_and_optimizers.model = TorchNormalModel_Classification_1(
num_labels=ArgMaxDatasetConfig.num_labels, num_labels=ArgMaxDatasetConfig.num_labels,
@@ -128,7 +127,7 @@ def test_model_checkpoint_callback_1(


trainer = Trainer( trainer = Trainer(
model=model_and_optimizers.model, model=model_and_optimizers.model,
driver=driver,
driver="torch_fsdp",
device=device, device=device,
optimizers=model_and_optimizers.optimizers, optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader, train_dataloader=model_and_optimizers.train_dataloader,
@@ -139,7 +138,7 @@ def test_model_checkpoint_callback_1(
n_epochs=10, n_epochs=10,
callbacks=callbacks, callbacks=callbacks,
output_from_new_proc="all", output_from_new_proc="all",
# torch_kwargs={"fsdp_kwargs": {'save_on_rank0': True}}
torch_kwargs={"fsdp_kwargs": {'save_on_rank0': True, 'load_on_rank0': True}} if save_on_rank0 else None
) )


trainer.run() trainer.run()
@@ -165,7 +164,7 @@ def test_model_checkpoint_callback_1(
step_save_path = all_saved_model_paths["model-epoch_9-batch_123"] step_save_path = all_saved_model_paths["model-epoch_9-batch_123"]


assert len(all_saved_model_paths) == 11 assert len(all_saved_model_paths) == 11
all_state_dicts = [epoch_save_path]#, step_save_path]
all_state_dicts = [epoch_save_path, step_save_path]


elif version == 1: elif version == 1:


@@ -214,7 +213,7 @@ def test_model_checkpoint_callback_1(


trainer = Trainer( trainer = Trainer(
model=model_and_optimizers.model, model=model_and_optimizers.model,
driver=driver,
driver="torch_fsdp",
device=device, device=device,
optimizers=model_and_optimizers.optimizers, optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader, train_dataloader=model_and_optimizers.train_dataloader,
@@ -223,9 +222,10 @@ def test_model_checkpoint_callback_1(
output_mapping=model_and_optimizers.output_mapping, output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics, metrics=model_and_optimizers.metrics,


n_epochs=20,
n_epochs=2,
output_from_new_proc="all", output_from_new_proc="all",

torch_kwargs={
"fsdp_kwargs": {'save_on_rank0': True, 'load_on_rank0': True}} if save_on_rank0 else None
) )
trainer.load_model(folder, only_state_dict=True) trainer.load_model(folder, only_state_dict=True)


@@ -238,9 +238,6 @@ def test_model_checkpoint_callback_1(
dist.destroy_process_group() dist.destroy_process_group()







@pytest.mark.skip("现在 fsdp 还不支持断点重训;") @pytest.mark.skip("现在 fsdp 还不支持断点重训;")
@pytest.mark.torch @pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch_fsdp", [6, 7])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) @pytest.mark.parametrize("driver,device", [("torch_fsdp", [6, 7])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1)


Loading…
Cancel
Save