Browse Source

限制测试时使用的设备;为fsdp的测试添加版本限制

dev0.8.0
x54-729 2 years ago
parent
commit
82fe167ea3
4 changed files with 22 additions and 15 deletions
  1. +4
    -4
      fastNLP/models/torch/sequence_labeling.py
  2. +6
    -6
      tests/core/controllers/test_trainer_w_evaluator_torch.py
  3. +8
    -3
      tests/core/controllers/test_trainer_wo_evaluator_torch.py
  4. +4
    -2
      tests/core/drivers/torch_driver/test_fsdp.py

+ 4
- 4
fastNLP/models/torch/sequence_labeling.py View File

@@ -86,7 +86,7 @@ class BiLSTMCRF(nn.Module):
:param seq_len: 每个句子的长度,形状为 ``[batch,]``
:return: 如果 ``target`` 为 ``None``,则返回预测结果 ``{'pred': torch.Tensor}``,否则返回 loss ``{'loss': torch.Tensor}``
"""
return self(words, seq_len, target)
return self(words, target, seq_len)

def evaluate_step(self, words: "torch.LongTensor", seq_len: "torch.LongTensor"):
"""
@@ -94,7 +94,7 @@ class BiLSTMCRF(nn.Module):
:param seq_len: 每个句子的长度,形状为 ``[batch,]``
:return: 预测结果 ``{'pred': torch.Tensor}``
"""
return self(words, seq_len)
return self(words, seq_len=seq_len)


class SeqLabeling(nn.Module):
@@ -286,7 +286,7 @@ class AdvSeqLabel(nn.Module):
:param seq_len: 每个句子的长度,形状为 ``[batch,]``
:return: 如果 ``target`` 为 ``None``,则返回预测结果 ``{'pred': torch.Tensor}``,否则返回 loss ``{'loss': torch.Tensor}``
"""
return self(words, seq_len, target)
return self(words, target, seq_len)
def evaluate_step(self, words: "torch.LongTensor", seq_len: "torch.LongTensor"):
"""
@@ -294,4 +294,4 @@ class AdvSeqLabel(nn.Module):
:param seq_len: 每个句子的长度,形状为 ``[batch,]``
:return: 预测结果 ``{'pred': torch.Tensor}``
"""
return self(words, seq_len)
return self(words, seq_len=seq_len)

+ 6
- 6
tests/core/controllers/test_trainer_w_evaluator_torch.py View File

@@ -103,8 +103,8 @@ def model_and_optimizers(request):

# 测试一下普通的情况;
@pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 4),
("torch", [4, 5])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1])
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1),
("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1])
@pytest.mark.parametrize("evaluate_every", [-3, -1, 2])
@magic_argv_env_context
def test_trainer_torch_with_evaluator(
@@ -139,7 +139,7 @@ def test_trainer_torch_with_evaluator(


@pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", [4, 5]), ("torch", 4)]) # ("torch", [0, 1]),("torch", 1)
@pytest.mark.parametrize("driver,device", [("torch", [0, 1]), ("torch", 1)]) # ("torch", [0, 1]),("torch", 1)
@pytest.mark.parametrize("fp16", [True, False])
@pytest.mark.parametrize("accumulation_steps", [1, 3])
@magic_argv_env_context
@@ -250,7 +250,7 @@ def test_trainer_on(


@pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", 'cpu'), ("torch", 4)]) # ("torch", [0, 1]),("torch", 1)
@pytest.mark.parametrize("driver,device", [("torch", 'cpu'), ("torch", 1)]) # ("torch", [0, 1]),("torch", 1)
@magic_argv_env_context
def test_trainer_specific_params_1(
model_and_optimizers: TrainerParameters,
@@ -291,7 +291,7 @@ def test_trainer_specific_params_1(


@pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", [4, 5])]) # ("torch", [0, 1]),("torch", 1)
@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1)
@magic_argv_env_context
def test_trainer_specific_params_2(
model_and_optimizers: TrainerParameters,
@@ -340,7 +340,7 @@ def test_trainer_specific_params_2(


@pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", 4), ("torch", [4, 5])]) # ("torch", [0, 1]),("torch", 1)
@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1)
@pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)])
@magic_argv_env_context
def test_trainer_w_evaluator_overfit_torch(


+ 8
- 3
tests/core/controllers/test_trainer_wo_evaluator_torch.py View File

@@ -14,7 +14,7 @@ from tests.helpers.callbacks.helper_callbacks import RecordLossCallback
from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch
from tests.helpers.utils import magic_argv_env_context, Capturing
from fastNLP.envs.distributed import rank_zero_rm
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _TORCH_GREATER_EQUAL_1_12
if _NEED_IMPORT_TORCH:
import torch.distributed as dist
from torch.optim import SGD
@@ -290,7 +290,7 @@ def test_trainer_on_exception(


@pytest.mark.torch
@pytest.mark.parametrize("version", [0, 1, 2, 3])
@pytest.mark.parametrize("version", [0, 1])
@magic_argv_env_context
def test_torch_distributed_launch_1(version):
"""
@@ -304,7 +304,7 @@ def test_torch_distributed_launch_1(version):


@pytest.mark.torch
@pytest.mark.parametrize("version", [0, 1, 2, 3])
@pytest.mark.parametrize("version", [0, 1])
@magic_argv_env_context
def test_torch_distributed_launch_2(version):
"""
@@ -325,6 +325,8 @@ def test_torch_wo_auto_param_call(
device,
n_epochs=2,
):
if driver == "torch_fsdp" and not _TORCH_GREATER_EQUAL_1_12:
pytest.skip("fsdp 需要 torch 在 1.12 及以上")

model = TorchNormalModel_Classification_3(
num_labels=NormalClassificationTrainTorchConfig.num_labels,
@@ -373,6 +375,9 @@ def test_trainer_overfit_torch(
overfit_batches,
num_train_batch_per_epoch
):
if driver == "torch_fsdp" and not _TORCH_GREATER_EQUAL_1_12:
pytest.skip("fsdp 需要 torch 在 1.12 及以上")

trainer = Trainer(
model=model_and_optimizers.model,
driver=driver,


+ 4
- 2
tests/core/drivers/torch_driver/test_fsdp.py View File

@@ -11,7 +11,7 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback
from tests.helpers.utils import magic_argv_env_context
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _TORCH_GREATER_EQUAL_1_12
from fastNLP.envs import FASTNLP_LAUNCH_TIME, rank_zero_rm
if _NEED_IMPORT_TORCH:
import torch.distributed as dist
@@ -67,6 +67,7 @@ def model_and_optimizers(request):

return trainer_params

@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, "fsdp 需要 torch 版本在 1.12 及以上")
@pytest.mark.torch
@magic_argv_env_context
def test_trainer_torch_without_evaluator(
@@ -76,7 +77,7 @@ def test_trainer_torch_without_evaluator(
trainer = Trainer(
model=model_and_optimizers.model,
driver="torch_fsdp",
device=[4, 5],
device=[0, 1],
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders,
@@ -96,6 +97,7 @@ def test_trainer_torch_without_evaluator(
dist.destroy_process_group()


@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_12, "fsdp 需要 torch 版本在 1.12 及以上")
@pytest.mark.torch
@pytest.mark.parametrize("save_on_rank0", [True, False])
@magic_argv_env_context(timeout=100)


Loading…
Cancel
Save