diff --git a/fastNLP/models/torch/sequence_labeling.py b/fastNLP/models/torch/sequence_labeling.py index d868cd7a..cb459ed7 100755 --- a/fastNLP/models/torch/sequence_labeling.py +++ b/fastNLP/models/torch/sequence_labeling.py @@ -86,7 +86,7 @@ class BiLSTMCRF(nn.Module): :param seq_len: 每个句子的长度,形状为 ``[batch,]`` :return: 如果 ``target`` 为 ``None``,则返回预测结果 ``{'pred': torch.Tensor}``,否则返回 loss ``{'loss': torch.Tensor}`` """ - return self(words, seq_len, target) + return self(words, target, seq_len) def evaluate_step(self, words: "torch.LongTensor", seq_len: "torch.LongTensor"): """ @@ -94,7 +94,7 @@ class BiLSTMCRF(nn.Module): :param seq_len: 每个句子的长度,形状为 ``[batch,]`` :return: 预测结果 ``{'pred': torch.Tensor}`` """ - return self(words, seq_len) + return self(words, seq_len=seq_len) class SeqLabeling(nn.Module): @@ -286,7 +286,7 @@ class AdvSeqLabel(nn.Module): :param seq_len: 每个句子的长度,形状为 ``[batch,]`` :return: 如果 ``target`` 为 ``None``,则返回预测结果 ``{'pred': torch.Tensor}``,否则返回 loss ``{'loss': torch.Tensor}`` """ - return self(words, seq_len, target) + return self(words, target, seq_len) def evaluate_step(self, words: "torch.LongTensor", seq_len: "torch.LongTensor"): """ @@ -294,4 +294,4 @@ class AdvSeqLabel(nn.Module): :param seq_len: 每个句子的长度,形状为 ``[batch,]`` :return: 预测结果 ``{'pred': torch.Tensor}`` """ - return self(words, seq_len) + return self(words, seq_len=seq_len) diff --git a/tests/core/controllers/test_trainer_w_evaluator_torch.py b/tests/core/controllers/test_trainer_w_evaluator_torch.py index 4d31e5f8..87ccb613 100644 --- a/tests/core/controllers/test_trainer_w_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_w_evaluator_torch.py @@ -103,8 +103,8 @@ def model_and_optimizers(request): # 测试一下普通的情况; @pytest.mark.torch -@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 4), - ("torch", [4, 5])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1]) +@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), + ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1]) @pytest.mark.parametrize("evaluate_every", [-3, -1, 2]) @magic_argv_env_context def test_trainer_torch_with_evaluator( @@ -139,7 +139,7 @@ def test_trainer_torch_with_evaluator( @pytest.mark.torch -@pytest.mark.parametrize("driver,device", [("torch", [4, 5]), ("torch", 4)]) # ("torch", [0, 1]),("torch", 1) +@pytest.mark.parametrize("driver,device", [("torch", [0, 1]), ("torch", 1)]) # ("torch", [0, 1]),("torch", 1) @pytest.mark.parametrize("fp16", [True, False]) @pytest.mark.parametrize("accumulation_steps", [1, 3]) @magic_argv_env_context @@ -250,7 +250,7 @@ def test_trainer_on( @pytest.mark.torch -@pytest.mark.parametrize("driver,device", [("torch", 'cpu'), ("torch", 4)]) # ("torch", [0, 1]),("torch", 1) +@pytest.mark.parametrize("driver,device", [("torch", 'cpu'), ("torch", 1)]) # ("torch", [0, 1]),("torch", 1) @magic_argv_env_context def test_trainer_specific_params_1( model_and_optimizers: TrainerParameters, @@ -291,7 +291,7 @@ def test_trainer_specific_params_1( @pytest.mark.torch -@pytest.mark.parametrize("driver,device", [("torch", [4, 5])]) # ("torch", [0, 1]),("torch", 1) +@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1) @magic_argv_env_context def test_trainer_specific_params_2( model_and_optimizers: TrainerParameters, @@ -340,7 +340,7 @@ def test_trainer_specific_params_2( @pytest.mark.torch -@pytest.mark.parametrize("driver,device", [("torch", 4), ("torch", [4, 5])]) # ("torch", [0, 1]),("torch", 1) +@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1) @pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)]) @magic_argv_env_context def test_trainer_w_evaluator_overfit_torch( diff --git a/tests/core/controllers/test_trainer_wo_evaluator_torch.py b/tests/core/controllers/test_trainer_wo_evaluator_torch.py index 2cdbe189..176f646a 100644 --- a/tests/core/controllers/test_trainer_wo_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_wo_evaluator_torch.py @@ -14,7 +14,7 @@ from tests.helpers.callbacks.helper_callbacks import RecordLossCallback from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch from tests.helpers.utils import magic_argv_env_context, Capturing from fastNLP.envs.distributed import rank_zero_rm -from fastNLP.envs.imports import _NEED_IMPORT_TORCH +from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _TORCH_GREATER_EQUAL_1_12 if _NEED_IMPORT_TORCH: import torch.distributed as dist from torch.optim import SGD @@ -290,7 +290,7 @@ def test_trainer_on_exception( @pytest.mark.torch -@pytest.mark.parametrize("version", [0, 1, 2, 3]) +@pytest.mark.parametrize("version", [0, 1]) @magic_argv_env_context def test_torch_distributed_launch_1(version): """ @@ -304,7 +304,7 @@ def test_torch_distributed_launch_1(version): @pytest.mark.torch -@pytest.mark.parametrize("version", [0, 1, 2, 3]) +@pytest.mark.parametrize("version", [0, 1]) @magic_argv_env_context def test_torch_distributed_launch_2(version): """ @@ -325,6 +325,8 @@ def test_torch_wo_auto_param_call( device, n_epochs=2, ): + if driver == "torch_fsdp" and not _TORCH_GREATER_EQUAL_1_12: + pytest.skip("fsdp 需要 torch 在 1.12 及以上") model = TorchNormalModel_Classification_3( num_labels=NormalClassificationTrainTorchConfig.num_labels, @@ -373,6 +375,9 @@ def test_trainer_overfit_torch( overfit_batches, num_train_batch_per_epoch ): + if driver == "torch_fsdp" and not _TORCH_GREATER_EQUAL_1_12: + pytest.skip("fsdp 需要 torch 在 1.12 及以上") + trainer = Trainer( model=model_and_optimizers.model, driver=driver, diff --git a/tests/core/drivers/torch_driver/test_fsdp.py b/tests/core/drivers/torch_driver/test_fsdp.py index 0046251a..586a97ea 100644 --- a/tests/core/drivers/torch_driver/test_fsdp.py +++ b/tests/core/drivers/torch_driver/test_fsdp.py @@ -11,7 +11,7 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset from tests.helpers.callbacks.helper_callbacks import RecordLossCallback from tests.helpers.utils import magic_argv_env_context -from fastNLP.envs.imports import _NEED_IMPORT_TORCH +from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _TORCH_GREATER_EQUAL_1_12 from fastNLP.envs import FASTNLP_LAUNCH_TIME, rank_zero_rm if _NEED_IMPORT_TORCH: import torch.distributed as dist @@ -67,6 +67,7 @@ def model_and_optimizers(request): return trainer_params +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, "fsdp 需要 torch 版本在 1.12 及以上") @pytest.mark.torch @magic_argv_env_context def test_trainer_torch_without_evaluator( @@ -76,7 +77,7 @@ def test_trainer_torch_without_evaluator( trainer = Trainer( model=model_and_optimizers.model, driver="torch_fsdp", - device=[4, 5], + device=[0, 1], optimizers=model_and_optimizers.optimizers, train_dataloader=model_and_optimizers.train_dataloader, evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, @@ -96,6 +97,7 @@ def test_trainer_torch_without_evaluator( dist.destroy_process_group() +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, "fsdp 需要 torch 版本在 1.12 及以上") @pytest.mark.torch @pytest.mark.parametrize("save_on_rank0", [True, False]) @magic_argv_env_context(timeout=100)