@@ -0,0 +1,10 @@ | |||||
# 首先保证 FASTNLP_GLOBAL_RANK 正确设置 | |||||
from fastNLP.envs.set_env_on_import import set_env_on_import | |||||
set_env_on_import() | |||||
# 再设置 backend 相关 | |||||
from fastNLP.envs.set_backend import _set_backend | |||||
_set_backend() | |||||
from fastNLP.core import Trainer, Evaluator |
@@ -0,0 +1,22 @@ | |||||
__all__ = [ | |||||
"TorchSingleDriver", | |||||
"TorchDDPDriver", | |||||
"PaddleSingleDriver", | |||||
"PaddleFleetDriver", | |||||
"JittorSingleDriver", | |||||
"JittorMPIDriver", | |||||
"TorchPaddleDriver", | |||||
"paddle_to", | |||||
"get_paddle_gpu_str", | |||||
"get_paddle_device_id", | |||||
"paddle_move_data_to_device", | |||||
"torch_paddle_move_data_to_device", | |||||
] | |||||
# TODO:之后要优化一下这里的导入,应该是每一个 sub module 先import自己内部的类和函数,然后外层的 module 再直接从 submodule 中 import; | |||||
from fastNLP.core.controllers.trainer import Trainer | |||||
from fastNLP.core.controllers.evaluator import Evaluator | |||||
from fastNLP.core.dataloaders.torch_dataloader import * | |||||
from .drivers import * | |||||
from .utils import * |
@@ -32,8 +32,8 @@ __all__ = [ | |||||
] | ] | ||||
from fastNLP.core.log.handler import StdoutStreamHandler, TqdmLoggingHandler | from fastNLP.core.log.handler import StdoutStreamHandler, TqdmLoggingHandler | ||||
from fastNLP.core.envs import FASTNLP_LOG_LEVEL, FASTNLP_GLOBAL_RANK, FASTNLP_LAUNCH_TIME, FASTNLP_BACKEND_LAUNCH | |||||
from fastNLP.core.envs import is_cur_env_distributed | |||||
from fastNLP.envs.env import FASTNLP_LOG_LEVEL, FASTNLP_GLOBAL_RANK, FASTNLP_LAUNCH_TIME, FASTNLP_BACKEND_LAUNCH | |||||
from fastNLP.envs.distributed import is_cur_env_distributed | |||||
ROOT_NAME = 'fastNLP' | ROOT_NAME = 'fastNLP' | ||||
@@ -10,7 +10,7 @@ __all__ = [ | |||||
'all_rank_call' | 'all_rank_call' | ||||
] | ] | ||||
from fastNLP.core.envs import FASTNLP_GLOBAL_RANK | |||||
from fastNLP.envs.env import FASTNLP_GLOBAL_RANK | |||||
def is_cur_env_distributed() -> bool: | def is_cur_env_distributed() -> bool: |
@@ -3,8 +3,8 @@ import os | |||||
import operator | import operator | ||||
from fastNLP.core.envs.env import FASTNLP_BACKEND | |||||
from fastNLP.core.envs.utils import _module_available, _compare_version | |||||
from fastNLP.envs.env import FASTNLP_BACKEND | |||||
from fastNLP.envs.utils import _module_available, _compare_version | |||||
SUPPORT_BACKENDS = ['torch', 'paddle', 'jittor'] | SUPPORT_BACKENDS = ['torch', 'paddle', 'jittor'] |
@@ -8,9 +8,9 @@ import sys | |||||
from collections import defaultdict | from collections import defaultdict | ||||
from fastNLP.core.envs.env import FASTNLP_BACKEND, FASTNLP_GLOBAL_RANK, USER_CUDA_VISIBLE_DEVICES, FASTNLP_GLOBAL_SEED | |||||
from fastNLP.core.envs import SUPPORT_BACKENDS | |||||
from fastNLP.core.envs.utils import _module_available | |||||
from fastNLP.envs.env import FASTNLP_BACKEND, FASTNLP_GLOBAL_RANK, USER_CUDA_VISIBLE_DEVICES, FASTNLP_GLOBAL_SEED | |||||
from fastNLP.envs.imports import SUPPORT_BACKENDS | |||||
from fastNLP.envs.utils import _module_available | |||||
def _set_backend(): | def _set_backend(): |
@@ -0,0 +1,17 @@ | |||||
import os | |||||
from fastNLP.envs.set_env import dump_fastnlp_backend | |||||
from tests.helpers.utils import Capturing | |||||
from fastNLP.core import synchronize_safe_rm | |||||
def test_dump_fastnlp_envs(): | |||||
filepath = None | |||||
try: | |||||
with Capturing() as output: | |||||
dump_fastnlp_backend() | |||||
filepath = os.path.join(os.path.expanduser('~'), '.fastNLP', 'envs', os.environ['CONDA_DEFAULT_ENV']+'.json') | |||||
assert filepath in output[0] | |||||
assert os.path.exists(filepath) | |||||
finally: | |||||
synchronize_safe_rm(filepath) |