|
|
@@ -7,7 +7,6 @@ from tests.helpers.utils import magic_argv_env_context |
|
|
|
from fastNLP.envs.imports import _NEED_IMPORT_TORCH |
|
|
|
if _NEED_IMPORT_TORCH: |
|
|
|
import torch |
|
|
|
import torch.distributed as dist |
|
|
|
from torch import device as torchdevice |
|
|
|
else: |
|
|
|
from fastNLP.core.utils.dummy_class import DummyClass as torchdevice |
|
|
@@ -58,9 +57,6 @@ def test_get_ddp_2(driver, device): |
|
|
|
driver = initialize_torch_driver(driver, device, model) |
|
|
|
|
|
|
|
assert isinstance(driver, TorchDDPDriver) |
|
|
|
dist.barrier() |
|
|
|
if dist.is_initialized(): |
|
|
|
dist.destroy_process_group() |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.torch |
|
|
@@ -82,9 +78,6 @@ def test_get_ddp(driver, device): |
|
|
|
driver = initialize_torch_driver(driver, device, model) |
|
|
|
|
|
|
|
assert isinstance(driver, TorchDDPDriver) |
|
|
|
dist.barrier() |
|
|
|
if dist.is_initialized(): |
|
|
|
dist.destroy_process_group() |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.torch |
|
|
|