Browse Source

为torch测试例添加_NEED_IMPORT_TORCH

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
f79ee04956
16 changed files with 94 additions and 50 deletions
  1. +2
    -2
      fastNLP/core/utils/dummy_class.py
  2. +11
    -7
      tests/core/callbacks/test_checkpoint_callback_torch.py
  3. +6
    -4
      tests/core/callbacks/test_more_evaluate_callback.py
  4. +7
    -4
      tests/core/controllers/test_trainer_event_trigger.py
  5. +6
    -4
      tests/core/controllers/test_trainer_w_evaluator_torch.py
  6. +6
    -3
      tests/core/controllers/test_trainer_wo_evaluator_torch.py
  7. +2
    -1
      tests/core/drivers/jittor_driver/test_single_device.py
  8. +0
    -2
      tests/core/drivers/torch_driver/test_ddp.py
  9. +5
    -2
      tests/core/drivers/torch_driver/test_initialize_torch_driver.py
  10. +9
    -4
      tests/core/metrics/test_accuracy_torch.py
  11. +8
    -3
      tests/core/metrics/test_classify_f1_pre_rec_metric_torch.py
  12. +9
    -4
      tests/core/metrics/test_span_f1_rec_acc_torch.py
  13. +4
    -2
      tests/core/metrics/utils.py
  14. +3
    -1
      tests/helpers/callbacks/helper_callbacks_torch.py
  15. +6
    -2
      tests/helpers/datasets/torch_data.py
  16. +10
    -5
      tests/helpers/models/torch_model.py

+ 2
- 2
fastNLP/core/utils/dummy_class.py View File

@@ -1,5 +1,5 @@
import functools import functools


class DummyClass: class DummyClass:
def __call__(self, *args, **kwargs):
return
def __init__(self, *args, **kwargs):
pass

+ 11
- 7
tests/core/callbacks/test_checkpoint_callback_torch.py View File

@@ -2,9 +2,6 @@ import os
import pytest import pytest
from typing import Any from typing import Any
from dataclasses import dataclass from dataclasses import dataclass
from torch.utils.data import DataLoader
from torch.optim import SGD
import torch.distributed as dist
from pathlib import Path from pathlib import Path
import re import re
import time import time
@@ -20,6 +17,11 @@ from tests.helpers.datasets.torch_data import TorchArgMaxDataset
from torchmetrics import Accuracy from torchmetrics import Accuracy
from fastNLP.core.log import logger from fastNLP.core.log import logger


from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
from torch.utils.data import DataLoader
from torch.optim import SGD
import torch.distributed as dist


