diff --git a/fastNLP/core/utils/dummy_class.py b/fastNLP/core/utils/dummy_class.py index 2856b656..42200cbb 100644 --- a/fastNLP/core/utils/dummy_class.py +++ b/fastNLP/core/utils/dummy_class.py @@ -1,5 +1,5 @@ import functools class DummyClass: - def __call__(self, *args, **kwargs): - return + def __init__(self, *args, **kwargs): + pass diff --git a/tests/core/callbacks/test_checkpoint_callback_torch.py b/tests/core/callbacks/test_checkpoint_callback_torch.py index 976b68ba..a8ce451b 100644 --- a/tests/core/callbacks/test_checkpoint_callback_torch.py +++ b/tests/core/callbacks/test_checkpoint_callback_torch.py @@ -2,9 +2,6 @@ import os import pytest from typing import Any from dataclasses import dataclass -from torch.utils.data import DataLoader -from torch.optim import SGD -import torch.distributed as dist from pathlib import Path import re import time @@ -20,6 +17,11 @@ from tests.helpers.datasets.torch_data import TorchArgMaxDataset from torchmetrics import Accuracy from fastNLP.core.log import logger +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + from torch.utils.data import DataLoader + from torch.optim import SGD + import torch.distributed as dist @dataclass class ArgMaxDatasetConfig: @@ -550,7 +552,7 @@ def test_trainer_checkpoint_callback_2( if version == 0: callbacks = [ - TrainerCheckpointCallback( + CheckpointCallback( monitor="acc", folder=path, every_n_epochs=None, @@ -558,12 +560,13 @@ def test_trainer_checkpoint_callback_2( topk=None, last=False, on_exception=None, - model_save_fn=model_save_fn + model_save_fn=model_save_fn, + save_object="trainer" ) ] elif version == 1: callbacks = [ - TrainerCheckpointCallback( + CheckpointCallback( monitor="acc", folder=path, every_n_epochs=None, @@ -571,7 +574,8 @@ def test_trainer_checkpoint_callback_2( topk=1, last=True, on_exception=None, - model_save_fn=model_save_fn + model_save_fn=model_save_fn, + save_object="trainer" ) ] diff --git a/tests/core/callbacks/test_more_evaluate_callback.py b/tests/core/callbacks/test_more_evaluate_callback.py index 2b59ccd5..08c6f8e2 100644 --- a/tests/core/callbacks/test_more_evaluate_callback.py +++ b/tests/core/callbacks/test_more_evaluate_callback.py @@ -12,9 +12,7 @@ import os import pytest from typing import Any from dataclasses import dataclass -from torch.utils.data import DataLoader -from torch.optim import SGD -import torch.distributed as dist + from pathlib import Path import re @@ -29,7 +27,11 @@ from torchmetrics import Accuracy from fastNLP.core.metrics import Metric from fastNLP.core.log import logger from fastNLP.core.callbacks import MoreEvaluateCallback - +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + from torch.utils.data import DataLoader + from torch.optim import SGD + import torch.distributed as dist @dataclass class ArgMaxDatasetConfig: diff --git a/tests/core/controllers/test_trainer_event_trigger.py b/tests/core/controllers/test_trainer_event_trigger.py index bcd89614..0d484ac7 100644 --- a/tests/core/controllers/test_trainer_event_trigger.py +++ b/tests/core/controllers/test_trainer_event_trigger.py @@ -1,10 +1,7 @@ import pytest from typing import Any from dataclasses import dataclass -from torch.optim import SGD -from torch.utils.data import DataLoader -from torchmetrics import Accuracy -import torch.distributed as dist + from fastNLP.core.controllers.trainer import Trainer from fastNLP.core.callbacks.callback_events import Events @@ -12,6 +9,12 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification from tests.helpers.callbacks.helper_callbacks import RecordTrainerEventTriggerCallback from tests.helpers.utils import magic_argv_env_context, Capturing +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + from torch.optim import SGD + from torch.utils.data import DataLoader + from torchmetrics import Accuracy + import torch.distributed as dist @dataclass diff --git a/tests/core/controllers/test_trainer_w_evaluator_torch.py b/tests/core/controllers/test_trainer_w_evaluator_torch.py index 891626b5..f44bd735 100644 --- a/tests/core/controllers/test_trainer_w_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_w_evaluator_torch.py @@ -2,9 +2,7 @@ 注意这一文件中的测试函数都应当是在 `test_trainer_w_evaluator_torch.py` 中已经测试过的测试函数的基础上加上 metrics 和 evaluator 修改而成; """ import pytest -from torch.optim import SGD -from torch.utils.data import DataLoader -import torch.distributed as dist + from dataclasses import dataclass from typing import Any from torchmetrics import Accuracy @@ -14,7 +12,11 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback from tests.helpers.utils import magic_argv_env_context - +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + from torch.optim import SGD + from torch.utils.data import DataLoader + import torch.distributed as dist @dataclass class NormalClassificationTrainTorchConfig: diff --git a/tests/core/controllers/test_trainer_wo_evaluator_torch.py b/tests/core/controllers/test_trainer_wo_evaluator_torch.py index aa86ef92..74e5058e 100644 --- a/tests/core/controllers/test_trainer_wo_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_wo_evaluator_torch.py @@ -2,9 +2,7 @@ import os.path import subprocess import sys import pytest -import torch.distributed as dist -from torch.optim import SGD -from torch.utils.data import DataLoader + from dataclasses import dataclass from typing import Any from pathlib import Path @@ -16,6 +14,11 @@ from tests.helpers.callbacks.helper_callbacks import RecordLossCallback from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch from tests.helpers.utils import magic_argv_env_context, Capturing from fastNLP.core import rank_zero_rm +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + import torch.distributed as dist + from torch.optim import SGD + from torch.utils.data import DataLoader @dataclass diff --git a/tests/core/drivers/jittor_driver/test_single_device.py b/tests/core/drivers/jittor_driver/test_single_device.py index 8bbceed9..2e220974 100644 --- a/tests/core/drivers/jittor_driver/test_single_device.py +++ b/tests/core/drivers/jittor_driver/test_single_device.py @@ -15,7 +15,7 @@ else: -class Model (Module): +class Model(Module): def __init__ (self): super (Model, self).__init__() self.conv1 = nn.Conv (3, 32, 3, 1) # no padding @@ -45,6 +45,7 @@ class Model (Module): return x @pytest.mark.jittor +@pytest.mark.skip("Skip jittor tests now.") class TestSingleDevice: def test_on_gpu_without_fp16(self): diff --git a/tests/core/drivers/torch_driver/test_ddp.py b/tests/core/drivers/torch_driver/test_ddp.py index 11799515..d6f0ee77 100644 --- a/tests/core/drivers/torch_driver/test_ddp.py +++ b/tests/core/drivers/torch_driver/test_ddp.py @@ -92,7 +92,6 @@ def test_multi_drivers(): dist.destroy_process_group() @pytest.mark.torch -@pytest.mark.torchtemp class TestDDPDriverFunction: """ 测试 TorchDDPDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题 @@ -176,7 +175,6 @@ class TestDDPDriverFunction: ############################################################################ @pytest.mark.torch -@pytest.mark.torchtemp class TestSetDistReproDataloader: @classmethod diff --git a/tests/core/drivers/torch_driver/test_initialize_torch_driver.py b/tests/core/drivers/torch_driver/test_initialize_torch_driver.py index 8992867e..9c3bd8f9 100644 --- a/tests/core/drivers/torch_driver/test_initialize_torch_driver.py +++ b/tests/core/drivers/torch_driver/test_initialize_torch_driver.py @@ -8,6 +8,9 @@ from fastNLP.envs.imports import _NEED_IMPORT_TORCH if _NEED_IMPORT_TORCH: import torch import torch.distributed as dist + from torch import device as torchdevice +else: + from fastNLP.core.utils.dummy_class import DummyClass as torchdevice @pytest.mark.torch def test_incorrect_driver(): @@ -20,7 +23,7 @@ def test_incorrect_driver(): @pytest.mark.torch @pytest.mark.parametrize( "device", - ["cpu", "cuda:0", 0, torch.device("cuda:0")] + ["cpu", "cuda:0", 0, torchdevice("cuda:0")] ) @pytest.mark.parametrize( "driver", @@ -101,7 +104,7 @@ def test_get_ddp_cpu(driver, device): @pytest.mark.torch @pytest.mark.parametrize( "device", - [-2, [0, torch.cuda.device_count() + 1, 3], [-2], torch.cuda.device_count() + 1] + [-2, [0, 20, 3], [-2], 20] ) @pytest.mark.parametrize( "driver", diff --git a/tests/core/metrics/test_accuracy_torch.py b/tests/core/metrics/test_accuracy_torch.py index b89d15db..cadf4e0e 100644 --- a/tests/core/metrics/test_accuracy_torch.py +++ b/tests/core/metrics/test_accuracy_torch.py @@ -7,15 +7,20 @@ import copy import socket import pytest import numpy as np -import torch -import torch.distributed -from torch.multiprocessing import Pool, set_start_method + from sklearn.metrics import accuracy_score as sklearn_accuracy from fastNLP.core.dataset import DataSet from fastNLP.core.metrics.accuracy import Accuracy from fastNLP.core.metrics.metric import Metric from .utils import find_free_network_port, setup_ddp, _assert_allclose +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + import torch + import torch.distributed + from torch.multiprocessing import Pool, set_start_method +else: + from fastNLP.core.utils.dummy_class import DummyClass as set_start_method set_start_method("spawn", force=True) @@ -26,7 +31,7 @@ pool = None def _test(local_rank: int, world_size: int, - device: torch.device, + device: "torch.device", dataset: DataSet, metric_class: Type[Metric], metric_kwargs: Dict[str, Any], diff --git a/tests/core/metrics/test_classify_f1_pre_rec_metric_torch.py b/tests/core/metrics/test_classify_f1_pre_rec_metric_torch.py index bc006cb1..75203a3e 100644 --- a/tests/core/metrics/test_classify_f1_pre_rec_metric_torch.py +++ b/tests/core/metrics/test_classify_f1_pre_rec_metric_torch.py @@ -2,18 +2,23 @@ from functools import partial import copy import pytest -import torch + import numpy as np -from torch.multiprocessing import Pool, set_start_method from fastNLP.core.metrics import ClassifyFPreRecMetric from fastNLP.core.dataset import DataSet +from fastNLP.envs.imports import _NEED_IMPORT_TORCH from .utils import find_free_network_port, setup_ddp +if _NEED_IMPORT_TORCH: + import torch + from torch.multiprocessing import Pool, set_start_method +else: + from fastNLP.core.utils.dummy_class import DummyClass as set_start_method set_start_method("spawn", force=True) -def _test(local_rank: int, world_size: int, device: torch.device, +def _test(local_rank: int, world_size: int, device: "torch.device", dataset: DataSet, metric_class, metric_kwargs, metric_result): metric = metric_class(**metric_kwargs) # dataset 也类似(每个进程有自己的一个) diff --git a/tests/core/metrics/test_span_f1_rec_acc_torch.py b/tests/core/metrics/test_span_f1_rec_acc_torch.py index 72db05fc..0ebb9bdd 100644 --- a/tests/core/metrics/test_span_f1_rec_acc_torch.py +++ b/tests/core/metrics/test_span_f1_rec_acc_torch.py @@ -5,16 +5,21 @@ import os, sys import copy from functools import partial -import torch -import torch.distributed import numpy as np import socket -from torch.multiprocessing import Pool, set_start_method + # from multiprocessing import Pool, set_start_method from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.metrics import SpanFPreRecMetric from fastNLP.core.dataset import DataSet +from fastNLP.envs.imports import _NEED_IMPORT_TORCH from .utils import find_free_network_port, setup_ddp +if _NEED_IMPORT_TORCH: + import torch + import torch.distributed + from torch.multiprocessing import Pool, set_start_method +else: + from fastNLP.core.utils.dummy_class import DummyClass as set_start_method set_start_method("spawn", force=True) @@ -44,7 +49,7 @@ pool = None def _test(local_rank: int, world_size: int, - device: torch.device, + device: "torch.device", dataset: DataSet, metric_class, metric_kwargs, diff --git a/tests/core/metrics/utils.py b/tests/core/metrics/utils.py index 10157438..4126dc97 100644 --- a/tests/core/metrics/utils.py +++ b/tests/core/metrics/utils.py @@ -2,9 +2,11 @@ import os, sys import socket from typing import Union -import torch -from torch import distributed import numpy as np +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + import torch + from torch import distributed def setup_ddp(rank: int, world_size: int, master_port: int) -> None: diff --git a/tests/helpers/callbacks/helper_callbacks_torch.py b/tests/helpers/callbacks/helper_callbacks_torch.py index a197bb33..4b9730da 100644 --- a/tests/helpers/callbacks/helper_callbacks_torch.py +++ b/tests/helpers/callbacks/helper_callbacks_torch.py @@ -1,7 +1,9 @@ -import torch from copy import deepcopy from fastNLP.core.callbacks.callback import Callback +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + import torch class RecordAccumulationStepsCallback_Torch(Callback): diff --git a/tests/helpers/datasets/torch_data.py b/tests/helpers/datasets/torch_data.py index 9a0af019..7c9056cd 100644 --- a/tests/helpers/datasets/torch_data.py +++ b/tests/helpers/datasets/torch_data.py @@ -1,7 +1,11 @@ import torch from functools import reduce -from torch.utils.data import Dataset, DataLoader, DistributedSampler -from torch.utils.data.sampler import SequentialSampler, BatchSampler +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + from torch.utils.data import Dataset, DataLoader, DistributedSampler + from torch.utils.data.sampler import SequentialSampler, BatchSampler +else: + from fastNLP.core.utils.dummy_class import DummyClass as Dataset class TorchNormalDataset(Dataset): diff --git a/tests/helpers/models/torch_model.py b/tests/helpers/models/torch_model.py index 236ffda5..afb441ce 100644 --- a/tests/helpers/models/torch_model.py +++ b/tests/helpers/models/torch_model.py @@ -1,9 +1,14 @@ -import torch -import torch.nn as nn +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + import torch + from torch.nn import Module + import torch.nn as nn +else: + from fastNLP.core.utils.dummy_class import DummyClass as Module # 1. 最为基础的分类模型 -class TorchNormalModel_Classification_1(nn.Module): +class TorchNormalModel_Classification_1(Module): """ 单独实现 train_step 和 evaluate_step; """ @@ -38,7 +43,7 @@ class TorchNormalModel_Classification_1(nn.Module): return {"preds": x, "target": y} -class TorchNormalModel_Classification_2(nn.Module): +class TorchNormalModel_Classification_2(Module): """ 只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景; """ @@ -62,7 +67,7 @@ class TorchNormalModel_Classification_2(nn.Module): return {"loss": loss, "preds": x, "target": y} -class TorchNormalModel_Classification_3(nn.Module): +class TorchNormalModel_Classification_3(Module): """ 只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景; 关闭 auto_param_call,forward 只有一个 batch 参数;