@@ -86,7 +86,7 @@ class BiLSTMCRF(nn.Module): | |||||
:param seq_len: 每个句子的长度,形状为 ``[batch,]`` | :param seq_len: 每个句子的长度,形状为 ``[batch,]`` | ||||
:return: 如果 ``target`` 为 ``None``,则返回预测结果 ``{'pred': torch.Tensor}``,否则返回 loss ``{'loss': torch.Tensor}`` | :return: 如果 ``target`` 为 ``None``,则返回预测结果 ``{'pred': torch.Tensor}``,否则返回 loss ``{'loss': torch.Tensor}`` | ||||
""" | """ | ||||
return self(words, seq_len, target) | |||||
return self(words, target, seq_len) | |||||
def evaluate_step(self, words: "torch.LongTensor", seq_len: "torch.LongTensor"): | def evaluate_step(self, words: "torch.LongTensor", seq_len: "torch.LongTensor"): | ||||
""" | """ | ||||
@@ -94,7 +94,7 @@ class BiLSTMCRF(nn.Module): | |||||
:param seq_len: 每个句子的长度,形状为 ``[batch,]`` | :param seq_len: 每个句子的长度,形状为 ``[batch,]`` | ||||
:return: 预测结果 ``{'pred': torch.Tensor}`` | :return: 预测结果 ``{'pred': torch.Tensor}`` | ||||
""" | """ | ||||
return self(words, seq_len) | |||||
return self(words, seq_len=seq_len) | |||||
class SeqLabeling(nn.Module): | class SeqLabeling(nn.Module): | ||||
@@ -286,7 +286,7 @@ class AdvSeqLabel(nn.Module): | |||||
:param seq_len: 每个句子的长度,形状为 ``[batch,]`` | :param seq_len: 每个句子的长度,形状为 ``[batch,]`` | ||||
:return: 如果 ``target`` 为 ``None``,则返回预测结果 ``{'pred': torch.Tensor}``,否则返回 loss ``{'loss': torch.Tensor}`` | :return: 如果 ``target`` 为 ``None``,则返回预测结果 ``{'pred': torch.Tensor}``,否则返回 loss ``{'loss': torch.Tensor}`` | ||||
""" | """ | ||||
return self(words, seq_len, target) | |||||
return self(words, target, seq_len) | |||||
def evaluate_step(self, words: "torch.LongTensor", seq_len: "torch.LongTensor"): | def evaluate_step(self, words: "torch.LongTensor", seq_len: "torch.LongTensor"): | ||||
""" | """ | ||||
@@ -294,4 +294,4 @@ class AdvSeqLabel(nn.Module): | |||||
:param seq_len: 每个句子的长度,形状为 ``[batch,]`` | :param seq_len: 每个句子的长度,形状为 ``[batch,]`` | ||||
:return: 预测结果 ``{'pred': torch.Tensor}`` | :return: 预测结果 ``{'pred': torch.Tensor}`` | ||||
""" | """ | ||||
return self(words, seq_len) | |||||
return self(words, seq_len=seq_len) |
@@ -103,8 +103,8 @@ def model_and_optimizers(request): | |||||
# 测试一下普通的情况; | # 测试一下普通的情况; | ||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 4), | |||||
("torch", [4, 5])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1]) | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), | |||||
("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1]) | |||||
@pytest.mark.parametrize("evaluate_every", [-3, -1, 2]) | @pytest.mark.parametrize("evaluate_every", [-3, -1, 2]) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_torch_with_evaluator( | def test_trainer_torch_with_evaluator( | ||||
@@ -139,7 +139,7 @@ def test_trainer_torch_with_evaluator( | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch", [4, 5]), ("torch", 4)]) # ("torch", [0, 1]),("torch", 1) | |||||
@pytest.mark.parametrize("driver,device", [("torch", [0, 1]), ("torch", 1)]) # ("torch", [0, 1]),("torch", 1) | |||||
@pytest.mark.parametrize("fp16", [True, False]) | @pytest.mark.parametrize("fp16", [True, False]) | ||||
@pytest.mark.parametrize("accumulation_steps", [1, 3]) | @pytest.mark.parametrize("accumulation_steps", [1, 3]) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@@ -250,7 +250,7 @@ def test_trainer_on( | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch", 'cpu'), ("torch", 4)]) # ("torch", [0, 1]),("torch", 1) | |||||
@pytest.mark.parametrize("driver,device", [("torch", 'cpu'), ("torch", 1)]) # ("torch", [0, 1]),("torch", 1) | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_specific_params_1( | def test_trainer_specific_params_1( | ||||
model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
@@ -291,7 +291,7 @@ def test_trainer_specific_params_1( | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch", [4, 5])]) # ("torch", [0, 1]),("torch", 1) | |||||
@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1) | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_specific_params_2( | def test_trainer_specific_params_2( | ||||
model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
@@ -340,7 +340,7 @@ def test_trainer_specific_params_2( | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch", 4), ("torch", [4, 5])]) # ("torch", [0, 1]),("torch", 1) | |||||
@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1) | |||||
@pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)]) | @pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)]) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_w_evaluator_overfit_torch( | def test_trainer_w_evaluator_overfit_torch( | ||||
@@ -14,7 +14,7 @@ from tests.helpers.callbacks.helper_callbacks import RecordLossCallback | |||||
from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch | from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch | ||||
from tests.helpers.utils import magic_argv_env_context, Capturing | from tests.helpers.utils import magic_argv_env_context, Capturing | ||||
from fastNLP.envs.distributed import rank_zero_rm | from fastNLP.envs.distributed import rank_zero_rm | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _TORCH_GREATER_EQUAL_1_12 | |||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
import torch.distributed as dist | import torch.distributed as dist | ||||
from torch.optim import SGD | from torch.optim import SGD | ||||
@@ -290,7 +290,7 @@ def test_trainer_on_exception( | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("version", [0, 1, 2, 3]) | |||||
@pytest.mark.parametrize("version", [0, 1]) | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_torch_distributed_launch_1(version): | def test_torch_distributed_launch_1(version): | ||||
""" | """ | ||||
@@ -304,7 +304,7 @@ def test_torch_distributed_launch_1(version): | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("version", [0, 1, 2, 3]) | |||||
@pytest.mark.parametrize("version", [0, 1]) | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_torch_distributed_launch_2(version): | def test_torch_distributed_launch_2(version): | ||||
""" | """ | ||||
@@ -325,6 +325,8 @@ def test_torch_wo_auto_param_call( | |||||
device, | device, | ||||
n_epochs=2, | n_epochs=2, | ||||
): | ): | ||||
if driver == "torch_fsdp" and not _TORCH_GREATER_EQUAL_1_12: | |||||
pytest.skip("fsdp 需要 torch 在 1.12 及以上") | |||||
model = TorchNormalModel_Classification_3( | model = TorchNormalModel_Classification_3( | ||||
num_labels=NormalClassificationTrainTorchConfig.num_labels, | num_labels=NormalClassificationTrainTorchConfig.num_labels, | ||||
@@ -373,6 +375,9 @@ def test_trainer_overfit_torch( | |||||
overfit_batches, | overfit_batches, | ||||
num_train_batch_per_epoch | num_train_batch_per_epoch | ||||
): | ): | ||||
if driver == "torch_fsdp" and not _TORCH_GREATER_EQUAL_1_12: | |||||
pytest.skip("fsdp 需要 torch 在 1.12 及以上") | |||||
trainer = Trainer( | trainer = Trainer( | ||||
model=model_and_optimizers.model, | model=model_and_optimizers.model, | ||||
driver=driver, | driver=driver, | ||||
@@ -11,7 +11,7 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset | ||||
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback | from tests.helpers.callbacks.helper_callbacks import RecordLossCallback | ||||
from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _TORCH_GREATER_EQUAL_1_12 | |||||
from fastNLP.envs import FASTNLP_LAUNCH_TIME, rank_zero_rm | from fastNLP.envs import FASTNLP_LAUNCH_TIME, rank_zero_rm | ||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
import torch.distributed as dist | import torch.distributed as dist | ||||
@@ -67,6 +67,7 @@ def model_and_optimizers(request): | |||||
return trainer_params | return trainer_params | ||||
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, "fsdp 需要 torch 版本在 1.12 及以上") | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_torch_without_evaluator( | def test_trainer_torch_without_evaluator( | ||||
@@ -76,7 +77,7 @@ def test_trainer_torch_without_evaluator( | |||||
trainer = Trainer( | trainer = Trainer( | ||||
model=model_and_optimizers.model, | model=model_and_optimizers.model, | ||||
driver="torch_fsdp", | driver="torch_fsdp", | ||||
device=[4, 5], | |||||
device=[0, 1], | |||||
optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | ||||
@@ -96,6 +97,7 @@ def test_trainer_torch_without_evaluator( | |||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, "fsdp 需要 torch 版本在 1.12 及以上") | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("save_on_rank0", [True, False]) | @pytest.mark.parametrize("save_on_rank0", [True, False]) | ||||
@magic_argv_env_context(timeout=100) | @magic_argv_env_context(timeout=100) | ||||