| @@ -86,7 +86,7 @@ class BiLSTMCRF(nn.Module): | |||||
| :param seq_len: 每个句子的长度,形状为 ``[batch,]`` | :param seq_len: 每个句子的长度,形状为 ``[batch,]`` | ||||
| :return: 如果 ``target`` 为 ``None``,则返回预测结果 ``{'pred': torch.Tensor}``,否则返回 loss ``{'loss': torch.Tensor}`` | :return: 如果 ``target`` 为 ``None``,则返回预测结果 ``{'pred': torch.Tensor}``,否则返回 loss ``{'loss': torch.Tensor}`` | ||||
| """ | """ | ||||
| return self(words, seq_len, target) | |||||
| return self(words, target, seq_len) | |||||
| def evaluate_step(self, words: "torch.LongTensor", seq_len: "torch.LongTensor"): | def evaluate_step(self, words: "torch.LongTensor", seq_len: "torch.LongTensor"): | ||||
| """ | """ | ||||
| @@ -94,7 +94,7 @@ class BiLSTMCRF(nn.Module): | |||||
| :param seq_len: 每个句子的长度,形状为 ``[batch,]`` | :param seq_len: 每个句子的长度,形状为 ``[batch,]`` | ||||
| :return: 预测结果 ``{'pred': torch.Tensor}`` | :return: 预测结果 ``{'pred': torch.Tensor}`` | ||||
| """ | """ | ||||
| return self(words, seq_len) | |||||
| return self(words, seq_len=seq_len) | |||||
| class SeqLabeling(nn.Module): | class SeqLabeling(nn.Module): | ||||
| @@ -286,7 +286,7 @@ class AdvSeqLabel(nn.Module): | |||||
| :param seq_len: 每个句子的长度,形状为 ``[batch,]`` | :param seq_len: 每个句子的长度,形状为 ``[batch,]`` | ||||
| :return: 如果 ``target`` 为 ``None``,则返回预测结果 ``{'pred': torch.Tensor}``,否则返回 loss ``{'loss': torch.Tensor}`` | :return: 如果 ``target`` 为 ``None``,则返回预测结果 ``{'pred': torch.Tensor}``,否则返回 loss ``{'loss': torch.Tensor}`` | ||||
| """ | """ | ||||
| return self(words, seq_len, target) | |||||
| return self(words, target, seq_len) | |||||
| def evaluate_step(self, words: "torch.LongTensor", seq_len: "torch.LongTensor"): | def evaluate_step(self, words: "torch.LongTensor", seq_len: "torch.LongTensor"): | ||||
| """ | """ | ||||
| @@ -294,4 +294,4 @@ class AdvSeqLabel(nn.Module): | |||||
| :param seq_len: 每个句子的长度,形状为 ``[batch,]`` | :param seq_len: 每个句子的长度,形状为 ``[batch,]`` | ||||
| :return: 预测结果 ``{'pred': torch.Tensor}`` | :return: 预测结果 ``{'pred': torch.Tensor}`` | ||||
| """ | """ | ||||
| return self(words, seq_len) | |||||
| return self(words, seq_len=seq_len) | |||||
| @@ -103,8 +103,8 @@ def model_and_optimizers(request): | |||||
| # 测试一下普通的情况; | # 测试一下普通的情况; | ||||
| @pytest.mark.torch | @pytest.mark.torch | ||||
| @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 4), | |||||
| ("torch", [4, 5])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1]) | |||||
| @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), | |||||
| ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1]) | |||||
| @pytest.mark.parametrize("evaluate_every", [-3, -1, 2]) | @pytest.mark.parametrize("evaluate_every", [-3, -1, 2]) | ||||
| @magic_argv_env_context | @magic_argv_env_context | ||||
| def test_trainer_torch_with_evaluator( | def test_trainer_torch_with_evaluator( | ||||
| @@ -139,7 +139,7 @@ def test_trainer_torch_with_evaluator( | |||||
| @pytest.mark.torch | @pytest.mark.torch | ||||
| @pytest.mark.parametrize("driver,device", [("torch", [4, 5]), ("torch", 4)]) # ("torch", [0, 1]),("torch", 1) | |||||
| @pytest.mark.parametrize("driver,device", [("torch", [0, 1]), ("torch", 1)]) # ("torch", [0, 1]),("torch", 1) | |||||
| @pytest.mark.parametrize("fp16", [True, False]) | @pytest.mark.parametrize("fp16", [True, False]) | ||||
| @pytest.mark.parametrize("accumulation_steps", [1, 3]) | @pytest.mark.parametrize("accumulation_steps", [1, 3]) | ||||
| @magic_argv_env_context | @magic_argv_env_context | ||||
| @@ -250,7 +250,7 @@ def test_trainer_on( | |||||
| @pytest.mark.torch | @pytest.mark.torch | ||||
| @pytest.mark.parametrize("driver,device", [("torch", 'cpu'), ("torch", 4)]) # ("torch", [0, 1]),("torch", 1) | |||||
| @pytest.mark.parametrize("driver,device", [("torch", 'cpu'), ("torch", 1)]) # ("torch", [0, 1]),("torch", 1) | |||||
| @magic_argv_env_context | @magic_argv_env_context | ||||
| def test_trainer_specific_params_1( | def test_trainer_specific_params_1( | ||||
| model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
| @@ -291,7 +291,7 @@ def test_trainer_specific_params_1( | |||||
| @pytest.mark.torch | @pytest.mark.torch | ||||
| @pytest.mark.parametrize("driver,device", [("torch", [4, 5])]) # ("torch", [0, 1]),("torch", 1) | |||||
| @pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1) | |||||
| @magic_argv_env_context | @magic_argv_env_context | ||||
| def test_trainer_specific_params_2( | def test_trainer_specific_params_2( | ||||
| model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
| @@ -340,7 +340,7 @@ def test_trainer_specific_params_2( | |||||
| @pytest.mark.torch | @pytest.mark.torch | ||||
| @pytest.mark.parametrize("driver,device", [("torch", 4), ("torch", [4, 5])]) # ("torch", [0, 1]),("torch", 1) | |||||
| @pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1) | |||||
| @pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)]) | @pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)]) | ||||
| @magic_argv_env_context | @magic_argv_env_context | ||||
| def test_trainer_w_evaluator_overfit_torch( | def test_trainer_w_evaluator_overfit_torch( | ||||
| @@ -14,7 +14,7 @@ from tests.helpers.callbacks.helper_callbacks import RecordLossCallback | |||||
| from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch | from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch | ||||
| from tests.helpers.utils import magic_argv_env_context, Capturing | from tests.helpers.utils import magic_argv_env_context, Capturing | ||||
| from fastNLP.envs.distributed import rank_zero_rm | from fastNLP.envs.distributed import rank_zero_rm | ||||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _TORCH_GREATER_EQUAL_1_12 | |||||
| if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
| import torch.distributed as dist | import torch.distributed as dist | ||||
| from torch.optim import SGD | from torch.optim import SGD | ||||
| @@ -290,7 +290,7 @@ def test_trainer_on_exception( | |||||
| @pytest.mark.torch | @pytest.mark.torch | ||||
| @pytest.mark.parametrize("version", [0, 1, 2, 3]) | |||||
| @pytest.mark.parametrize("version", [0, 1]) | |||||
| @magic_argv_env_context | @magic_argv_env_context | ||||
| def test_torch_distributed_launch_1(version): | def test_torch_distributed_launch_1(version): | ||||
| """ | """ | ||||
| @@ -304,7 +304,7 @@ def test_torch_distributed_launch_1(version): | |||||
| @pytest.mark.torch | @pytest.mark.torch | ||||
| @pytest.mark.parametrize("version", [0, 1, 2, 3]) | |||||
| @pytest.mark.parametrize("version", [0, 1]) | |||||
| @magic_argv_env_context | @magic_argv_env_context | ||||
| def test_torch_distributed_launch_2(version): | def test_torch_distributed_launch_2(version): | ||||
| """ | """ | ||||
| @@ -325,6 +325,8 @@ def test_torch_wo_auto_param_call( | |||||
| device, | device, | ||||
| n_epochs=2, | n_epochs=2, | ||||
| ): | ): | ||||
| if driver == "torch_fsdp" and not _TORCH_GREATER_EQUAL_1_12: | |||||
| pytest.skip("fsdp 需要 torch 在 1.12 及以上") | |||||
| model = TorchNormalModel_Classification_3( | model = TorchNormalModel_Classification_3( | ||||
| num_labels=NormalClassificationTrainTorchConfig.num_labels, | num_labels=NormalClassificationTrainTorchConfig.num_labels, | ||||
| @@ -373,6 +375,9 @@ def test_trainer_overfit_torch( | |||||
| overfit_batches, | overfit_batches, | ||||
| num_train_batch_per_epoch | num_train_batch_per_epoch | ||||
| ): | ): | ||||
| if driver == "torch_fsdp" and not _TORCH_GREATER_EQUAL_1_12: | |||||
| pytest.skip("fsdp 需要 torch 在 1.12 及以上") | |||||
| trainer = Trainer( | trainer = Trainer( | ||||
| model=model_and_optimizers.model, | model=model_and_optimizers.model, | ||||
| driver=driver, | driver=driver, | ||||
| @@ -11,7 +11,7 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
| from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset | ||||
| from tests.helpers.callbacks.helper_callbacks import RecordLossCallback | from tests.helpers.callbacks.helper_callbacks import RecordLossCallback | ||||
| from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _TORCH_GREATER_EQUAL_1_12 | |||||
| from fastNLP.envs import FASTNLP_LAUNCH_TIME, rank_zero_rm | from fastNLP.envs import FASTNLP_LAUNCH_TIME, rank_zero_rm | ||||
| if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
| import torch.distributed as dist | import torch.distributed as dist | ||||
| @@ -67,6 +67,7 @@ def model_and_optimizers(request): | |||||
| return trainer_params | return trainer_params | ||||
| @pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, "fsdp 需要 torch 版本在 1.12 及以上") | |||||
| @pytest.mark.torch | @pytest.mark.torch | ||||
| @magic_argv_env_context | @magic_argv_env_context | ||||
| def test_trainer_torch_without_evaluator( | def test_trainer_torch_without_evaluator( | ||||
| @@ -76,7 +77,7 @@ def test_trainer_torch_without_evaluator( | |||||
| trainer = Trainer( | trainer = Trainer( | ||||
| model=model_and_optimizers.model, | model=model_and_optimizers.model, | ||||
| driver="torch_fsdp", | driver="torch_fsdp", | ||||
| device=[4, 5], | |||||
| device=[0, 1], | |||||
| optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
| train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
| evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | ||||
| @@ -96,6 +97,7 @@ def test_trainer_torch_without_evaluator( | |||||
| dist.destroy_process_group() | dist.destroy_process_group() | ||||
| @pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, "fsdp 需要 torch 版本在 1.12 及以上") | |||||
| @pytest.mark.torch | @pytest.mark.torch | ||||
| @pytest.mark.parametrize("save_on_rank0", [True, False]) | @pytest.mark.parametrize("save_on_rank0", [True, False]) | ||||
| @magic_argv_env_context(timeout=100) | @magic_argv_env_context(timeout=100) | ||||