diff --git a/fastNLP/embeddings/torch/static_embedding.py b/fastNLP/embeddings/torch/static_embedding.py index 8b555c6d..add5dcfc 100644 --- a/fastNLP/embeddings/torch/static_embedding.py +++ b/fastNLP/embeddings/torch/static_embedding.py @@ -14,14 +14,17 @@ import json from typing import Union import numpy as np -import torch -import torch.nn as nn from .embedding import TokenEmbedding from ...core import logger from ...core.vocabulary import Vocabulary from ...io.file_utils import PRETRAIN_STATIC_FILES, _get_embedding_url, cached_path from ...io.file_utils import _get_file_name_base_on_postfix +from ...envs.imports import _NEED_IMPORT_TORCH + +if _NEED_IMPORT_TORCH: + import torch + import torch.nn as nn VOCAB_FILENAME = 'vocab.txt' diff --git a/fastNLP/modules/torch/encoder/lstm.py b/fastNLP/modules/torch/encoder/lstm.py index bd0d844d..d5bccb82 100644 --- a/fastNLP/modules/torch/encoder/lstm.py +++ b/fastNLP/modules/torch/encoder/lstm.py @@ -7,12 +7,18 @@ __all__ = [ "LSTM" ] -import torch -import torch.nn as nn -import torch.nn.utils.rnn as rnn +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + import torch + import torch.nn as nn + import torch.nn.utils.rnn as rnn + from torch.nn import Module +else: + from fastNLP.core.utils.dummy_class import DummyClass as Module -class LSTM(nn.Module): + +class LSTM(Module): r""" LSTM 模块, 轻量封装的Pytorch LSTM. 在提供seq_len的情况下,将自动使用pack_padded_sequence; 同时默认将forget gate的bias初始化 为1; 且可以应对DataParallel中LSTM的使用问题。 diff --git a/tests/core/callbacks/test_checkpoint_callback_torch.py b/tests/core/callbacks/test_checkpoint_callback_torch.py index e148a1af..1fc2f9ee 100644 --- a/tests/core/callbacks/test_checkpoint_callback_torch.py +++ b/tests/core/callbacks/test_checkpoint_callback_torch.py @@ -15,7 +15,6 @@ from fastNLP.envs.distributed import rank_zero_rm from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.datasets.torch_data import TorchArgMaxDataset from tests.helpers.utils import Capturing -from torchmetrics import Accuracy from fastNLP.core.log import logger from fastNLP.envs.imports import _NEED_IMPORT_TORCH @@ -23,6 +22,7 @@ if _NEED_IMPORT_TORCH: from torch.utils.data import DataLoader from torch.optim import SGD import torch.distributed as dist + from torchmetrics import Accuracy @dataclass class ArgMaxDatasetConfig: diff --git a/tests/core/callbacks/test_more_evaluate_callback.py b/tests/core/callbacks/test_more_evaluate_callback.py index 3de0c422..43cbf231 100644 --- a/tests/core/callbacks/test_more_evaluate_callback.py +++ b/tests/core/callbacks/test_more_evaluate_callback.py @@ -18,14 +18,15 @@ from tests.helpers.utils import magic_argv_env_context from fastNLP.envs.distributed import rank_zero_rm from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.datasets.torch_data import TorchArgMaxDataset -from torchmetrics import Accuracy from fastNLP.core.metrics import Metric from fastNLP.core.callbacks import MoreEvaluateCallback from fastNLP.envs.imports import _NEED_IMPORT_TORCH + if _NEED_IMPORT_TORCH: from torch.utils.data import DataLoader from torch.optim import SGD import torch.distributed as dist + from torchmetrics import Accuracy @dataclass class ArgMaxDatasetConfig: diff --git a/tests/core/controllers/_test_distributed_launch_torch_1.py b/tests/core/controllers/_test_distributed_launch_torch_1.py index 5e543770..4692b008 100644 --- a/tests/core/controllers/_test_distributed_launch_torch_1.py +++ b/tests/core/controllers/_test_distributed_launch_torch_1.py @@ -19,18 +19,20 @@ for folder in list(folders[::-1]): path = os.sep.join(folders) sys.path.extend([path, os.path.join(path, 'fastNLP')]) -import torch -from torch.nn.parallel import DistributedDataParallel -from torch.utils.data import DataLoader -from torch.optim import SGD -import torch.distributed as dist from dataclasses import dataclass -from torchmetrics import Accuracy from fastNLP.core.controllers.trainer import Trainer +from fastNLP.envs.imports import _NEED_IMPORT_TORCH from tests.helpers.models.torch_model import TorchNormalModel_Classification_2 from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification +if _NEED_IMPORT_TORCH: + import torch + from torch.nn.parallel import DistributedDataParallel + from torch.utils.data import DataLoader + from torch.optim import SGD + import torch.distributed as dist + from torchmetrics import Accuracy @dataclass class NormalClassificationTrainTorchConfig: diff --git a/tests/core/controllers/_test_distributed_launch_torch_2.py b/tests/core/controllers/_test_distributed_launch_torch_2.py index ac753adc..a800ce2b 100644 --- a/tests/core/controllers/_test_distributed_launch_torch_2.py +++ b/tests/core/controllers/_test_distributed_launch_torch_2.py @@ -19,17 +19,18 @@ for folder in list(folders[::-1]): path = os.sep.join(folders) sys.path.extend([path, os.path.join(path, 'fastNLP')]) - -from torch.utils.data import DataLoader -from torch.optim import SGD -import torch.distributed as dist from dataclasses import dataclass -from torchmetrics import Accuracy from fastNLP.core.controllers.trainer import Trainer +from fastNLP.envs.imports import _NEED_IMPORT_TORCH from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 +if _NEED_IMPORT_TORCH: + from torch.utils.data import DataLoader + from torch.optim import SGD + import torch.distributed as dist + from torchmetrics import Accuracy @dataclass class NormalClassificationTrainTorchConfig: diff --git a/tests/core/controllers/test_trainer_w_evaluator_torch.py b/tests/core/controllers/test_trainer_w_evaluator_torch.py index c0348caa..2d525260 100644 --- a/tests/core/controllers/test_trainer_w_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_w_evaluator_torch.py @@ -5,7 +5,6 @@ import pytest from dataclasses import dataclass from typing import Any -from torchmetrics import Accuracy from fastNLP.core.controllers.trainer import Trainer from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 @@ -18,6 +17,7 @@ if _NEED_IMPORT_TORCH: from torch.optim import SGD from torch.utils.data import DataLoader import torch.distributed as dist + from torchmetrics import Accuracy @dataclass diff --git a/tests/core/drivers/torch_driver/test_dist_utils.py b/tests/core/drivers/torch_driver/test_dist_utils.py index a118e562..7248fe75 100644 --- a/tests/core/drivers/torch_driver/test_dist_utils.py +++ b/tests/core/drivers/torch_driver/test_dist_utils.py @@ -1,16 +1,18 @@ import os import pytest -import torch -import torch.distributed as dist import numpy as np # print(isinstance((1,), tuple)) # exit() +from fastNLP.envs.imports import _NEED_IMPORT_TORCH from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object from tests.helpers.utils import re_run_current_cmd_for_torch, magic_argv_env_context +if _NEED_IMPORT_TORCH: + import torch + import torch.distributed as dist @pytest.mark.torch @magic_argv_env_context diff --git a/tests/core/drivers/torch_driver/test_utils.py b/tests/core/drivers/torch_driver/test_utils.py index 8d5d3267..2bc2887a 100644 --- a/tests/core/drivers/torch_driver/test_utils.py +++ b/tests/core/drivers/torch_driver/test_utils.py @@ -5,10 +5,11 @@ from fastNLP.core.drivers.torch_driver.utils import ( replace_sampler, ) from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler -from torch.utils.data import DataLoader, BatchSampler - +from fastNLP.envs.imports import _NEED_IMPORT_TORCH from tests.helpers.datasets.torch_data import TorchNormalDataset +if _NEED_IMPORT_TORCH: + from torch.utils.data import DataLoader, BatchSampler @pytest.mark.torch def test_replace_batch_sampler(): diff --git a/tests/core/samplers/test_reproducible_batch_sampler_torch.py b/tests/core/samplers/test_reproducible_batch_sampler_torch.py index af180f56..11a63ef9 100644 --- a/tests/core/samplers/test_reproducible_batch_sampler_torch.py +++ b/tests/core/samplers/test_reproducible_batch_sampler_torch.py @@ -1,13 +1,16 @@ from array import array -import torch -from torch.utils.data import DataLoader import pytest from fastNLP.core.samplers import ReproduceBatchSampler from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler +from fastNLP.envs.imports import _NEED_IMPORT_TORCH from tests.helpers.datasets.torch_data import TorchNormalDataset +if _NEED_IMPORT_TORCH: + import torch + from torch.utils.data import DataLoader + @pytest.mark.torch class TestReproducibleBatchSamplerTorch: diff --git a/tests/helpers/datasets/torch_data.py b/tests/helpers/datasets/torch_data.py index 7c9056cd..1244a2f6 100644 --- a/tests/helpers/datasets/torch_data.py +++ b/tests/helpers/datasets/torch_data.py @@ -1,9 +1,9 @@ -import torch from functools import reduce from fastNLP.envs.imports import _NEED_IMPORT_TORCH + if _NEED_IMPORT_TORCH: - from torch.utils.data import Dataset, DataLoader, DistributedSampler - from torch.utils.data.sampler import SequentialSampler, BatchSampler + import torch + from torch.utils.data import Dataset else: from fastNLP.core.utils.dummy_class import DummyClass as Dataset