Browse Source

补充需要 _NEED_IMPORT_TORCH 的地方

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
c977d3be02
11 changed files with 48 additions and 29 deletions
  1. +5
    -2
      fastNLP/embeddings/torch/static_embedding.py
  2. +10
    -4
      fastNLP/modules/torch/encoder/lstm.py
  3. +1
    -1
      tests/core/callbacks/test_checkpoint_callback_torch.py
  4. +2
    -1
      tests/core/callbacks/test_more_evaluate_callback.py
  5. +8
    -6
      tests/core/controllers/_test_distributed_launch_torch_1.py
  6. +6
    -5
      tests/core/controllers/_test_distributed_launch_torch_2.py
  7. +1
    -1
      tests/core/controllers/test_trainer_w_evaluator_torch.py
  8. +4
    -2
      tests/core/drivers/torch_driver/test_dist_utils.py
  9. +3
    -2
      tests/core/drivers/torch_driver/test_utils.py
  10. +5
    -2
      tests/core/samplers/test_reproducible_batch_sampler_torch.py
  11. +3
    -3
      tests/helpers/datasets/torch_data.py

+ 5
- 2
fastNLP/embeddings/torch/static_embedding.py View File

@@ -14,14 +14,17 @@ import json
from typing import Union

import numpy as np
import torch
import torch.nn as nn

from .embedding import TokenEmbedding
from ...core import logger
from ...core.vocabulary import Vocabulary
from ...io.file_utils import PRETRAIN_STATIC_FILES, _get_embedding_url, cached_path
from ...io.file_utils import _get_file_name_base_on_postfix
from ...envs.imports import _NEED_IMPORT_TORCH

if _NEED_IMPORT_TORCH:
import torch
import torch.nn as nn


VOCAB_FILENAME = 'vocab.txt'


+ 10
- 4
fastNLP/modules/torch/encoder/lstm.py View File

@@ -7,12 +7,18 @@ __all__ = [
"LSTM"
]

import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn
from fastNLP.envs.imports import _NEED_IMPORT_TORCH

if _NEED_IMPORT_TORCH:
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn
from torch.nn import Module
else:
from fastNLP.core.utils.dummy_class import DummyClass as Module

class LSTM(nn.Module):