@dataclass @dataclass
class ArgMaxDatasetConfig: class ArgMaxDatasetConfig:
@@ -550,7 +552,7 @@ def test_trainer_checkpoint_callback_2(


if version == 0: if version == 0:
callbacks = [ callbacks = [
TrainerCheckpointCallback(
CheckpointCallback(
monitor="acc", monitor="acc",
folder=path, folder=path,
every_n_epochs=None, every_n_epochs=None,
@@ -558,12 +560,13 @@ def test_trainer_checkpoint_callback_2(
topk=None, topk=None,
last=False, last=False,
on_exception=None, on_exception=None,
model_save_fn=model_save_fn
model_save_fn=model_save_fn,
save_object="trainer"
) )
] ]
elif version == 1: elif version == 1:
callbacks = [ callbacks = [
TrainerCheckpointCallback(
CheckpointCallback(
monitor="acc", monitor="acc",
folder=path, folder=path,
every_n_epochs=None, every_n_epochs=None,
@@ -571,7 +574,8 @@ def test_trainer_checkpoint_callback_2(
topk=1, topk=1,
last=True, last=True,
on_exception=None, on_exception=None,
model_save_fn=model_save_fn
model_save_fn=model_save_fn,
save_object="trainer"
) )
] ]




+ 6
- 4
tests/core/callbacks/test_more_evaluate_callback.py View File

@@ -12,9 +12,7 @@ import os
import pytest import pytest
from typing import Any from typing import Any
from dataclasses import dataclass from dataclasses import dataclass
from torch.utils.data import DataLoader
from torch.optim import SGD
import torch.distributed as dist

from pathlib import Path from pathlib import Path
import re import re


@@ -29,7 +27,11 @@ from torchmetrics import Accuracy
from fastNLP.core.metrics import Metric from fastNLP.core.metrics import Metric
from fastNLP.core.log import logger from fastNLP.core.log import logger
from fastNLP.core.callbacks import MoreEvaluateCallback from fastNLP.core.callbacks import MoreEvaluateCallback

from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
from torch.utils.data import DataLoader
from torch.optim import SGD
import torch.distributed as dist


@dataclass @dataclass
class ArgMaxDatasetConfig: class ArgMaxDatasetConfig:


+ 7
- 4
tests/core/controllers/test_trainer_event_trigger.py View File

@@ -1,10 +1,7 @@
import pytest import pytest
from typing import Any from typing import Any
from dataclasses import dataclass from dataclasses import dataclass
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchmetrics import Accuracy
import torch.distributed as dist



from fastNLP.core.controllers.trainer import Trainer from fastNLP.core.controllers.trainer import Trainer
from fastNLP.core.callbacks.callback_events import Events from fastNLP.core.callbacks.callback_events import Events
@@ -12,6 +9,12 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification
from tests.helpers.callbacks.helper_callbacks import RecordTrainerEventTriggerCallback from tests.helpers.callbacks.helper_callbacks import RecordTrainerEventTriggerCallback
from tests.helpers.utils import magic_argv_env_context, Capturing from tests.helpers.utils import magic_argv_env_context, Capturing
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchmetrics import Accuracy
import torch.distributed as dist




@dataclass @dataclass


+ 6
- 4
tests/core/controllers/test_trainer_w_evaluator_torch.py View File

@@ -2,9 +2,7 @@
注意这一文件中的测试函数都应当是在 `test_trainer_w_evaluator_torch.py` 中已经测试过的测试函数的基础上加上 metrics 和 evaluator 修改而成; 注意这一文件中的测试函数都应当是在 `test_trainer_w_evaluator_torch.py` 中已经测试过的测试函数的基础上加上 metrics 和 evaluator 修改而成;
""" """
import pytest import pytest
from torch.optim import SGD
from torch.utils.data import DataLoader
import torch.distributed as dist

from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
from torchmetrics import Accuracy from torchmetrics import Accuracy
@@ -14,7 +12,11 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback
from tests.helpers.utils import magic_argv_env_context from tests.helpers.utils import magic_argv_env_context

from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
from torch.optim import SGD
from torch.utils.data import DataLoader
import torch.distributed as dist


@dataclass @dataclass
class NormalClassificationTrainTorchConfig: class NormalClassificationTrainTorchConfig:


+ 6
- 3
tests/core/controllers/test_trainer_wo_evaluator_torch.py View File

@@ -2,9 +2,7 @@ import os.path
import subprocess import subprocess
import sys import sys
import pytest import pytest
import torch.distributed as dist
from torch.optim import SGD
from torch.utils.data import DataLoader

from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
from pathlib import Path from pathlib import Path
@@ -16,6 +14,11 @@ from tests.helpers.callbacks.helper_callbacks import RecordLossCallback
from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch
from tests.helpers.utils import magic_argv_env_context, Capturing from tests.helpers.utils import magic_argv_env_context, Capturing
from fastNLP.core import rank_zero_rm from fastNLP.core import rank_zero_rm
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch.distributed as dist
from torch.optim import SGD
from torch.utils.data import DataLoader




@dataclass @dataclass


+ 2
- 1
tests/core/drivers/jittor_driver/test_single_device.py View File

@@ -15,7 +15,7 @@ else:






class Model (Module):
class Model(Module):
def __init__ (self): def __init__ (self):
super (Model, self).__init__() super (Model, self).__init__()
self.conv1 = nn.Conv (3, 32, 3, 1) # no padding self.conv1 = nn.Conv (3, 32, 3, 1) # no padding
@@ -45,6 +45,7 @@ class Model (Module):
return x return x


@pytest.mark.jittor @pytest.mark.jittor
@pytest.mark.skip("Skip jittor tests now.")
class TestSingleDevice: class TestSingleDevice:


def test_on_gpu_without_fp16(self): def test_on_gpu_without_fp16(self):


+ 0
- 2
tests/core/drivers/torch_driver/test_ddp.py View File

@@ -92,7 +92,6 @@ def test_multi_drivers():
dist.destroy_process_group() dist.destroy_process_group()


@pytest.mark.torch @pytest.mark.torch
@pytest.mark.torchtemp
class TestDDPDriverFunction: class TestDDPDriverFunction:
""" """
测试 TorchDDPDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题 测试 TorchDDPDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题
@@ -176,7 +175,6 @@ class TestDDPDriverFunction:
############################################################################ ############################################################################


@pytest.mark.torch @pytest.mark.torch
@pytest.mark.torchtemp
class TestSetDistReproDataloader: class TestSetDistReproDataloader:


@classmethod @classmethod


+ 5
- 2
tests/core/drivers/torch_driver/test_initialize_torch_driver.py View File

@@ -8,6 +8,9 @@ from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH: if _NEED_IMPORT_TORCH:
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch import device as torchdevice
else:
from fastNLP.core.utils.dummy_class import DummyClass as torchdevice


@pytest.mark.torch @pytest.mark.torch
def test_incorrect_driver(): def test_incorrect_driver():
@@ -20,7 +23,7 @@ def test_incorrect_driver():
@pytest.mark.torch @pytest.mark.torch
@pytest.mark.parametrize( @pytest.mark.parametrize(
"device", "device",
["cpu", "cuda:0", 0, torch.device("cuda:0")]
["cpu", "cuda:0", 0, torchdevice("cuda:0")]
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"driver", "driver",
@@ -101,7 +104,7 @@ def test_get_ddp_cpu(driver, device):
@pytest.mark.torch @pytest.mark.torch
@pytest.mark.parametrize( @pytest.mark.parametrize(
"device", "device",
[-2, [0, torch.cuda.device_count() + 1, 3], [-2], torch.cuda.device_count() + 1]
[-2, [0, 20, 3], [-2], 20]
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"driver", "driver",


+ 9
- 4
tests/core/metrics/test_accuracy_torch.py View File

@@ -7,15 +7,20 @@ import copy
import socket import socket
import pytest import pytest
import numpy as np import numpy as np
import torch
import torch.distributed
from torch.multiprocessing import Pool, set_start_method

from sklearn.metrics import accuracy_score as sklearn_accuracy from sklearn.metrics import accuracy_score as sklearn_accuracy


from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.core.metrics.accuracy import Accuracy from fastNLP.core.metrics.accuracy import Accuracy
from fastNLP.core.metrics.metric import Metric from fastNLP.core.metrics.metric import Metric
from .utils import find_free_network_port, setup_ddp, _assert_allclose from .utils import find_free_network_port, setup_ddp, _assert_allclose
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch
import torch.distributed
from torch.multiprocessing import Pool, set_start_method
else:
from fastNLP.core.utils.dummy_class import DummyClass as set_start_method


set_start_method("spawn", force=True) set_start_method("spawn", force=True)


@@ -26,7 +31,7 @@ pool = None


def _test(local_rank: int, def _test(local_rank: int,
world_size: int, world_size: int,
device: torch.device,
device: "torch.device",
dataset: DataSet, dataset: DataSet,
metric_class: Type[Metric], metric_class: Type[Metric],
metric_kwargs: Dict[str, Any], metric_kwargs: Dict[str, Any],


+ 8
- 3
tests/core/metrics/test_classify_f1_pre_rec_metric_torch.py View File

@@ -2,18 +2,23 @@ from functools import partial
import copy import copy


import pytest import pytest
import torch
import numpy as np import numpy as np
from torch.multiprocessing import Pool, set_start_method


from fastNLP.core.metrics import ClassifyFPreRecMetric from fastNLP.core.metrics import ClassifyFPreRecMetric
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from .utils import find_free_network_port, setup_ddp from .utils import find_free_network_port, setup_ddp
if _NEED_IMPORT_TORCH:
import torch
from torch.multiprocessing import Pool, set_start_method
else:
from fastNLP.core.utils.dummy_class import DummyClass as set_start_method


set_start_method("spawn", force=True) set_start_method("spawn", force=True)




def _test(local_rank: int, world_size: int, device: torch.device,
def _test(local_rank: int, world_size: int, device: "torch.device",
dataset: DataSet, metric_class, metric_kwargs, metric_result): dataset: DataSet, metric_class, metric_kwargs, metric_result):
metric = metric_class(**metric_kwargs) metric = metric_class(**metric_kwargs)
# dataset 也类似(每个进程有自己的一个) # dataset 也类似(每个进程有自己的一个)


+ 9
- 4
tests/core/metrics/test_span_f1_rec_acc_torch.py View File

@@ -5,16 +5,21 @@ import os, sys
import copy import copy
from functools import partial from functools import partial


import torch
import torch.distributed
import numpy as np import numpy as np
import socket import socket
from torch.multiprocessing import Pool, set_start_method
# from multiprocessing import Pool, set_start_method # from multiprocessing import Pool, set_start_method
from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.metrics import SpanFPreRecMetric from fastNLP.core.metrics import SpanFPreRecMetric
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from .utils import find_free_network_port, setup_ddp from .utils import find_free_network_port, setup_ddp
if _NEED_IMPORT_TORCH:
import torch
import torch.distributed
from torch.multiprocessing import Pool, set_start_method
else:
from fastNLP.core.utils.dummy_class import DummyClass as set_start_method


set_start_method("spawn", force=True) set_start_method("spawn", force=True)


@@ -44,7 +49,7 @@ pool = None


def _test(local_rank: int, def _test(local_rank: int,
world_size: int, world_size: int,
device: torch.device,
device: "torch.device",
dataset: DataSet, dataset: DataSet,
metric_class, metric_class,
metric_kwargs, metric_kwargs,


+ 4
- 2
tests/core/metrics/utils.py View File

@@ -2,9 +2,11 @@ import os, sys
import socket import socket
from typing import Union from typing import Union


import torch
from torch import distributed
import numpy as np import numpy as np
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch
from torch import distributed




def setup_ddp(rank: int, world_size: int, master_port: int) -> None: def setup_ddp(rank: int, world_size: int, master_port: int) -> None:


+ 3
- 1
tests/helpers/callbacks/helper_callbacks_torch.py View File

@@ -1,7 +1,9 @@
import torch
from copy import deepcopy from copy import deepcopy


from fastNLP.core.callbacks.callback import Callback from fastNLP.core.callbacks.callback import Callback
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch




class RecordAccumulationStepsCallback_Torch(Callback): class RecordAccumulationStepsCallback_Torch(Callback):


+ 6
- 2
tests/helpers/datasets/torch_data.py View File

@@ -1,7 +1,11 @@
import torch import torch
from functools import reduce from functools import reduce
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from torch.utils.data.sampler import SequentialSampler, BatchSampler
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from torch.utils.data.sampler import SequentialSampler, BatchSampler
else:
from fastNLP.core.utils.dummy_class import DummyClass as Dataset




class TorchNormalDataset(Dataset): class TorchNormalDataset(Dataset):


+ 10
- 5
tests/helpers/models/torch_model.py View File

@@ -1,9 +1,14 @@
import torch
import torch.nn as nn
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch
from torch.nn import Module
import torch.nn as nn
else:
from fastNLP.core.utils.dummy_class import DummyClass as Module




# 1. 最为基础的分类模型 # 1. 最为基础的分类模型
class TorchNormalModel_Classification_1(nn.Module):
class TorchNormalModel_Classification_1(Module):
""" """
单独实现 train_step 和 evaluate_step; 单独实现 train_step 和 evaluate_step;
""" """
@@ -38,7 +43,7 @@ class TorchNormalModel_Classification_1(nn.Module):
return {"preds": x, "target": y} return {"preds": x, "target": y}




class TorchNormalModel_Classification_2(nn.Module):
class TorchNormalModel_Classification_2(Module):
""" """
只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景; 只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景;
""" """
@@ -62,7 +67,7 @@ class TorchNormalModel_Classification_2(nn.Module):
return {"loss": loss, "preds": x, "target": y} return {"loss": loss, "preds": x, "target": y}




class TorchNormalModel_Classification_3(nn.Module):
class TorchNormalModel_Classification_3(Module):
""" """
只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景; 只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景;
关闭 auto_param_call,forward 只有一个 batch 参数; 关闭 auto_param_call,forward 只有一个 batch 参数;


Loading…
Cancel
Save