@@ -0,0 +1,18 @@ | |||||
r""" | |||||
""" | |||||
__all__ = [ | |||||
'dump_fastnlp_backend', | |||||
'is_cur_env_distributed', | |||||
'get_global_rank', | |||||
'rank_zero_call', | |||||
'all_rank_call' | |||||
] | |||||
from .env import * | |||||
from .set_env_on_import import set_env_on_import | |||||
from .set_backend import dump_fastnlp_backend | |||||
from .imports import * | |||||
from .utils import _module_available | |||||
from .distributed import * |
@@ -0,0 +1,75 @@ | |||||
import os | |||||
from functools import wraps | |||||
from typing import Callable, Any, Optional | |||||
from contextlib import contextmanager | |||||
__all__ = [ | |||||
'is_cur_env_distributed', | |||||
'get_global_rank', | |||||
'rank_zero_call', | |||||
'all_rank_call' | |||||
] | |||||
from fastNLP.core.envs import FASTNLP_GLOBAL_RANK | |||||
def is_cur_env_distributed() -> bool: | |||||
""" | |||||
单卡模式该函数一定返回 False; | |||||
注意进程 0 在多卡的训练模式下前后的值是不一样的,例如在开启多卡的 driver 之前,在进程 0 上的该函数返回 False;但是在开启后,在进程 0 上 | |||||
的该函数返回的值是 True; | |||||
多卡模式下除了进程 0 外的其它进程返回的值一定是 True; | |||||
""" | |||||
return FASTNLP_GLOBAL_RANK in os.environ | |||||
def get_global_rank(): | |||||
return int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) | |||||
def rank_zero_call(fn: Callable): | |||||
""" | |||||
通过该函数包裹的函数,在单卡模式下该方法不影响任何东西,在多卡状态下仅会在 global rank 为 0 的进程下执行。使用方式有两种 | |||||
# 使用方式1 | |||||
@rank_zero_call | |||||
def save_model(): | |||||
do_something # will only run in global rank 0 | |||||
# 使用方式2 | |||||
def add(a, b): | |||||
return a+b | |||||
rank_zero_call(add)(1, 2) | |||||
:param fn: 需要包裹的可执行的函数。 | |||||
:return: | |||||
""" | |||||
@wraps(fn) | |||||
def wrapped_fn(*args: Any, **kwargs: Any) -> Optional[Any]: | |||||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | |||||
return fn(*args, **kwargs) | |||||
return None | |||||
return wrapped_fn | |||||
@contextmanager | |||||
def all_rank_call(): | |||||
""" | |||||
在多卡模式下,该环境内,会暂时地将 FASTNLP_GLOBAL_RANK 设置为 "0",使得 rank_zero_call 函数失效,使得每个进程都会运行该函数。 | |||||
# 使用方式 | |||||
with all_rank_run(): | |||||
do_something # all rank will do | |||||
:param fn: | |||||
:return: | |||||
""" | |||||
old_fastnlp_global_rank = os.environ[FASTNLP_GLOBAL_RANK] if FASTNLP_GLOBAL_RANK in os.environ else None | |||||
os.environ[FASTNLP_GLOBAL_RANK] = '0' | |||||
yield | |||||
if old_fastnlp_global_rank is not None: | |||||
os.environ[FASTNLP_GLOBAL_RANK] = old_fastnlp_global_rank | |||||
else: | |||||
os.environ.pop(FASTNLP_GLOBAL_RANK) |
@@ -0,0 +1,52 @@ | |||||
# TODO 应该保证在__init__.py中的第一时间被调用。 | |||||
# FASTNLP_DISTRIBUTED_CHECK 用于这样的使用场景:用户可能在同一个脚本中连续使用两个独立的 trainer 实例,然后这两个 trainer 使用的都是 TorchDDPDriver; | |||||
# 因为我们在训练完成后不会主动地去关闭 ddp 的通信进程(目前没有),因此第二个 trainer 的 TorchDDPDriver 不会去真正地初始化 ddp 环境,而是会沿用 | |||||
# 第一个 trainer 的 TorchDDPDriver 所开启的 ddp 环境; | |||||
# 但是注意当第二个 TorchDDPDriver 的机器数量和使用的显卡数量(对应进程数量)发生变化时,这一沿用会造成严重的错误;因此在 TorchDDPDriver 第一次启动 ddp | |||||
# 环境后,我们会将 FASTNLP_DISTRIBUTED_CHECK 注入到环境变量中;从而在第二个 TorchDDPDriver 启动的时候会检测到该值,然后去检验当前使用的机器数量和每个机器 | |||||
# 上的进程的数量是否相等; | |||||
FASTNLP_DISTRIBUTED_CHECK = "FASTNLP_DISTRIBUTED_CHECK" | |||||
# 每一个 分布式的 driver 都应当正确地设立该值; | |||||
# FASTNLP_GLOBAL_RANK 用于给 fastNLP.core.utils.distributed.rank_zero_call 进行正确的配置。这是因为 TorchDDPDriver 初始化 ddp 环境的 | |||||
# 方式是开启多个和主进程基本一样的子进程,然后将所有代码从前到后完整地运行一遍。而在运行到 TorchDDPDriver 中的设立一些变量的正确地值之前,就已经 | |||||
# 运行到了某些需要区分主进程和其它进程的代码。 | |||||
# 因为考虑到用户可能在 Trainer 实例化前调用该函数修饰器,因此我们需要通过环境变量的方式在每一个子进程开始开始运行到被修饰的函数前就将 | |||||
# rank_zero_call 的 rank 值设立正确; | |||||
FASTNLP_GLOBAL_RANK = "FASTNLP_GLOBAL_RANK" | |||||
# FASTNLP_LOG_LEVEL 的使用场景和 FASTNLP_GLOBAL_RANK 类似,即用户在使用我们 log 的时候是在 trainer.run 之前的,这时我们要提前通过 | |||||
# 环境变量将该值设立正确; | |||||
FASTNLP_LOG_LEVEL = "FASTNLP_LOG_LEVEL" | |||||
# todo 每一个分布式的 driver 都应当正确地设立该值;具体可见 ddp; | |||||
# FASTNLP_LAUNCH_TIME 记录了当前 fastNLP 脚本启动的时间。 | |||||
FASTNLP_LAUNCH_TIME = "FASTNLP_LAUNCH_TIME" | |||||
# FASTNLP_GLOBAL_SEED 用于每个子进程随机数种子的正确设置; | |||||
FASTNLP_GLOBAL_SEED = "FASTNLP_GLOBAL_SEED" | |||||
# FASTNLP_SEED_WORKERS 用于 pytorch dataloader work_init_fn 的正确的设置; | |||||
FASTNLP_SEED_WORKERS = "FASTNLP_SEED_WORKERS" | |||||
# 用于设置 fastNLP 使用的 backend 框架 | |||||
FASTNLP_BACKEND = 'FASTNLP_BACKEND' | |||||
# 用于保存用户传入的 CUDA_VISIBLE_DEVICES,目前在paddle中有使用,用户不需要使用 | |||||
USER_CUDA_VISIBLE_DEVICES = 'USER_CUDA_VISIBLE_DEVICES' | |||||
# 用于在 torch.distributed.launch 时移除传入的 rank ,在 pytorch 中有使用。值的可选为 [0, 1] | |||||
FASTNLP_REMOVE_LOCAL_RANK = 'FASTNLP_REMOVE_LOCAL_RANK' | |||||
# todo 注释 | |||||
FASTNLP_BACKEND_LAUNCH = "FASTNLP_BACKEND_LAUNCH" | |||||
# todo 注释 直接使用的变量 | |||||
FASTNLP_MODEL_FILENAME = "fastnlp_model.pkl.tar" | |||||
FASTNLP_CHECKPOINT_FILENAME = "fastnlp_checkpoint.pkl.tar" | |||||
@@ -0,0 +1,26 @@ | |||||
import platform | |||||
import os | |||||
import operator | |||||
from fastNLP.core.envs.env import FASTNLP_BACKEND | |||||
from fastNLP.core.envs.utils import _module_available, _compare_version | |||||
SUPPORT_BACKENDS = ['torch', 'paddle', 'jittor'] | |||||
backend = os.environ.get(FASTNLP_BACKEND, 'all') | |||||
if backend == 'all': | |||||
need_import = SUPPORT_BACKENDS | |||||
elif ',' in backend: | |||||
need_import = list(map(str.strip, backend.split(','))) | |||||
else: | |||||
need_import = [backend] | |||||
_IS_WINDOWS = platform.system() == "Windows" | |||||
_NEED_IMPORT_FAIRSCALE = not _IS_WINDOWS and _module_available("fairscale.nn") and 'torch' in need_import | |||||
_NEED_IMPORT_TORCH = _module_available("torch") and 'torch' in need_import | |||||
_NEED_IMPORT_JITTOR = _module_available("jittor") and 'jittor' in need_import | |||||
_NEED_IMPORT_PADDLE = _module_available("paddle") and 'paddle' in need_import | |||||
_TORCH_GREATER_EQUAL_1_8 = _NEED_IMPORT_TORCH and _compare_version("torch", operator.ge, "1.8.0") |
@@ -0,0 +1,173 @@ | |||||
""" | |||||
这个文件用于自动以及手动设置某些环境变量的,该文件中的set_env()函数会在 fastNLP 被 import 的时候在set_env_on_import之后运行。可以 | |||||
用于设置某些必要的环境变量。同时用户在使用时set_env()修改环境变量时,也应该保证set_env()函数在所有其它代码之前被运行。 | |||||
""" | |||||
import os | |||||
import json | |||||
import sys | |||||
from collections import defaultdict | |||||
from fastNLP.core.envs.env import FASTNLP_BACKEND, FASTNLP_GLOBAL_RANK, USER_CUDA_VISIBLE_DEVICES, FASTNLP_GLOBAL_SEED | |||||
from fastNLP.core.envs import SUPPORT_BACKENDS | |||||
from fastNLP.core.envs.utils import _module_available | |||||
def _set_backend(): | |||||
""" | |||||
根据环境变量或者默认配置文件设置 backend 。 | |||||
backend 为 paddle 时,我们还将设置部分环境变量以使得 paddle 能够在 fastNLP 中正确运行。 | |||||
backend 为 jittor 时,我们将设置 log_silent:1 | |||||
:return: | |||||
""" | |||||
backend = '' | |||||
if FASTNLP_BACKEND in os.environ: | |||||
backend = os.environ[FASTNLP_BACKEND] | |||||
else: | |||||
# 从文件中读取的 | |||||
conda_env = os.environ.get('CONDA_DEFAULT_ENV', None) | |||||
if conda_env is None: | |||||
conda_env = 'default' | |||||
env_folder = os.path.join(os.path.expanduser('~'), '.fastNLP', 'envs') | |||||
env_path = os.path.join(env_folder, conda_env + '.json') | |||||
if os.path.exists(env_path): | |||||
try: | |||||
with open(env_path, 'r', encoding='utf8') as f: | |||||
envs = json.load(f) | |||||
# print(json.dumps(envs)) | |||||
if FASTNLP_BACKEND in envs: | |||||
backend = envs[FASTNLP_BACKEND] | |||||
os.environ[FASTNLP_BACKEND] = backend | |||||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | |||||
print(f"Set fastNLP backend as {backend} based on {env_path}.") | |||||
except BaseException as e: | |||||
raise e | |||||
if backend: | |||||
assert backend in SUPPORT_BACKENDS, f"Right now fastNLP only support the following backends:{SUPPORT_BACKENDS}, " \ | |||||
f"instead of `{backend}`" | |||||
if backend == 'paddle': | |||||
assert _module_available(backend), f"You must have {backend} available to use {backend} backend." | |||||
assert 'paddle' not in sys.modules, "You have to use `set_backend()` before `import paddle`." | |||||
if 'CUDA_VISIBLE_DEVICES' not in os.environ and 'PADDLE_RANK_IN_NODE' not in os.environ \ | |||||
and 'FLAGS_selected_gpus' not in os.environ: | |||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0' | |||||
os.environ[USER_CUDA_VISIBLE_DEVICES] = '' | |||||
elif 'CUDA_VISIBLE_DEVICES' in os.environ: | |||||
CUDA_VISIBLE_DEVICES = os.environ['CUDA_VISIBLE_DEVICES'] | |||||
os.environ[USER_CUDA_VISIBLE_DEVICES] = CUDA_VISIBLE_DEVICES | |||||
os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES.split(',')[0] | |||||
elif 'PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ: | |||||
# TODO 这里由于fastNLP需要hack CUDA_VISIBLE_DEVICES,因此需要相应滴修改FLAGS等paddle变量 @xsh | |||||
CUDA_VISIBLE_DEVICES = os.environ['FLAGS_selected_gpus'] | |||||
os.environ[USER_CUDA_VISIBLE_DEVICES] = CUDA_VISIBLE_DEVICES | |||||
os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES.split(',')[0] | |||||
os.environ['FLAGS_selected_gpus'] = "0" | |||||
os.environ['FLAGS_selected_accelerators'] = "0" | |||||
elif backend == 'jittor': | |||||
assert _module_available(backend), f"You must have {backend} available to use {backend} backend." | |||||
if "log_silent" not in os.environ: | |||||
os.environ["log_silent"] = "1" | |||||
if "CUDA_VISIBLE_DEVICES" in os.environ: | |||||
os.environ["use_cuda"] = "1" | |||||
elif backend == 'torch': | |||||
assert _module_available(backend), f"You must have {backend} available to use {backend} backend." | |||||
def set_env(global_seed=None): | |||||
""" | |||||
set_env 用于显式告知 fastNLP 将要使用的相关环境变量是什么,必须在代码最开端运行。以下的环境变量设置,优先级分别为:(1)在代码开始 | |||||
的位置显式调用设置;(2)通过环境变量注入的;(3)通过读取配置文件(如果有)。 | |||||
:param backend: 目前支持的 backend 有 torch, jittor, paddle 。设置特定的 backend 后,fastNLP 将不再加载其它 backend ,可以 | |||||
提高加载速度。该值对应环境变量中的 FASTNLP_BACKEND 。 | |||||
:param int global_seed: 对应环境变量为 FASTNLP_GLOBAL_SEED 。设置 fastNLP 的全局随机数。 | |||||
:param str log_level: 可选 ['INFO','WARNING', 'DEBUG', 'ERROR'] ,对应环境变量为 FASTNLP_LOG_LEVEL 。 | |||||
:return: | |||||
""" | |||||
_need_set_envs = [FASTNLP_GLOBAL_SEED] | |||||
_env_values = defaultdict(list) | |||||
if global_seed is not None: | |||||
assert isinstance(global_seed, int) | |||||
_env_values[FASTNLP_GLOBAL_SEED].append(global_seed) | |||||
# 直接读取环境变量的,这里应当是用户自己注入的环境变量 | |||||
for env_name in _need_set_envs: | |||||
if env_name in os.environ: | |||||
_env_values[env_name].append(os.environ.get(env_name)) | |||||
if FASTNLP_GLOBAL_SEED in _env_values: | |||||
os.environ[FASTNLP_GLOBAL_SEED] = _env_values[FASTNLP_GLOBAL_SEED][0] | |||||
# 针对不同的backend,做特定的设置 | |||||
backend = os.environ.get(FASTNLP_BACKEND, '') | |||||
if backend == 'paddle': | |||||
assert _module_available(backend), f"You must have {backend} available to use {backend} backend." | |||||
if os.environ.get(FASTNLP_GLOBAL_SEED, None) is not None: | |||||
seed_paddle_global_seed(int(os.environ.get(FASTNLP_GLOBAL_SEED))) | |||||
if backend == 'jittor': | |||||
assert _module_available(backend), f"You must have {backend} available to use {backend} backend." | |||||
if os.environ.get(FASTNLP_GLOBAL_SEED, None) is not None: | |||||
seed_jittor_global_seed(int(os.environ.get(FASTNLP_GLOBAL_SEED))) | |||||
if backend == 'torch': | |||||
assert _module_available(backend), f"You must have {backend} available to use {backend} backend." | |||||
if os.environ.get(FASTNLP_GLOBAL_SEED, None) is not None: | |||||
seed_torch_global_seed(int(os.environ.get(FASTNLP_GLOBAL_SEED))) | |||||
def seed_torch_global_seed(global_seed): | |||||
# @yxg | |||||
pass | |||||
def seed_paddle_global_seed(global_seed): | |||||
# @xsh | |||||
pass | |||||
def seed_jittor_global_seed(global_seed): | |||||
# @xsh | |||||
pass | |||||
def dump_fastnlp_backend(default:bool = False): | |||||
""" | |||||
将 fastNLP 的设置写入到 ~/.fastNLP/envs/ 文件夹下, | |||||
若 default 为 True,则保存的文件为 ~/.fastNLP/envs/default.json 。 | |||||
如 default 为 False,则保存的文件为 ~/.fastNLP/envs/{CONDA_DEFAULT_ENV}.json ,当CONDA_DEFAULT_ENV这个环境变量不存在时 | |||||
,报错。 | |||||
当 fastNLP 被 import 时,会默认尝试从 ~/.fastNLP/envs/{CONDA_DEFAULT_ENV}.json 读取配置文件,如果文件不存在,则尝试从 | |||||
~/.fastNLP/envs/default.json (如果有)读取环境变量。不过这些变量的优先级低于代码运行时的环境变量注入。 | |||||
会保存的环境变量为 FASTNLP_BACKEND 。 | |||||
:param default: | |||||
:return: | |||||
""" | |||||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | |||||
if default: | |||||
env_path = os.path.join(os.path.expanduser('~'), '.fastNLP', 'envs', 'default.json') | |||||
elif 'CONDA_DEFAULT_ENV' in os.environ: | |||||
env_path = os.path.join(os.path.expanduser('~'), '.fastNLP', 'envs', | |||||
os.environ.get('CONDA_DEFAULT_ENV') + '.json') | |||||
else: | |||||
raise RuntimeError("Did not found `CONDA_DEFAULT_ENV` in your environment variable.") | |||||
os.makedirs(os.path.dirname(env_path), exist_ok=True) | |||||
envs = {} | |||||
if FASTNLP_BACKEND in os.environ: | |||||
envs[FASTNLP_BACKEND] = os.environ[FASTNLP_BACKEND] | |||||
if len(envs): | |||||
with open(env_path, 'w', encoding='utf8') as f: | |||||
json.dump(fp=f, obj=envs) | |||||
print(f"Writing the default fastNLP backend:{envs[FASTNLP_BACKEND]} to {env_path}.") |
@@ -0,0 +1,66 @@ | |||||
# 本文件主要用于在分布式启动的情况下,各个backend应该可以提前确定 FASTNLP_GLOBAL_RANK(例如根据环境变量的中RANK值)。 | |||||
# 注意!仅有当确定我们的训练是分布式训练时,我们才会将 FASTNLP_GLOBAL_RANK 注入到环境变量中; | |||||
import os | |||||
import sys | |||||
from .env import * | |||||
import datetime | |||||
def remove_local_rank_in_argv(): | |||||
""" | |||||
通过 torch.distributed.launch 启动的时候,如果没有加入参数 --use_env ,pytorch 会默认通过 rank 注入 rank,这就 | |||||
要求代码中必须有能够 parse rank 的parser,这里将 rank 删除掉,防止后续报错。 | |||||
:return: | |||||
""" | |||||
index = -1 | |||||
for i, v in enumerate(sys.argv): | |||||
if v.startswith('--rank='): | |||||
os.environ['LOCAL_RANK'] = v.split('=')[1] | |||||
index = i | |||||
break | |||||
if index != -1: | |||||
sys.argv.pop(index) | |||||
def set_env_on_import_torch(): | |||||
if 'WORLD_SIZE' in os.environ and 'LOCAL_RANK' in os.environ and 'RANK' in os.environ: | |||||
os.environ[FASTNLP_GLOBAL_RANK] = os.environ['RANK'] | |||||
if int(os.environ.get(FASTNLP_REMOVE_LOCAL_RANK, 1)): | |||||
remove_local_rank_in_argv() | |||||
if 'WORLD_SIZE' in os.environ and 'LOCAL_RANK' in os.environ and 'RANK' in os.environ and \ | |||||
FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||||
os.environ[FASTNLP_BACKEND_LAUNCH] = '1' | |||||
# TODO paddle may need set this | |||||
def set_env_on_import_paddle(): | |||||
# todo 需要设置 FASTNLP_GLOBAL_RANK 和 FASTNLP_BACKEND_LAUNCH | |||||
pass | |||||
# TODO jittor may need set this | |||||
def set_env_on_import_jittor(): | |||||
# todo 需要设置 FASTNLP_GLOBAL_RANK 和 FASTNLP_BACKEND_LAUNCH | |||||
pass | |||||
def set_env_on_import(): | |||||
""" | |||||
设置环境变量 | |||||
:return: | |||||
""" | |||||
# 框架相关的变量设置 | |||||
set_env_on_import_torch() | |||||
set_env_on_import_paddle() | |||||
set_env_on_import_jittor() | |||||
# fastNLP 内部使用的一些变量 | |||||
if FASTNLP_LAUNCH_TIME not in os.environ: | |||||
cur_time = f"{datetime.datetime.now().strftime('%Y-%m-%d-%H_%M_%S_%M_%f')}" | |||||
os.environ[FASTNLP_LAUNCH_TIME] = cur_time | |||||
# 设置对应的值 | |||||
if FASTNLP_LOG_LEVEL not in os.environ: | |||||
os.environ[FASTNLP_LOG_LEVEL] = 'AUTO' |
@@ -0,0 +1,48 @@ | |||||
from importlib.util import find_spec | |||||
from typing import Callable | |||||
import importlib | |||||
from pkg_resources import DistributionNotFound | |||||
from packaging.version import Version | |||||
import pkg_resources | |||||
def _module_available(module_path: str) -> bool: | |||||
"""Check if a path is available in your environment. | |||||
>>> _module_available('os') | |||||
True | |||||
>>> _module_available('bla.bla') | |||||
False | |||||
""" | |||||
try: | |||||
return find_spec(module_path) is not None | |||||
except AttributeError: | |||||
# Python 3.6 | |||||
return False | |||||
except ModuleNotFoundError: | |||||
# Python 3.7+ | |||||
return False | |||||
def _compare_version(package: str, op: Callable, version: str, use_base_version: bool = False) -> bool: | |||||
"""Compare package version with some requirements. | |||||
>>> _compare_version("torch", operator.ge, "0.1") | |||||
True | |||||
""" | |||||
try: | |||||
pkg = importlib.import_module(package) | |||||
except (ModuleNotFoundError, DistributionNotFound): | |||||
return False | |||||
try: | |||||
if hasattr(pkg, "__version__"): | |||||
pkg_version = Version(pkg.__version__) | |||||
else: | |||||
# try pkg_resources to infer version | |||||
pkg_version = Version(pkg_resources.get_distribution(package).version) | |||||
except TypeError: | |||||
# this is mocked by Sphinx, so it should return True to generate all summaries | |||||
return True | |||||
if use_base_version: | |||||
pkg_version = Version(pkg_version.base_version) | |||||
return op(pkg_version, Version(version)) |
@@ -0,0 +1,6 @@ | |||||
__all__ = [ | |||||
'logger' | |||||
] | |||||
from .logger import logger | |||||
@@ -0,0 +1,89 @@ | |||||
import logging | |||||
import sys | |||||
from logging import getLevelName | |||||
try: | |||||
from tqdm.auto import tqdm | |||||
except ImportError: | |||||
tqdm = None | |||||
if tqdm is not None: | |||||
class TqdmLoggingHandler(logging.Handler): | |||||
def __init__(self, level=logging.INFO): | |||||
super().__init__(level) | |||||
def emit(self, record): | |||||
try: | |||||
msg = self.format(record) | |||||
tqdm.write(msg) | |||||
self.flush() | |||||
except (KeyboardInterrupt, SystemExit): | |||||
raise | |||||
except: | |||||
self.handleError(record) | |||||
else: | |||||
class TqdmLoggingHandler(logging.StreamHandler): | |||||
def __init__(self, level=logging.INFO): | |||||
super().__init__(sys.stdout) | |||||
self.setLevel(level) | |||||
class StdoutStreamHandler(logging.StreamHandler): | |||||
""" | |||||
重载 StreamHandler 使得替换 sys.stdout 的时候能够生效。 | |||||
""" | |||||
def __init__(self): | |||||
super(StdoutStreamHandler, self).__init__() | |||||
def flush(self): | |||||
""" | |||||
Flushes the stream. | |||||
""" | |||||
self.acquire() | |||||
try: | |||||
sys.stdout.flush() | |||||
finally: | |||||
self.release() | |||||
def emit(self, record): | |||||
""" | |||||
Emit a record. | |||||
If a formatter is specified, it is used to format the record. | |||||
The record is then written to the stream with a trailing newline. If | |||||
exception information is present, it is formatted using | |||||
traceback.print_exception and appended to the stream. If the stream | |||||
has an 'encoding' attribute, it is used to determine how to do the | |||||
output to the stream. | |||||
""" | |||||
try: | |||||
msg = self.format(record) | |||||
stream = sys.stdout | |||||
# issue 35046: merged two stream.writes into one. | |||||
stream.write(msg + self.terminator) | |||||
self.flush() | |||||
except RecursionError: # See issue 36272 | |||||
raise | |||||
except Exception: | |||||
self.handleError(record) | |||||
def setStream(self, stream): | |||||
""" | |||||
Sets the StreamHandler's stream to the specified value, | |||||
if it is different. | |||||
Returns the old stream, if the stream was changed, or None | |||||
if it wasn't. | |||||
""" | |||||
raise RuntimeError("Cannot set the stream of FStreamHandler.") | |||||
def __repr__(self): | |||||
level = getLevelName(self.level) | |||||
name = getattr(sys.stdout, 'name', '') | |||||
# bpo-36015: name can be an int | |||||
name = str(name) | |||||
if name: | |||||
name += ' ' | |||||
return '<%s %s(%s)>' % (self.__class__.__name__, name, level) |
@@ -0,0 +1,9 @@ | |||||
from rich.highlighter import Highlighter | |||||
class ColorHighlighter(Highlighter): | |||||
def __init__(self, color='black'): | |||||
self.color = color | |||||
def highlight(self, text): | |||||
text.stylize(self.color) |
@@ -0,0 +1,314 @@ | |||||
r""" | |||||
Logger 是fastNLP中记录日志的模块,logger封装了logging模块的Logger, | |||||
具体使用方式与直接使用logging.Logger相同,同时也新增一些简单好用的API | |||||
使用方式: | |||||
from fastNLP import _logger | |||||
# | |||||
# _logger 可以和 logging.Logger 一样使用 | |||||
_logger.info('your msg') | |||||
_logger.error('your msg') | |||||
# _logger 新增的API | |||||
# 将日志输出到文件,以及输出的日志等级 | |||||
_logger.add_file('/path/to/log', level='INFO') | |||||
# 定义在命令行中的显示格式和日志等级 | |||||
_logger.set_stdout('tqdm', level='WARN') | |||||
""" | |||||
import logging | |||||
import logging.config | |||||
from logging import DEBUG, ERROR, INFO, WARNING, CRITICAL, raiseExceptions | |||||
import os | |||||
import sys | |||||
import warnings | |||||
from pathlib import Path | |||||
from typing import Optional, Union | |||||
from rich.logging import RichHandler | |||||
__all__ = [ | |||||
'logger' | |||||
] | |||||
from fastNLP.core.log.handler import StdoutStreamHandler, TqdmLoggingHandler | |||||
from fastNLP.core.envs import FASTNLP_LOG_LEVEL, FASTNLP_GLOBAL_RANK, FASTNLP_LAUNCH_TIME, FASTNLP_BACKEND_LAUNCH | |||||
from fastNLP.core.envs import is_cur_env_distributed | |||||
ROOT_NAME = 'fastNLP' | |||||
class LoggerSingleton(type): | |||||
_instances = {} | |||||
def __call__(cls, *args, **kwargs): | |||||
if cls not in cls._instances: | |||||
cls._instances[cls] = super(LoggerSingleton, cls).__call__(*args, **kwargs) | |||||
return cls._instances[cls] | |||||
class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): | |||||
def __init__(self, name): | |||||
super().__init__(name) | |||||
def add_file(self, path: Optional[Union[str, Path]] = None, level='AUTO', remove_other_handlers: bool = False, | |||||
mode: str = "w"): | |||||
""" | |||||
将日志输出到 path 中。 | |||||
:param path: 若 path 为文件路径(通过 path 是否包含后缀判定 path 是否表示文件名,例如 output.log 会被认为是文件,而 | |||||
output 则认为是文件夹)则直接写入到给定文件中;如果判定为文件夹,则是在该文件夹下以 时间戳 创建一个日志文件。 | |||||
:param level: 可选 ['INFO', 'WARNING', 'DEBUG', 'ERROR', 'AUTO'], 其中AUTO表示根据环境变量"FASTNLP_LOG_LEVEL'进行 | |||||
设置。 | |||||
:param remove_other_handlers: 是否移除其它 handler ,如果移除,则terminal中将不会有 log 输出。 | |||||
:param mode: 可选为['w', 'a'],如果传入的 path 是存在的文件,'w' 会覆盖原有内容 'a' 则会在文件结尾处继续添加。 | |||||
:return: | |||||
""" | |||||
r"""添加日志输出文件和输出级别""" | |||||
if level == 'AUTO': | |||||
level = parse_level() | |||||
return _add_file_handler(self, path, level, remove_other_handlers, mode) | |||||
def set_stdout(self, stdout: str = 'raw', level: str = 'AUTO'): | |||||
""" | |||||
设置 log 的 terminal 输出形式。 | |||||
:param stdout: 可选['rich', 'naive', 'raw', 'none']。 | |||||
:param level: 可选 ['INFO', 'WARNING', 'DEBUG', 'ERROR', 'AUTO'], 其中AUTO表示根据环境变量"FASTNLP_LOG_LEVEL'进行 | |||||
设置。 | |||||
:return: | |||||
""" | |||||
r"""设置标准输出格式和输出级别""" | |||||
if level == 'AUTO': | |||||
level = parse_level() | |||||
return _set_stdout_handler(self, stdout, level) | |||||
def debug(self, msg, *args, **kwargs): | |||||
""" | |||||
Delegate a debug call to the underlying log. | |||||
""" | |||||
if self.isEnabledFor(DEBUG): | |||||
kwargs = self._add_rank_info(kwargs) | |||||
self._log(DEBUG, msg, args, **kwargs) | |||||
def info(self, msg, *args, **kwargs): | |||||
""" | |||||
Delegate an info call to the underlying log. | |||||
""" | |||||
if self.isEnabledFor(INFO): | |||||
kwargs = self._add_rank_info(kwargs) | |||||
self._log(INFO, msg, args, **kwargs) | |||||
def warning(self, msg, *args, **kwargs): | |||||
""" | |||||
Delegate a warning call to the underlying log. | |||||
""" | |||||
if self.isEnabledFor(WARNING): | |||||
kwargs = self._add_rank_info(kwargs) | |||||
self._log(WARNING, msg, args, **kwargs) | |||||
def warn(self, msg, *args, **kwargs): | |||||
warnings.warn("The 'warn' method is deprecated, " | |||||
"use 'warning' instead", DeprecationWarning, 2) | |||||
self.warning(msg, *args, **kwargs) | |||||
def error(self, msg, *args, **kwargs): | |||||
""" | |||||
Delegate an error call to the underlying log. | |||||
""" | |||||
if self.isEnabledFor(ERROR): | |||||
kwargs = self._add_rank_info(kwargs) | |||||
self._log(ERROR, msg, args, **kwargs) | |||||
def exception(self, msg, *args, exc_info=True, **kwargs): | |||||
""" | |||||
Delegate an exception call to the underlying log. | |||||
""" | |||||
kwargs = self._add_rank_info(kwargs) | |||||
self.error(msg, *args, exc_info=exc_info, **kwargs) | |||||
def critical(self, msg, *args, **kwargs): | |||||
""" | |||||
Delegate a critical call to the underlying log. | |||||
""" | |||||
if self.isEnabledFor(CRITICAL): | |||||
kwargs = self._add_rank_info(kwargs) | |||||
self._log(CRITICAL, msg, args, **kwargs) | |||||
def log(self, level, msg, *args, **kwargs): | |||||
""" | |||||
Delegate a log call to the underlying log, after adding | |||||
contextual information from this adapter instance. | |||||
""" | |||||
if not isinstance(level, int): | |||||
if raiseExceptions: | |||||
raise TypeError("level must be an integer") | |||||
else: | |||||
return | |||||
if self.isEnabledFor(level): | |||||
kwargs = self._add_rank_info(kwargs) | |||||
self._log(level, msg, args, **kwargs) | |||||
def _add_rank_info(self, kwargs): | |||||
if is_cur_env_distributed(): | |||||
extra = kwargs.get('extra', {}) | |||||
extra.update({"rank": int(os.environ.get(FASTNLP_GLOBAL_RANK, 0))}) | |||||
kwargs["extra"] = extra | |||||
return kwargs | |||||
def _get_level(level): | |||||
if not isinstance(level, int): | |||||
level = level.lower() | |||||
level = {'info': logging.INFO, 'debug': logging.DEBUG, | |||||
'warn': logging.WARN, 'warning': logging.WARNING, | |||||
'error': logging.ERROR}[level] | |||||
return level | |||||
def _add_file_handler(_logger: logging.Logger, path: Optional[Union[str, Path]] = None, level: str = 'INFO', | |||||
remove_other_handlers: bool = False, mode: str = "w"): | |||||
if path is None: | |||||
path = Path.cwd() | |||||
if isinstance(path, str): | |||||
path = Path(path) | |||||
if not isinstance(path, Path): | |||||
raise TypeError("Parameter `path` can only be `str` or `pathlib.Path` type.") | |||||
if not path.exists(): | |||||
head, tail = os.path.splitext(path) | |||||
if tail == '': # 说明没有后缀,理解为是一个folder | |||||
path.mkdir(parents=True, exist_ok=True) | |||||
else: | |||||
# 主进程会帮助我们创建文件夹,但是由于主从进程几乎是同步的,因此到这里时子进程也会尝试创建文件夹,即使主进程会做这件事情; | |||||
dirname = os.path.dirname(path) | |||||
os.makedirs(dirname, exist_ok=True) | |||||
if path.is_dir(): | |||||
if os.environ.get(FASTNLP_BACKEND_LAUNCH, '0')== '1': | |||||
# 如果是通过 python -m xxx.launch 等启动的话,FASTNLP_LAUNCH_TIME这个名称可能是不一致的。 | |||||
path = path.joinpath(f"RANK-{os.environ.get(FASTNLP_GLOBAL_RANK, '0')}-" + | |||||
os.environ.get(FASTNLP_LAUNCH_TIME) + '.log') | |||||
else: | |||||
path = path.joinpath(os.environ.get(FASTNLP_LAUNCH_TIME) + '.log') | |||||
if not isinstance(remove_other_handlers, bool): | |||||
raise TypeError("Parameter `remove_other_handlers` can only be `bool` type.") | |||||
if not isinstance(mode, str): | |||||
raise TypeError("Parameter 'mode' can only be `str` type.") | |||||
if mode not in {"w", "a"}: | |||||
raise ValueError("Parameter `mode` can only be one of these values: ('w', 'a').") | |||||
for h in _logger.handlers: | |||||
if isinstance(h, logging.FileHandler): | |||||
if os.path.abspath(path) == h.baseFilename: | |||||
# file path already added | |||||
return | |||||
# File Handler | |||||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | |||||
if os.path.exists(path): | |||||
assert os.path.isfile(path) | |||||
warnings.warn('log already exists in {}'.format(path)) | |||||
dirname = os.path.abspath(os.path.dirname(path)) | |||||
os.makedirs(dirname, exist_ok=True) | |||||
# 这里只要检测到是分布式训练,我们就将 mode 改为 "a";这样会导致的一个问题在于,如果第二次训练也是分布式训练,logger记录的log不会重新 | |||||
# 覆盖掉原文件,而是会接着上一次的 log 继续添加; | |||||
# 这样做主要是为了解决这样的情形所导致的问题:在分布式训练中,进程 1 比 进程 0 先运行到这里,然后使得进程 0 将进程 1 的 log 覆盖掉; | |||||
if is_cur_env_distributed():# and int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) != 0: | |||||
mode = "a" | |||||
file_handler = logging.FileHandler(path, mode=mode) | |||||
logger.info(f"Writing log to file:{os.path.abspath(path)}") | |||||
file_handler.setLevel(_get_level(level)) | |||||
if is_cur_env_distributed(): | |||||
file_formatter = logging.Formatter(fmt='Rank: %(rank)s - %(asctime)s - %(module)s - [%(levelname)s] - %(message)s', | |||||
datefmt='%Y/%m/%d %H:%M:%S') | |||||
else: | |||||
file_formatter = logging.Formatter(fmt='%(asctime)s - %(module)s - [%(levelname)s] - %(message)s', | |||||
datefmt='%Y/%m/%d %H:%M:%S') | |||||
file_handler.setFormatter(file_formatter) | |||||
_logger.addHandler(file_handler) | |||||
if remove_other_handlers: | |||||
_need_remove_handlers = [] | |||||
for i, h in enumerate(_logger.handlers): | |||||
if not isinstance(h, logging.FileHandler): | |||||
_need_remove_handlers.append(h) | |||||
for handler in _need_remove_handlers: | |||||
_logger.removeHandler(handler) | |||||
return file_handler | |||||
def _set_stdout_handler(_logger, stdout='raw', level='INFO'): | |||||
level = _get_level(level) | |||||
supported_stdout = ['none', 'raw', 'tqdm', 'naive', 'rich'] | |||||
if stdout not in supported_stdout: | |||||
raise ValueError('stdout must in one of {}'.format(supported_stdout)) | |||||
# make sure to initialize _logger only once | |||||
stream_handler = None | |||||
_handlers = (logging.StreamHandler, TqdmLoggingHandler, StdoutStreamHandler, RichHandler) | |||||
for i, h in enumerate(_logger.handlers): | |||||
if isinstance(h, _handlers): | |||||
stream_handler = h | |||||
break | |||||
if stream_handler is not None: | |||||
_logger.removeHandler(stream_handler) | |||||
# Stream Handler | |||||
if stdout == 'raw': | |||||
stream_handler = StdoutStreamHandler() | |||||
elif stdout == 'rich': | |||||
stream_handler = RichHandler(level=level, log_time_format="[%X]") | |||||
elif stdout == 'naive': | |||||
stream_handler = logging.StreamHandler(sys.stdout) | |||||
elif stdout == 'tqdm': | |||||
stream_handler = TqdmLoggingHandler(level) | |||||
else: | |||||
stream_handler = None | |||||
if stream_handler is not None: | |||||
if is_cur_env_distributed(): | |||||
stream_formatter = logging.Formatter('Rank: %(rank)s - %(message)s') | |||||
else: | |||||
stream_formatter = logging.Formatter('%(message)s') | |||||
stream_handler.setLevel(level) | |||||
stream_handler.setFormatter(stream_formatter) | |||||
_logger.addHandler(stream_handler) | |||||
return stream_handler | |||||
def _init_logger(path=None, stdout='rich', level='INFO'): | |||||
r"""initialize _logger""" | |||||
level = _get_level(level) | |||||
logger = FastNLPLogger(ROOT_NAME) | |||||
logger.propagate = False | |||||
_set_stdout_handler(logger, stdout, level) | |||||
# File Handler | |||||
if path is not None: | |||||
_add_file_handler(logger, path, level) | |||||
return logger | |||||
def parse_level(): | |||||
if os.environ[FASTNLP_LOG_LEVEL] == 'AUTO': | |||||
level = 'WARNING' if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) != 0 else "INFO" | |||||
else: | |||||
level = os.environ[FASTNLP_LOG_LEVEL] | |||||
return level | |||||
logger = _init_logger(path=None, stdout='rich', level=parse_level()) | |||||
logger.debug("The environment variables are as following:") | |||||
logger.debug(os.environ) |
@@ -0,0 +1,300 @@ | |||||
import os | |||||
import tempfile | |||||
import datetime | |||||
from pathlib import Path | |||||
import logging | |||||
import re | |||||
from fastNLP.core.envs.env import FASTNLP_LAUNCH_TIME | |||||
from tests.helpers.utils import magic_argv_env_context | |||||
from fastNLP.core import synchronize_safe_rm | |||||
# 测试 TorchDDPDriver; | |||||
@magic_argv_env_context | |||||
def test_add_file_ddp_1(): | |||||
""" | |||||
测试 path 是一个文件的地址,但是这个文件所在的文件夹存在; | |||||
多卡时根据时间创造文件名字有一个很大的 bug,就是不同的进程启动之间是有时差的,因此会导致他们各自输出到单独的 log 文件中; | |||||
""" | |||||
import torch | |||||
import torch.distributed as dist | |||||
from fastNLP.core.log.logger import logger | |||||
from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) | |||||
driver = TorchDDPDriver( | |||||
model=model, | |||||
parallel_device=[torch.device("cuda:0"), torch.device("cuda:1")], | |||||
output_from_new_proc="all" | |||||
) | |||||
driver.setup() | |||||
msg = 'some test log msg' | |||||
path = Path.cwd() | |||||
filepath = path.joinpath('log.txt') | |||||
handler = logger.add_file(filepath, mode="w") | |||||
logger.info(msg) | |||||
logger.warning(f"\nrank {driver.get_local_rank()} should have this message!\n") | |||||
for h in logger.handlers: | |||||
if isinstance(h, logging.FileHandler): | |||||
h.flush() | |||||
dist.barrier() | |||||
with open(filepath, 'r') as f: | |||||
line = ''.join([l for l in f]) | |||||
assert msg in line | |||||
assert f"\nrank {driver.get_local_rank()} should have this message!\n" in line | |||||
pattern = re.compile(msg) | |||||
assert len(pattern.findall(line)) == 1 | |||||
synchronize_safe_rm(filepath) | |||||
dist.barrier() | |||||
dist.destroy_process_group() | |||||
logger.removeHandler(handler) | |||||
@magic_argv_env_context | |||||
def test_add_file_ddp_2(): | |||||
""" | |||||
测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在; | |||||
""" | |||||
import torch | |||||
import torch.distributed as dist | |||||
from fastNLP.core.log.logger import logger | |||||
from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) | |||||
driver = TorchDDPDriver( | |||||
model=model, | |||||
parallel_device=[torch.device("cuda:0"), torch.device("cuda:1")], | |||||
output_from_new_proc="all" | |||||
) | |||||
driver.setup() | |||||
msg = 'some test log msg' | |||||
origin_path = Path.cwd() | |||||
try: | |||||
path = origin_path.joinpath("not_existed") | |||||
filepath = path.joinpath('log.txt') | |||||
handler = logger.add_file(filepath) | |||||
logger.info(msg) | |||||
logger.warning(f"\nrank {driver.get_local_rank()} should have this message!\n") | |||||
for h in logger.handlers: | |||||
if isinstance(h, logging.FileHandler): | |||||
h.flush() | |||||
dist.barrier() | |||||
with open(filepath, 'r') as f: | |||||
line = ''.join([l for l in f]) | |||||
assert msg in line | |||||
assert f"\nrank {driver.get_local_rank()} should have this message!\n" in line | |||||
pattern = re.compile(msg) | |||||
assert len(pattern.findall(line)) == 1 | |||||
finally: | |||||
synchronize_safe_rm(path) | |||||
logger.removeHandler(handler) | |||||
dist.barrier() | |||||
dist.destroy_process_group() | |||||
@magic_argv_env_context | |||||
def test_add_file_ddp_3(): | |||||
""" | |||||
path = None; | |||||
多卡时根据时间创造文件名字有一个很大的 bug,就是不同的进程启动之间是有时差的,因此会导致他们各自输出到单独的 log 文件中; | |||||
""" | |||||
import torch | |||||
import torch.distributed as dist | |||||
from fastNLP.core.log.logger import logger | |||||
from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) | |||||
driver = TorchDDPDriver( | |||||
model=model, | |||||
parallel_device=[torch.device("cuda:0"), torch.device("cuda:1")], | |||||
output_from_new_proc="all" | |||||
) | |||||
driver.setup() | |||||
msg = 'some test log msg' | |||||
handler = logger.add_file() | |||||
logger.info(msg) | |||||
logger.warning(f"\nrank {driver.get_local_rank()} should have this message!\n") | |||||
for h in logger.handlers: | |||||
if isinstance(h, logging.FileHandler): | |||||
h.flush() | |||||
dist.barrier() | |||||
file = Path.cwd().joinpath(os.environ.get(FASTNLP_LAUNCH_TIME)+".log") | |||||
with open(file, 'r') as f: | |||||
line = ''.join([l for l in f]) | |||||
# print(f"\nrank: {driver.get_local_rank()} line, {line}\n") | |||||
assert msg in line | |||||
assert f"\nrank {driver.get_local_rank()} should have this message!\n" in line | |||||
pattern = re.compile(msg) | |||||
assert len(pattern.findall(line)) == 1 | |||||
synchronize_safe_rm(file) | |||||
dist.barrier() | |||||
dist.destroy_process_group() | |||||
logger.removeHandler(handler) | |||||
@magic_argv_env_context | |||||
def test_add_file_ddp_4(): | |||||
""" | |||||
测试 path 是文件夹; | |||||
""" | |||||
import torch | |||||
import torch.distributed as dist | |||||
from fastNLP.core.log.logger import logger | |||||
from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) | |||||
driver = TorchDDPDriver( | |||||
model=model, | |||||
parallel_device=[torch.device("cuda:0"), torch.device("cuda:1")], | |||||
output_from_new_proc="all" | |||||
) | |||||
driver.setup() | |||||
msg = 'some test log msg' | |||||
path = Path.cwd().joinpath("not_existed") | |||||
try: | |||||
handler = logger.add_file(path) | |||||
logger.info(msg) | |||||
logger.warning(f"\nrank {driver.get_local_rank()} should have this message!\n") | |||||
for h in logger.handlers: | |||||
if isinstance(h, logging.FileHandler): | |||||
h.flush() | |||||
dist.barrier() | |||||
file = path.joinpath(os.environ.get(FASTNLP_LAUNCH_TIME) + ".log") | |||||
with open(file, 'r') as f: | |||||
line = ''.join([l for l in f]) | |||||
assert msg in line | |||||
assert f"\nrank {driver.get_local_rank()} should have this message!\n" in line | |||||
pattern = re.compile(msg) | |||||
assert len(pattern.findall(line)) == 1 | |||||
finally: | |||||
synchronize_safe_rm(path) | |||||
logger.removeHandler(handler) | |||||
dist.barrier() | |||||
dist.destroy_process_group() | |||||
class TestLogger: | |||||
msg = 'some test log msg' | |||||
def test_add_file_1(self): | |||||
""" | |||||
测试 path 是一个文件的地址,但是这个文件所在的文件夹存在; | |||||
""" | |||||
from fastNLP.core.log.logger import logger | |||||
path = Path(tempfile.mkdtemp()) | |||||
try: | |||||
filepath = path.joinpath('log.txt') | |||||
handler = logger.add_file(filepath) | |||||
logger.info(self.msg) | |||||
with open(filepath, 'r') as f: | |||||
line = ''.join([l for l in f]) | |||||
assert self.msg in line | |||||
finally: | |||||
synchronize_safe_rm(path) | |||||
logger.removeHandler(handler) | |||||
def test_add_file_2(self): | |||||
""" | |||||
测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在; | |||||
""" | |||||
from fastNLP.core.log.logger import logger | |||||
origin_path = Path(tempfile.mkdtemp()) | |||||
try: | |||||
path = origin_path.joinpath("not_existed") | |||||
path = path.joinpath('log.txt') | |||||
handler = logger.add_file(path) | |||||
logger.info(self.msg) | |||||
with open(path, 'r') as f: | |||||
line = ''.join([l for l in f]) | |||||
assert self.msg in line | |||||
finally: | |||||
synchronize_safe_rm(origin_path) | |||||
logger.removeHandler(handler) | |||||
def test_add_file_3(self): | |||||
""" | |||||
测试 path 是 None; | |||||
""" | |||||
from fastNLP.core.log.logger import logger | |||||
handler = logger.add_file() | |||||
logger.info(self.msg) | |||||
path = Path.cwd() | |||||
cur_datetime = str(datetime.datetime.now().strftime('%Y-%m-%d')) | |||||
for file in path.iterdir(): | |||||
if file.name.startswith(cur_datetime): | |||||
with open(file, 'r') as f: | |||||
line = ''.join([l for l in f]) | |||||
assert self.msg in line | |||||
file.unlink() | |||||
logger.removeHandler(handler) | |||||
def test_add_file_4(self): | |||||
""" | |||||
测试 path 是文件夹; | |||||
""" | |||||
from fastNLP.core.log.logger import logger | |||||
path = Path(tempfile.mkdtemp()) | |||||
try: | |||||
handler = logger.add_file(path) | |||||
logger.info(self.msg) | |||||
cur_datetime = str(datetime.datetime.now().strftime('%Y-%m-%d')) | |||||
for file in path.iterdir(): | |||||
if file.name.startswith(cur_datetime): | |||||
with open(file, 'r') as f: | |||||
line = ''.join([l for l in f]) | |||||
assert self.msg in line | |||||
finally: | |||||
synchronize_safe_rm(path) | |||||
logger.removeHandler(handler) | |||||
def test_stdout(self, capsys): | |||||
from fastNLP.core.log.logger import logger | |||||
handler = logger.set_stdout(stdout="raw") | |||||
logger.info(self.msg) | |||||
logger.debug('aabbc') | |||||
captured = capsys.readouterr() | |||||
assert "some test log msg\n" == captured.out | |||||
logger.removeHandler(handler) | |||||