class LSTM(Module):
r"""
LSTM 模块, 轻量封装的Pytorch LSTM. 在提供seq_len的情况下,将自动使用pack_padded_sequence; 同时默认将forget gate的bias初始化
为1; 且可以应对DataParallel中LSTM的使用问题。


+ 1
- 1
tests/core/callbacks/test_checkpoint_callback_torch.py View File

@@ -15,7 +15,6 @@ from fastNLP.envs.distributed import rank_zero_rm
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchArgMaxDataset
from tests.helpers.utils import Capturing
from torchmetrics import Accuracy
from fastNLP.core.log import logger

from fastNLP.envs.imports import _NEED_IMPORT_TORCH
@@ -23,6 +22,7 @@ if _NEED_IMPORT_TORCH:
from torch.utils.data import DataLoader
from torch.optim import SGD
import torch.distributed as dist
from torchmetrics import Accuracy

@dataclass
class ArgMaxDatasetConfig:


+ 2
- 1
tests/core/callbacks/test_more_evaluate_callback.py View File

@@ -18,14 +18,15 @@ from tests.helpers.utils import magic_argv_env_context
from fastNLP.envs.distributed import rank_zero_rm
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchArgMaxDataset
from torchmetrics import Accuracy
from fastNLP.core.metrics import Metric
from fastNLP.core.callbacks import MoreEvaluateCallback
from fastNLP.envs.imports import _NEED_IMPORT_TORCH

if _NEED_IMPORT_TORCH:
from torch.utils.data import DataLoader
from torch.optim import SGD
import torch.distributed as dist
from torchmetrics import Accuracy

@dataclass
class ArgMaxDatasetConfig:


+ 8
- 6
tests/core/controllers/_test_distributed_launch_torch_1.py View File

@@ -19,18 +19,20 @@ for folder in list(folders[::-1]):
path = os.sep.join(folders)
sys.path.extend([path, os.path.join(path, 'fastNLP')])

import torch
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader
from torch.optim import SGD
import torch.distributed as dist
from dataclasses import dataclass
from torchmetrics import Accuracy

from fastNLP.core.controllers.trainer import Trainer
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from tests.helpers.models.torch_model import TorchNormalModel_Classification_2
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification

if _NEED_IMPORT_TORCH:
import torch
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader
from torch.optim import SGD
import torch.distributed as dist
from torchmetrics import Accuracy

@dataclass
class NormalClassificationTrainTorchConfig:


+ 6
- 5
tests/core/controllers/_test_distributed_launch_torch_2.py View File

@@ -19,17 +19,18 @@ for folder in list(folders[::-1]):
path = os.sep.join(folders)
sys.path.extend([path, os.path.join(path, 'fastNLP')])


from torch.utils.data import DataLoader
from torch.optim import SGD
import torch.distributed as dist
from dataclasses import dataclass
from torchmetrics import Accuracy

from fastNLP.core.controllers.trainer import Trainer
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1

if _NEED_IMPORT_TORCH:
from torch.utils.data import DataLoader
from torch.optim import SGD
import torch.distributed as dist
from torchmetrics import Accuracy

@dataclass
class NormalClassificationTrainTorchConfig:


+ 1
- 1
tests/core/controllers/test_trainer_w_evaluator_torch.py View File

@@ -5,7 +5,6 @@ import pytest

from dataclasses import dataclass
from typing import Any
from torchmetrics import Accuracy

from fastNLP.core.controllers.trainer import Trainer
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
@@ -18,6 +17,7 @@ if _NEED_IMPORT_TORCH:
from torch.optim import SGD
from torch.utils.data import DataLoader
import torch.distributed as dist
from torchmetrics import Accuracy


@dataclass


+ 4
- 2
tests/core/drivers/torch_driver/test_dist_utils.py View File

@@ -1,16 +1,18 @@
import os
import pytest

import torch
import torch.distributed as dist
import numpy as np

# print(isinstance((1,), tuple))
# exit()

from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object
from tests.helpers.utils import re_run_current_cmd_for_torch, magic_argv_env_context

if _NEED_IMPORT_TORCH:
import torch
import torch.distributed as dist

@pytest.mark.torch
@magic_argv_env_context


+ 3
- 2
tests/core/drivers/torch_driver/test_utils.py View File

@@ -5,10 +5,11 @@ from fastNLP.core.drivers.torch_driver.utils import (
replace_sampler,
)
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler
from torch.utils.data import DataLoader, BatchSampler

from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from tests.helpers.datasets.torch_data import TorchNormalDataset

if _NEED_IMPORT_TORCH:
from torch.utils.data import DataLoader, BatchSampler

@pytest.mark.torch
def test_replace_batch_sampler():


+ 5
- 2
tests/core/samplers/test_reproducible_batch_sampler_torch.py View File

@@ -1,13 +1,16 @@
from array import array
import torch
from torch.utils.data import DataLoader

import pytest

from fastNLP.core.samplers import ReproduceBatchSampler
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from tests.helpers.datasets.torch_data import TorchNormalDataset

if _NEED_IMPORT_TORCH:
import torch
from torch.utils.data import DataLoader


@pytest.mark.torch
class TestReproducibleBatchSamplerTorch:


+ 3
- 3
tests/helpers/datasets/torch_data.py View File

@@ -1,9 +1,9 @@
import torch
from functools import reduce
from fastNLP.envs.imports import _NEED_IMPORT_TORCH

if _NEED_IMPORT_TORCH:
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from torch.utils.data.sampler import SequentialSampler, BatchSampler
import torch
from torch.utils.data import Dataset
else:
from fastNLP.core.utils.dummy_class import DummyClass as Dataset



Loading…
Cancel
Save