From 40c0a712dd653d6bae79aa8eb39d431cd1fca9cc Mon Sep 17 00:00:00 2001 From: yh_cc Date: Fri, 8 Apr 2022 22:29:22 +0800 Subject: [PATCH] =?UTF-8?q?=E7=A7=BB=E5=8A=A8env=E7=9A=84=E4=BD=8D?= =?UTF-8?q?=E7=BD=AE=E4=B8=BA=E9=A1=B6=E7=BA=A7=E7=9B=AE=E5=BD=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/__init__.py | 10 +++++++++ fastNLP/core/__init__.py | 22 ++++++++++++++++++++ fastNLP/core/log/logger.py | 4 ++-- fastNLP/{core => }/envs/__init__.py | 0 fastNLP/{core => }/envs/distributed.py | 2 +- fastNLP/{core => }/envs/env.py | 0 fastNLP/{core => }/envs/imports.py | 4 ++-- fastNLP/{core => }/envs/set_backend.py | 6 +++--- fastNLP/{core => }/envs/set_env_on_import.py | 0 fastNLP/{core => }/envs/utils.py | 0 tests/envs/__init__.py | 0 tests/envs/test_set_backend.py | 17 +++++++++++++++ 12 files changed, 57 insertions(+), 8 deletions(-) rename fastNLP/{core => }/envs/__init__.py (100%) rename fastNLP/{core => }/envs/distributed.py (97%) rename fastNLP/{core => }/envs/env.py (100%) rename fastNLP/{core => }/envs/imports.py (86%) rename fastNLP/{core => }/envs/set_backend.py (97%) rename fastNLP/{core => }/envs/set_env_on_import.py (100%) rename fastNLP/{core => }/envs/utils.py (100%) create mode 100644 tests/envs/__init__.py create mode 100644 tests/envs/test_set_backend.py diff --git a/fastNLP/__init__.py b/fastNLP/__init__.py index e69de29b..6007ed07 100644 --- a/fastNLP/__init__.py +++ b/fastNLP/__init__.py @@ -0,0 +1,10 @@ + +# 首先保证 FASTNLP_GLOBAL_RANK 正确设置 +from fastNLP.envs.set_env_on_import import set_env_on_import +set_env_on_import() + +# 再设置 backend 相关 +from fastNLP.envs.set_backend import _set_backend +_set_backend() + +from fastNLP.core import Trainer, Evaluator \ No newline at end of file diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index e69de29b..5cc765b9 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -0,0 +1,22 @@ +__all__ = [ + "TorchSingleDriver", + "TorchDDPDriver", + "PaddleSingleDriver", + "PaddleFleetDriver", + "JittorSingleDriver", + "JittorMPIDriver", + "TorchPaddleDriver", + + "paddle_to", + "get_paddle_gpu_str", + "get_paddle_device_id", + "paddle_move_data_to_device", + "torch_paddle_move_data_to_device", +] +# TODO:之后要优化一下这里的导入,应该是每一个 sub module 先import自己内部的类和函数,然后外层的 module 再直接从 submodule 中 import; +from fastNLP.core.controllers.trainer import Trainer +from fastNLP.core.controllers.evaluator import Evaluator +from fastNLP.core.dataloaders.torch_dataloader import * + +from .drivers import * +from .utils import * \ No newline at end of file diff --git a/fastNLP/core/log/logger.py b/fastNLP/core/log/logger.py index f9e9bdae..ae89ad3f 100644 --- a/fastNLP/core/log/logger.py +++ b/fastNLP/core/log/logger.py @@ -32,8 +32,8 @@ __all__ = [ ] from fastNLP.core.log.handler import StdoutStreamHandler, TqdmLoggingHandler -from fastNLP.core.envs import FASTNLP_LOG_LEVEL, FASTNLP_GLOBAL_RANK, FASTNLP_LAUNCH_TIME, FASTNLP_BACKEND_LAUNCH -from fastNLP.core.envs import is_cur_env_distributed +from fastNLP.envs.env import FASTNLP_LOG_LEVEL, FASTNLP_GLOBAL_RANK, FASTNLP_LAUNCH_TIME, FASTNLP_BACKEND_LAUNCH +from fastNLP.envs.distributed import is_cur_env_distributed ROOT_NAME = 'fastNLP' diff --git a/fastNLP/core/envs/__init__.py b/fastNLP/envs/__init__.py similarity index 100% rename from fastNLP/core/envs/__init__.py rename to fastNLP/envs/__init__.py diff --git a/fastNLP/core/envs/distributed.py b/fastNLP/envs/distributed.py similarity index 97% rename from fastNLP/core/envs/distributed.py rename to fastNLP/envs/distributed.py index dabc9f0a..f608272b 100644 --- a/fastNLP/core/envs/distributed.py +++ b/fastNLP/envs/distributed.py @@ -10,7 +10,7 @@ __all__ = [ 'all_rank_call' ] -from fastNLP.core.envs import FASTNLP_GLOBAL_RANK +from fastNLP.envs.env import FASTNLP_GLOBAL_RANK def is_cur_env_distributed() -> bool: diff --git a/fastNLP/core/envs/env.py b/fastNLP/envs/env.py similarity index 100% rename from fastNLP/core/envs/env.py rename to fastNLP/envs/env.py diff --git a/fastNLP/core/envs/imports.py b/fastNLP/envs/imports.py similarity index 86% rename from fastNLP/core/envs/imports.py rename to fastNLP/envs/imports.py index 3fed96f3..3d49afe4 100644 --- a/fastNLP/core/envs/imports.py +++ b/fastNLP/envs/imports.py @@ -3,8 +3,8 @@ import os import operator -from fastNLP.core.envs.env import FASTNLP_BACKEND -from fastNLP.core.envs.utils import _module_available, _compare_version +from fastNLP.envs.env import FASTNLP_BACKEND +from fastNLP.envs.utils import _module_available, _compare_version SUPPORT_BACKENDS = ['torch', 'paddle', 'jittor'] diff --git a/fastNLP/core/envs/set_backend.py b/fastNLP/envs/set_backend.py similarity index 97% rename from fastNLP/core/envs/set_backend.py rename to fastNLP/envs/set_backend.py index ac8dc33b..a1ac5efb 100644 --- a/fastNLP/core/envs/set_backend.py +++ b/fastNLP/envs/set_backend.py @@ -8,9 +8,9 @@ import sys from collections import defaultdict -from fastNLP.core.envs.env import FASTNLP_BACKEND, FASTNLP_GLOBAL_RANK, USER_CUDA_VISIBLE_DEVICES, FASTNLP_GLOBAL_SEED -from fastNLP.core.envs import SUPPORT_BACKENDS -from fastNLP.core.envs.utils import _module_available +from fastNLP.envs.env import FASTNLP_BACKEND, FASTNLP_GLOBAL_RANK, USER_CUDA_VISIBLE_DEVICES, FASTNLP_GLOBAL_SEED +from fastNLP.envs.imports import SUPPORT_BACKENDS +from fastNLP.envs.utils import _module_available def _set_backend(): diff --git a/fastNLP/core/envs/set_env_on_import.py b/fastNLP/envs/set_env_on_import.py similarity index 100% rename from fastNLP/core/envs/set_env_on_import.py rename to fastNLP/envs/set_env_on_import.py diff --git a/fastNLP/core/envs/utils.py b/fastNLP/envs/utils.py similarity index 100% rename from fastNLP/core/envs/utils.py rename to fastNLP/envs/utils.py diff --git a/tests/envs/__init__.py b/tests/envs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/envs/test_set_backend.py b/tests/envs/test_set_backend.py new file mode 100644 index 00000000..2c8fbadf --- /dev/null +++ b/tests/envs/test_set_backend.py @@ -0,0 +1,17 @@ +import os + +from fastNLP.envs.set_env import dump_fastnlp_backend +from tests.helpers.utils import Capturing +from fastNLP.core import synchronize_safe_rm + + +def test_dump_fastnlp_envs(): + filepath = None + try: + with Capturing() as output: + dump_fastnlp_backend() + filepath = os.path.join(os.path.expanduser('~'), '.fastNLP', 'envs', os.environ['CONDA_DEFAULT_ENV']+'.json') + assert filepath in output[0] + assert os.path.exists(filepath) + finally: + synchronize_safe_rm(filepath)