Browse Source

移动env的位置为顶级目录

tags/v1.0.0alpha
yh_cc 2 years ago
parent
commit
40c0a712dd
12 changed files with 57 additions and 8 deletions
  1. +10
    -0
      fastNLP/__init__.py
  2. +22
    -0
      fastNLP/core/__init__.py
  3. +2
    -2
      fastNLP/core/log/logger.py
  4. +0
    -0
      fastNLP/envs/__init__.py
  5. +1
    -1
      fastNLP/envs/distributed.py
  6. +0
    -0
      fastNLP/envs/env.py
  7. +2
    -2
      fastNLP/envs/imports.py
  8. +3
    -3
      fastNLP/envs/set_backend.py
  9. +0
    -0
      fastNLP/envs/set_env_on_import.py
  10. +0
    -0
      fastNLP/envs/utils.py
  11. +0
    -0
      tests/envs/__init__.py
  12. +17
    -0
      tests/envs/test_set_backend.py

+ 10
- 0
fastNLP/__init__.py View File

@@ -0,0 +1,10 @@

# 首先保证 FASTNLP_GLOBAL_RANK 正确设置
from fastNLP.envs.set_env_on_import import set_env_on_import
set_env_on_import()

# 再设置 backend 相关
from fastNLP.envs.set_backend import _set_backend
_set_backend()

from fastNLP.core import Trainer, Evaluator

+ 22
- 0
fastNLP/core/__init__.py View File

@@ -0,0 +1,22 @@
__all__ = [
"TorchSingleDriver",
"TorchDDPDriver",
"PaddleSingleDriver",
"PaddleFleetDriver",
"JittorSingleDriver",
"JittorMPIDriver",
"TorchPaddleDriver",

"paddle_to",
"get_paddle_gpu_str",
"get_paddle_device_id",
"paddle_move_data_to_device",
"torch_paddle_move_data_to_device",
]
# TODO:之后要优化一下这里的导入,应该是每一个 sub module 先import自己内部的类和函数,然后外层的 module 再直接从 submodule 中 import;
from fastNLP.core.controllers.trainer import Trainer
from fastNLP.core.controllers.evaluator import Evaluator
from fastNLP.core.dataloaders.torch_dataloader import *

from .drivers import *
from .utils import *

+ 2
- 2
fastNLP/core/log/logger.py View File

@@ -32,8 +32,8 @@ __all__ = [
] ]


from fastNLP.core.log.handler import StdoutStreamHandler, TqdmLoggingHandler from fastNLP.core.log.handler import StdoutStreamHandler, TqdmLoggingHandler
from fastNLP.core.envs import FASTNLP_LOG_LEVEL, FASTNLP_GLOBAL_RANK, FASTNLP_LAUNCH_TIME, FASTNLP_BACKEND_LAUNCH
from fastNLP.core.envs import is_cur_env_distributed
from fastNLP.envs.env import FASTNLP_LOG_LEVEL, FASTNLP_GLOBAL_RANK, FASTNLP_LAUNCH_TIME, FASTNLP_BACKEND_LAUNCH
from fastNLP.envs.distributed import is_cur_env_distributed




ROOT_NAME = 'fastNLP' ROOT_NAME = 'fastNLP'


fastNLP/core/envs/__init__.py → fastNLP/envs/__init__.py View File


fastNLP/core/envs/distributed.py → fastNLP/envs/distributed.py View File

@@ -10,7 +10,7 @@ __all__ = [
'all_rank_call' 'all_rank_call'
] ]


from fastNLP.core.envs import FASTNLP_GLOBAL_RANK
from fastNLP.envs.env import FASTNLP_GLOBAL_RANK




def is_cur_env_distributed() -> bool: def is_cur_env_distributed() -> bool:

fastNLP/core/envs/env.py → fastNLP/envs/env.py View File


fastNLP/core/envs/imports.py → fastNLP/envs/imports.py View File

@@ -3,8 +3,8 @@ import os
import operator import operator




from fastNLP.core.envs.env import FASTNLP_BACKEND
from fastNLP.core.envs.utils import _module_available, _compare_version
from fastNLP.envs.env import FASTNLP_BACKEND
from fastNLP.envs.utils import _module_available, _compare_version




SUPPORT_BACKENDS = ['torch', 'paddle', 'jittor'] SUPPORT_BACKENDS = ['torch', 'paddle', 'jittor']

fastNLP/core/envs/set_backend.py → fastNLP/envs/set_backend.py View File

@@ -8,9 +8,9 @@ import sys
from collections import defaultdict from collections import defaultdict




from fastNLP.core.envs.env import FASTNLP_BACKEND, FASTNLP_GLOBAL_RANK, USER_CUDA_VISIBLE_DEVICES, FASTNLP_GLOBAL_SEED
from fastNLP.core.envs import SUPPORT_BACKENDS
from fastNLP.core.envs.utils import _module_available
from fastNLP.envs.env import FASTNLP_BACKEND, FASTNLP_GLOBAL_RANK, USER_CUDA_VISIBLE_DEVICES, FASTNLP_GLOBAL_SEED
from fastNLP.envs.imports import SUPPORT_BACKENDS
from fastNLP.envs.utils import _module_available




def _set_backend(): def _set_backend():

fastNLP/core/envs/set_env_on_import.py → fastNLP/envs/set_env_on_import.py View File


fastNLP/core/envs/utils.py → fastNLP/envs/utils.py View File


+ 0
- 0
tests/envs/__init__.py View File


+ 17
- 0
tests/envs/test_set_backend.py View File

@@ -0,0 +1,17 @@
import os

from fastNLP.envs.set_env import dump_fastnlp_backend
from tests.helpers.utils import Capturing
from fastNLP.core import synchronize_safe_rm


def test_dump_fastnlp_envs():
filepath = None
try:
with Capturing() as output:
dump_fastnlp_backend()
filepath = os.path.join(os.path.expanduser('~'), '.fastNLP', 'envs', os.environ['CONDA_DEFAULT_ENV']+'.json')
assert filepath in output[0]
assert os.path.exists(filepath)
finally:
synchronize_safe_rm(filepath)

Loading…
Cancel
Save