| @@ -86,7 +86,7 @@ class BiLSTMCRF(nn.Module): | |||
| :param seq_len: 每个句子的长度,形状为 ``[batch,]`` | |||
| :return: 如果 ``target`` 为 ``None``,则返回预测结果 ``{'pred': torch.Tensor}``,否则返回 loss ``{'loss': torch.Tensor}`` | |||
| """ | |||
| return self(words, seq_len, target) | |||
| return self(words, target, seq_len) | |||
| def evaluate_step(self, words: "torch.LongTensor", seq_len: "torch.LongTensor"): | |||
| """ | |||
| @@ -94,7 +94,7 @@ class BiLSTMCRF(nn.Module): | |||
| :param seq_len: 每个句子的长度,形状为 ``[batch,]`` | |||
| :return: 预测结果 ``{'pred': torch.Tensor}`` | |||
| """ | |||
| return self(words, seq_len) | |||
| return self(words, seq_len=seq_len) | |||
| class SeqLabeling(nn.Module): | |||
| @@ -286,7 +286,7 @@ class AdvSeqLabel(nn.Module): | |||
| :param seq_len: 每个句子的长度,形状为 ``[batch,]`` | |||
| :return: 如果 ``target`` 为 ``None``,则返回预测结果 ``{'pred': torch.Tensor}``,否则返回 loss ``{'loss': torch.Tensor}`` | |||
| """ | |||
| return self(words, seq_len, target) | |||
| return self(words, target, seq_len) | |||
| def evaluate_step(self, words: "torch.LongTensor", seq_len: "torch.LongTensor"): | |||
| """ | |||
| @@ -294,4 +294,4 @@ class AdvSeqLabel(nn.Module): | |||
| :param seq_len: 每个句子的长度,形状为 ``[batch,]`` | |||
| :return: 预测结果 ``{'pred': torch.Tensor}`` | |||
| """ | |||
| return self(words, seq_len) | |||
| return self(words, seq_len=seq_len) | |||
| @@ -103,8 +103,8 @@ def model_and_optimizers(request): | |||
| # 测试一下普通的情况; | |||
| @pytest.mark.torch | |||
| @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 4), | |||
| ("torch", [4, 5])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1]) | |||
| @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), | |||
| ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1]) | |||
| @pytest.mark.parametrize("evaluate_every", [-3, -1, 2]) | |||
| @magic_argv_env_context | |||
| def test_trainer_torch_with_evaluator( | |||
| @@ -139,7 +139,7 @@ def test_trainer_torch_with_evaluator( | |||
| @pytest.mark.torch | |||
| @pytest.mark.parametrize("driver,device", [("torch", [4, 5]), ("torch", 4)]) # ("torch", [0, 1]),("torch", 1) | |||
| @pytest.mark.parametrize("driver,device", [("torch", [0, 1]), ("torch", 1)]) # ("torch", [0, 1]),("torch", 1) | |||
| @pytest.mark.parametrize("fp16", [True, False]) | |||
| @pytest.mark.parametrize("accumulation_steps", [1, 3]) | |||
| @magic_argv_env_context | |||
| @@ -250,7 +250,7 @@ def test_trainer_on( | |||
| @pytest.mark.torch | |||
| @pytest.mark.parametrize("driver,device", [("torch", 'cpu'), ("torch", 4)]) # ("torch", [0, 1]),("torch", 1) | |||
| @pytest.mark.parametrize("driver,device", [("torch", 'cpu'), ("torch", 1)]) # ("torch", [0, 1]),("torch", 1) | |||
| @magic_argv_env_context | |||
| def test_trainer_specific_params_1( | |||
| model_and_optimizers: TrainerParameters, | |||
| @@ -291,7 +291,7 @@ def test_trainer_specific_params_1( | |||
| @pytest.mark.torch | |||
| @pytest.mark.parametrize("driver,device", [("torch", [4, 5])]) # ("torch", [0, 1]),("torch", 1) | |||
| @pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1) | |||
| @magic_argv_env_context | |||
| def test_trainer_specific_params_2( | |||
| model_and_optimizers: TrainerParameters, | |||
| @@ -340,7 +340,7 @@ def test_trainer_specific_params_2( | |||
| @pytest.mark.torch | |||
| @pytest.mark.parametrize("driver,device", [("torch", 4), ("torch", [4, 5])]) # ("torch", [0, 1]),("torch", 1) | |||
| @pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1) | |||
| @pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)]) | |||
| @magic_argv_env_context | |||
| def test_trainer_w_evaluator_overfit_torch( | |||
| @@ -14,7 +14,7 @@ from tests.helpers.callbacks.helper_callbacks import RecordLossCallback | |||
| from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch | |||
| from tests.helpers.utils import magic_argv_env_context, Capturing | |||
| from fastNLP.envs.distributed import rank_zero_rm | |||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _TORCH_GREATER_EQUAL_1_12 | |||
| if _NEED_IMPORT_TORCH: | |||
| import torch.distributed as dist | |||
| from torch.optim import SGD | |||
| @@ -290,7 +290,7 @@ def test_trainer_on_exception( | |||
| @pytest.mark.torch | |||
| @pytest.mark.parametrize("version", [0, 1, 2, 3]) | |||
| @pytest.mark.parametrize("version", [0, 1]) | |||
| @magic_argv_env_context | |||
| def test_torch_distributed_launch_1(version): | |||
| """ | |||
| @@ -304,7 +304,7 @@ def test_torch_distributed_launch_1(version): | |||
| @pytest.mark.torch | |||
| @pytest.mark.parametrize("version", [0, 1, 2, 3]) | |||
| @pytest.mark.parametrize("version", [0, 1]) | |||
| @magic_argv_env_context | |||
| def test_torch_distributed_launch_2(version): | |||
| """ | |||
| @@ -325,6 +325,8 @@ def test_torch_wo_auto_param_call( | |||
| device, | |||
| n_epochs=2, | |||
| ): | |||
| if driver == "torch_fsdp" and not _TORCH_GREATER_EQUAL_1_12: | |||
| pytest.skip("fsdp 需要 torch 在 1.12 及以上") | |||
| model = TorchNormalModel_Classification_3( | |||
| num_labels=NormalClassificationTrainTorchConfig.num_labels, | |||
| @@ -373,6 +375,9 @@ def test_trainer_overfit_torch( | |||
| overfit_batches, | |||
| num_train_batch_per_epoch | |||
| ): | |||
| if driver == "torch_fsdp" and not _TORCH_GREATER_EQUAL_1_12: | |||
| pytest.skip("fsdp 需要 torch 在 1.12 及以上") | |||
| trainer = Trainer( | |||
| model=model_and_optimizers.model, | |||
| driver=driver, | |||
| @@ -11,7 +11,7 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||
| from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset | |||
| from tests.helpers.callbacks.helper_callbacks import RecordLossCallback | |||
| from tests.helpers.utils import magic_argv_env_context | |||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _TORCH_GREATER_EQUAL_1_12 | |||
| from fastNLP.envs import FASTNLP_LAUNCH_TIME, rank_zero_rm | |||
| if _NEED_IMPORT_TORCH: | |||
| import torch.distributed as dist | |||
| @@ -67,6 +67,7 @@ def model_and_optimizers(request): | |||
| return trainer_params | |||
| @pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, "fsdp 需要 torch 版本在 1.12 及以上") | |||
| @pytest.mark.torch | |||
| @magic_argv_env_context | |||
| def test_trainer_torch_without_evaluator( | |||
| @@ -76,7 +77,7 @@ def test_trainer_torch_without_evaluator( | |||
| trainer = Trainer( | |||
| model=model_and_optimizers.model, | |||
| driver="torch_fsdp", | |||
| device=[4, 5], | |||
| device=[0, 1], | |||
| optimizers=model_and_optimizers.optimizers, | |||
| train_dataloader=model_and_optimizers.train_dataloader, | |||
| evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
| @@ -96,6 +97,7 @@ def test_trainer_torch_without_evaluator( | |||
| dist.destroy_process_group() | |||
| @pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, "fsdp 需要 torch 版本在 1.12 及以上") | |||
| @pytest.mark.torch | |||
| @pytest.mark.parametrize("save_on_rank0", [True, False]) | |||
| @magic_argv_env_context(timeout=100) | |||