@@ -86,7 +86,7 @@ class BiLSTMCRF(nn.Module): | |||
:param seq_len: 每个句子的长度,形状为 ``[batch,]`` | |||
:return: 如果 ``target`` 为 ``None``,则返回预测结果 ``{'pred': torch.Tensor}``,否则返回 loss ``{'loss': torch.Tensor}`` | |||
""" | |||
return self(words, seq_len, target) | |||
return self(words, target, seq_len) | |||
def evaluate_step(self, words: "torch.LongTensor", seq_len: "torch.LongTensor"): | |||
""" | |||
@@ -94,7 +94,7 @@ class BiLSTMCRF(nn.Module): | |||
:param seq_len: 每个句子的长度,形状为 ``[batch,]`` | |||
:return: 预测结果 ``{'pred': torch.Tensor}`` | |||
""" | |||
return self(words, seq_len) | |||
return self(words, seq_len=seq_len) | |||
class SeqLabeling(nn.Module): | |||
@@ -286,7 +286,7 @@ class AdvSeqLabel(nn.Module): | |||
:param seq_len: 每个句子的长度,形状为 ``[batch,]`` | |||
:return: 如果 ``target`` 为 ``None``,则返回预测结果 ``{'pred': torch.Tensor}``,否则返回 loss ``{'loss': torch.Tensor}`` | |||
""" | |||
return self(words, seq_len, target) | |||
return self(words, target, seq_len) | |||
def evaluate_step(self, words: "torch.LongTensor", seq_len: "torch.LongTensor"): | |||
""" | |||
@@ -294,4 +294,4 @@ class AdvSeqLabel(nn.Module): | |||
:param seq_len: 每个句子的长度,形状为 ``[batch,]`` | |||
:return: 预测结果 ``{'pred': torch.Tensor}`` | |||
""" | |||
return self(words, seq_len) | |||
return self(words, seq_len=seq_len) |
@@ -103,8 +103,8 @@ def model_and_optimizers(request): | |||
# 测试一下普通的情况; | |||
@pytest.mark.torch | |||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 4), | |||
("torch", [4, 5])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1]) | |||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), | |||
("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1]) | |||
@pytest.mark.parametrize("evaluate_every", [-3, -1, 2]) | |||
@magic_argv_env_context | |||
def test_trainer_torch_with_evaluator( | |||
@@ -139,7 +139,7 @@ def test_trainer_torch_with_evaluator( | |||
@pytest.mark.torch | |||
@pytest.mark.parametrize("driver,device", [("torch", [4, 5]), ("torch", 4)]) # ("torch", [0, 1]),("torch", 1) | |||
@pytest.mark.parametrize("driver,device", [("torch", [0, 1]), ("torch", 1)]) # ("torch", [0, 1]),("torch", 1) | |||
@pytest.mark.parametrize("fp16", [True, False]) | |||
@pytest.mark.parametrize("accumulation_steps", [1, 3]) | |||
@magic_argv_env_context | |||
@@ -250,7 +250,7 @@ def test_trainer_on( | |||
@pytest.mark.torch | |||
@pytest.mark.parametrize("driver,device", [("torch", 'cpu'), ("torch", 4)]) # ("torch", [0, 1]),("torch", 1) | |||
@pytest.mark.parametrize("driver,device", [("torch", 'cpu'), ("torch", 1)]) # ("torch", [0, 1]),("torch", 1) | |||
@magic_argv_env_context | |||
def test_trainer_specific_params_1( | |||
model_and_optimizers: TrainerParameters, | |||
@@ -291,7 +291,7 @@ def test_trainer_specific_params_1( | |||
@pytest.mark.torch | |||
@pytest.mark.parametrize("driver,device", [("torch", [4, 5])]) # ("torch", [0, 1]),("torch", 1) | |||
@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1) | |||
@magic_argv_env_context | |||
def test_trainer_specific_params_2( | |||
model_and_optimizers: TrainerParameters, | |||
@@ -340,7 +340,7 @@ def test_trainer_specific_params_2( | |||
@pytest.mark.torch | |||
@pytest.mark.parametrize("driver,device", [("torch", 4), ("torch", [4, 5])]) # ("torch", [0, 1]),("torch", 1) | |||
@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1) | |||
@pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)]) | |||
@magic_argv_env_context | |||
def test_trainer_w_evaluator_overfit_torch( | |||
@@ -14,7 +14,7 @@ from tests.helpers.callbacks.helper_callbacks import RecordLossCallback | |||
from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch | |||
from tests.helpers.utils import magic_argv_env_context, Capturing | |||
from fastNLP.envs.distributed import rank_zero_rm | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _TORCH_GREATER_EQUAL_1_12 | |||
if _NEED_IMPORT_TORCH: | |||
import torch.distributed as dist | |||
from torch.optim import SGD | |||
@@ -290,7 +290,7 @@ def test_trainer_on_exception( | |||
@pytest.mark.torch | |||
@pytest.mark.parametrize("version", [0, 1, 2, 3]) | |||
@pytest.mark.parametrize("version", [0, 1]) | |||
@magic_argv_env_context | |||
def test_torch_distributed_launch_1(version): | |||
""" | |||
@@ -304,7 +304,7 @@ def test_torch_distributed_launch_1(version): | |||
@pytest.mark.torch | |||
@pytest.mark.parametrize("version", [0, 1, 2, 3]) | |||
@pytest.mark.parametrize("version", [0, 1]) | |||
@magic_argv_env_context | |||
def test_torch_distributed_launch_2(version): | |||
""" | |||
@@ -325,6 +325,8 @@ def test_torch_wo_auto_param_call( | |||
device, | |||
n_epochs=2, | |||
): | |||
if driver == "torch_fsdp" and not _TORCH_GREATER_EQUAL_1_12: | |||
pytest.skip("fsdp 需要 torch 在 1.12 及以上") | |||
model = TorchNormalModel_Classification_3( | |||
num_labels=NormalClassificationTrainTorchConfig.num_labels, | |||
@@ -373,6 +375,9 @@ def test_trainer_overfit_torch( | |||
overfit_batches, | |||
num_train_batch_per_epoch | |||
): | |||
if driver == "torch_fsdp" and not _TORCH_GREATER_EQUAL_1_12: | |||
pytest.skip("fsdp 需要 torch 在 1.12 及以上") | |||
trainer = Trainer( | |||
model=model_and_optimizers.model, | |||
driver=driver, | |||
@@ -11,7 +11,7 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset | |||
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback | |||
from tests.helpers.utils import magic_argv_env_context | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _TORCH_GREATER_EQUAL_1_12 | |||
from fastNLP.envs import FASTNLP_LAUNCH_TIME, rank_zero_rm | |||
if _NEED_IMPORT_TORCH: | |||
import torch.distributed as dist | |||
@@ -67,6 +67,7 @@ def model_and_optimizers(request): | |||
return trainer_params | |||
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, "fsdp 需要 torch 版本在 1.12 及以上") | |||
@pytest.mark.torch | |||
@magic_argv_env_context | |||
def test_trainer_torch_without_evaluator( | |||
@@ -76,7 +77,7 @@ def test_trainer_torch_without_evaluator( | |||
trainer = Trainer( | |||
model=model_and_optimizers.model, | |||
driver="torch_fsdp", | |||
device=[4, 5], | |||
device=[0, 1], | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
@@ -96,6 +97,7 @@ def test_trainer_torch_without_evaluator( | |||
dist.destroy_process_group() | |||
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, "fsdp 需要 torch 版本在 1.12 及以上") | |||
@pytest.mark.torch | |||
@pytest.mark.parametrize("save_on_rank0", [True, False]) | |||
@magic_argv_env_context(timeout=100) | |||