@@ -14,14 +14,17 @@ import json | |||||
from typing import Union | from typing import Union | ||||
import numpy as np | import numpy as np | ||||
import torch | |||||
import torch.nn as nn | |||||
from .embedding import TokenEmbedding | from .embedding import TokenEmbedding | ||||
from ...core import logger | from ...core import logger | ||||
from ...core.vocabulary import Vocabulary | from ...core.vocabulary import Vocabulary | ||||
from ...io.file_utils import PRETRAIN_STATIC_FILES, _get_embedding_url, cached_path | from ...io.file_utils import PRETRAIN_STATIC_FILES, _get_embedding_url, cached_path | ||||
from ...io.file_utils import _get_file_name_base_on_postfix | from ...io.file_utils import _get_file_name_base_on_postfix | ||||
from ...envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
import torch.nn as nn | |||||
VOCAB_FILENAME = 'vocab.txt' | VOCAB_FILENAME = 'vocab.txt' | ||||
@@ -7,12 +7,18 @@ __all__ = [ | |||||
"LSTM" | "LSTM" | ||||
] | ] | ||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.utils.rnn as rnn | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.utils.rnn as rnn | |||||
from torch.nn import Module | |||||
else: | |||||
from fastNLP.core.utils.dummy_class import DummyClass as Module | |||||
class LSTM(nn.Module): | |||||
class LSTM(Module): | |||||
r""" | r""" | ||||
LSTM 模块, 轻量封装的Pytorch LSTM. 在提供seq_len的情况下,将自动使用pack_padded_sequence; 同时默认将forget gate的bias初始化 | LSTM 模块, 轻量封装的Pytorch LSTM. 在提供seq_len的情况下,将自动使用pack_padded_sequence; 同时默认将forget gate的bias初始化 | ||||
为1; 且可以应对DataParallel中LSTM的使用问题。 | 为1; 且可以应对DataParallel中LSTM的使用问题。 | ||||
@@ -15,7 +15,6 @@ from fastNLP.envs.distributed import rank_zero_rm | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
from tests.helpers.datasets.torch_data import TorchArgMaxDataset | from tests.helpers.datasets.torch_data import TorchArgMaxDataset | ||||
from tests.helpers.utils import Capturing | from tests.helpers.utils import Capturing | ||||
from torchmetrics import Accuracy | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
@@ -23,6 +22,7 @@ if _NEED_IMPORT_TORCH: | |||||
from torch.utils.data import DataLoader | from torch.utils.data import DataLoader | ||||
from torch.optim import SGD | from torch.optim import SGD | ||||
import torch.distributed as dist | import torch.distributed as dist | ||||
from torchmetrics import Accuracy | |||||
@dataclass | @dataclass | ||||
class ArgMaxDatasetConfig: | class ArgMaxDatasetConfig: | ||||
@@ -18,14 +18,15 @@ from tests.helpers.utils import magic_argv_env_context | |||||
from fastNLP.envs.distributed import rank_zero_rm | from fastNLP.envs.distributed import rank_zero_rm | ||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
from tests.helpers.datasets.torch_data import TorchArgMaxDataset | from tests.helpers.datasets.torch_data import TorchArgMaxDataset | ||||
from torchmetrics import Accuracy | |||||
from fastNLP.core.metrics import Metric | from fastNLP.core.metrics import Metric | ||||
from fastNLP.core.callbacks import MoreEvaluateCallback | from fastNLP.core.callbacks import MoreEvaluateCallback | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
from torch.utils.data import DataLoader | from torch.utils.data import DataLoader | ||||
from torch.optim import SGD | from torch.optim import SGD | ||||
import torch.distributed as dist | import torch.distributed as dist | ||||
from torchmetrics import Accuracy | |||||
@dataclass | @dataclass | ||||
class ArgMaxDatasetConfig: | class ArgMaxDatasetConfig: | ||||
@@ -19,18 +19,20 @@ for folder in list(folders[::-1]): | |||||
path = os.sep.join(folders) | path = os.sep.join(folders) | ||||
sys.path.extend([path, os.path.join(path, 'fastNLP')]) | sys.path.extend([path, os.path.join(path, 'fastNLP')]) | ||||
import torch | |||||
from torch.nn.parallel import DistributedDataParallel | |||||
from torch.utils.data import DataLoader | |||||
from torch.optim import SGD | |||||
import torch.distributed as dist | |||||
from dataclasses import dataclass | from dataclasses import dataclass | ||||
from torchmetrics import Accuracy | |||||
from fastNLP.core.controllers.trainer import Trainer | from fastNLP.core.controllers.trainer import Trainer | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_2 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_2 | ||||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | ||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
from torch.nn.parallel import DistributedDataParallel | |||||
from torch.utils.data import DataLoader | |||||
from torch.optim import SGD | |||||
import torch.distributed as dist | |||||
from torchmetrics import Accuracy | |||||
@dataclass | @dataclass | ||||
class NormalClassificationTrainTorchConfig: | class NormalClassificationTrainTorchConfig: | ||||
@@ -19,17 +19,18 @@ for folder in list(folders[::-1]): | |||||
path = os.sep.join(folders) | path = os.sep.join(folders) | ||||
sys.path.extend([path, os.path.join(path, 'fastNLP')]) | sys.path.extend([path, os.path.join(path, 'fastNLP')]) | ||||
from torch.utils.data import DataLoader | |||||
from torch.optim import SGD | |||||
import torch.distributed as dist | |||||
from dataclasses import dataclass | from dataclasses import dataclass | ||||
from torchmetrics import Accuracy | |||||
from fastNLP.core.controllers.trainer import Trainer | from fastNLP.core.controllers.trainer import Trainer | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | ||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
if _NEED_IMPORT_TORCH: | |||||
from torch.utils.data import DataLoader | |||||
from torch.optim import SGD | |||||
import torch.distributed as dist | |||||
from torchmetrics import Accuracy | |||||
@dataclass | @dataclass | ||||
class NormalClassificationTrainTorchConfig: | class NormalClassificationTrainTorchConfig: | ||||
@@ -5,7 +5,6 @@ import pytest | |||||
from dataclasses import dataclass | from dataclasses import dataclass | ||||
from typing import Any | from typing import Any | ||||
from torchmetrics import Accuracy | |||||
from fastNLP.core.controllers.trainer import Trainer | from fastNLP.core.controllers.trainer import Trainer | ||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
@@ -18,6 +17,7 @@ if _NEED_IMPORT_TORCH: | |||||
from torch.optim import SGD | from torch.optim import SGD | ||||
from torch.utils.data import DataLoader | from torch.utils.data import DataLoader | ||||
import torch.distributed as dist | import torch.distributed as dist | ||||
from torchmetrics import Accuracy | |||||
@dataclass | @dataclass | ||||
@@ -1,16 +1,18 @@ | |||||
import os | import os | ||||
import pytest | import pytest | ||||
import torch | |||||
import torch.distributed as dist | |||||
import numpy as np | import numpy as np | ||||
# print(isinstance((1,), tuple)) | # print(isinstance((1,), tuple)) | ||||
# exit() | # exit() | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object | from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object | ||||
from tests.helpers.utils import re_run_current_cmd_for_torch, magic_argv_env_context | from tests.helpers.utils import re_run_current_cmd_for_torch, magic_argv_env_context | ||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
import torch.distributed as dist | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@@ -5,10 +5,11 @@ from fastNLP.core.drivers.torch_driver.utils import ( | |||||
replace_sampler, | replace_sampler, | ||||
) | ) | ||||
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler | from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler | ||||
from torch.utils.data import DataLoader, BatchSampler | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset | ||||
if _NEED_IMPORT_TORCH: | |||||
from torch.utils.data import DataLoader, BatchSampler | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
def test_replace_batch_sampler(): | def test_replace_batch_sampler(): | ||||
@@ -1,13 +1,16 @@ | |||||
from array import array | from array import array | ||||
import torch | |||||
from torch.utils.data import DataLoader | |||||
import pytest | import pytest | ||||
from fastNLP.core.samplers import ReproduceBatchSampler | from fastNLP.core.samplers import ReproduceBatchSampler | ||||
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset | ||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
from torch.utils.data import DataLoader | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
class TestReproducibleBatchSamplerTorch: | class TestReproducibleBatchSamplerTorch: | ||||
@@ -1,9 +1,9 @@ | |||||
import torch | |||||
from functools import reduce | from functools import reduce | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
from torch.utils.data import Dataset, DataLoader, DistributedSampler | |||||
from torch.utils.data.sampler import SequentialSampler, BatchSampler | |||||
import torch | |||||
from torch.utils.data import Dataset | |||||
else: | else: | ||||
from fastNLP.core.utils.dummy_class import DummyClass as Dataset | from fastNLP.core.utils.dummy_class import DummyClass as Dataset | ||||