Browse Source

修正 initialize_torch_driver 的测试

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
175ced3905
1 changed files with 0 additions and 7 deletions
  1. +0
    -7
      tests/core/drivers/torch_driver/test_initialize_torch_driver.py

+ 0
- 7
tests/core/drivers/torch_driver/test_initialize_torch_driver.py View File

@@ -7,7 +7,6 @@ from tests.helpers.utils import magic_argv_env_context
from fastNLP.envs.imports import _NEED_IMPORT_TORCH from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH: if _NEED_IMPORT_TORCH:
import torch import torch
import torch.distributed as dist
from torch import device as torchdevice from torch import device as torchdevice
else: else:
from fastNLP.core.utils.dummy_class import DummyClass as torchdevice from fastNLP.core.utils.dummy_class import DummyClass as torchdevice
@@ -58,9 +57,6 @@ def test_get_ddp_2(driver, device):
driver = initialize_torch_driver(driver, device, model) driver = initialize_torch_driver(driver, device, model)


assert isinstance(driver, TorchDDPDriver) assert isinstance(driver, TorchDDPDriver)
dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()




@pytest.mark.torch @pytest.mark.torch
@@ -82,9 +78,6 @@ def test_get_ddp(driver, device):
driver = initialize_torch_driver(driver, device, model) driver = initialize_torch_driver(driver, device, model)


assert isinstance(driver, TorchDDPDriver) assert isinstance(driver, TorchDDPDriver)
dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()




@pytest.mark.torch @pytest.mark.torch


Loading…
Cancel
Save