Browse Source

修改测试中需要超过2个gpu的

tags/v1.0.0alpha
yh 2 years ago
parent
commit
6d2dca421d
7 changed files with 10 additions and 7 deletions
  1. +1
    -1
      fastNLP/core/callbacks/topk_saver.py
  2. +1
    -1
      tests/core/controllers/_test_distributed_launch_torch_1.py
  3. +1
    -1
      tests/core/controllers/test_trainer_event_trigger.py
  4. +4
    -1
      tests/core/dataloaders/jittor_dataloader/test_fdl.py
  5. +1
    -1
      tests/core/drivers/torch_driver/test_initialize_torch_driver.py
  6. +1
    -1
      tests/core/metrics/test_accuracy_torch.py
  7. +1
    -1
      tests/core/metrics/test_classify_f1_pre_rec_metric_torch.py

+ 1
- 1
fastNLP/core/callbacks/topk_saver.py View File

@@ -36,7 +36,7 @@ class Saver:
model_save_fn:Callable=None, **kwargs): model_save_fn:Callable=None, **kwargs):
if folder is None: if folder is None:
folder = Path.cwd().absolute() folder = Path.cwd().absolute()
logger.info(f"Parameter `folder` is None, and we will use {folder} to save and load your model.")
logger.info(f"Parameter `folder` is None, and fastNLP will use {folder} to save and load your model.")
folder = Path(folder) folder = Path(folder)
if not folder.exists(): if not folder.exists():
folder.mkdir(parents=True, exist_ok=True) folder.mkdir(parents=True, exist_ok=True)


+ 1
- 1
tests/core/controllers/_test_distributed_launch_torch_1.py View File

@@ -6,7 +6,7 @@ python -m torch.distributed.launch --nproc_per_node 2 tests/core/controllers/_te


import argparse import argparse
import os import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1"


import sys import sys
path = os.path.abspath(__file__) path = os.path.abspath(__file__)


+ 1
- 1
tests/core/controllers/test_trainer_event_trigger.py View File

@@ -224,7 +224,7 @@ def test_trainer_event_trigger_2(
assert k in output[0] assert k in output[0]




@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 6)])
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 0)])
@pytest.mark.torch @pytest.mark.torch
@magic_argv_env_context @magic_argv_env_context
def test_trainer_event_trigger_3( def test_trainer_event_trigger_3(


+ 4
- 1
tests/core/dataloaders/jittor_dataloader/test_fdl.py View File

@@ -1,6 +1,9 @@
import pytest import pytest
import numpy as np import numpy as np
from datasets import Dataset as HfDataset
from fastNLP.envs import _module_available

if _module_available('datasets'):
from datasets import Dataset as HfDataset


from fastNLP.core.dataloaders.jittor_dataloader import JittorDataLoader from fastNLP.core.dataloaders.jittor_dataloader import JittorDataLoader
from fastNLP.core.dataset import DataSet as Fdataset from fastNLP.core.dataset import DataSet as Fdataset


+ 1
- 1
tests/core/drivers/torch_driver/test_initialize_torch_driver.py View File

@@ -40,7 +40,7 @@ def test_get_single_device(driver, device):
@pytest.mark.torch @pytest.mark.torch
@pytest.mark.parametrize( @pytest.mark.parametrize(
"device", "device",
[[0, 2, 3], -1]
[[0, 1], -1]
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"driver", "driver",


+ 1
- 1
tests/core/metrics/test_accuracy_torch.py View File

@@ -102,7 +102,7 @@ class TestAccuracy:
metric_kwargs=metric_kwargs, metric_kwargs=metric_kwargs,
sklearn_metric=sklearn_accuracy, sklearn_metric=sklearn_accuracy,
), ),
[(rank, processes, torch.device(f'cuda:{rank+4}')) for rank in range(processes)]
[(rank, processes, torch.device(f'cuda:{rank}')) for rank in range(processes)]
) )
else: else:
device = torch.device( device = torch.device(


+ 1
- 1
tests/core/metrics/test_classify_f1_pre_rec_metric_torch.py View File

@@ -177,6 +177,6 @@ class TestClassfiyFPreRecMetric:
metric_class=ClassifyFPreRecMetric, metric_class=ClassifyFPreRecMetric,
metric_kwargs=metric_kwargs, metric_kwargs=metric_kwargs,
metric_result=ground_truth), metric_result=ground_truth),
[(rank, NUM_PROCESSES, torch.device(f'cuda:{rank+4}')) for rank in range(NUM_PROCESSES)])
[(rank, NUM_PROCESSES, torch.device(f'cuda:{rank}')) for rank in range(NUM_PROCESSES)])
pool.close() pool.close()
pool.join() pool.join()

Loading…
Cancel
Save