Browse Source

Rich支持jupyter

tags/v1.0.0alpha
yh_cc 3 years ago
parent
commit
e85cbb067e
6 changed files with 91 additions and 31 deletions
  1. +1
    -12
      fastNLP/core/callbacks/callback_manager.py
  2. +24
    -8
      fastNLP/core/callbacks/progress_callback.py
  3. +1
    -1
      fastNLP/core/controllers/evaluator.py
  4. +23
    -8
      fastNLP/core/controllers/trainer.py
  5. +22
    -1
      fastNLP/core/utils/rich_progress.py
  6. +20
    -1
      fastNLP/core/utils/utils.py

+ 1
- 12
fastNLP/core/callbacks/callback_manager.py View File

@@ -8,7 +8,6 @@ __all__ = [

from .callback_events import Events
from .callback import Callback
from .progress_callback import ProgressCallback, choose_progress_callback
from fastNLP.core.log import logger


@@ -35,7 +34,7 @@ class CallbackManager:
class_callbacks: Optional[List[Callback]] # 用来保留原始的类callback;
callback_fns: dict

def __init__(self, callbacks: Optional[List[Callback]], progress_bar='auto'):
def __init__(self, callbacks: Optional[List[Callback]]):
r"""
注意 callback 的调用顺序:
1. 通过函数修饰器 `Trainer.on` 添加的 callback 函数;
@@ -46,7 +45,6 @@ class CallbackManager:
"""
self._need_reproducible_sampler = False

_has_progress_callback = False
_callbacks = []
if callbacks is not None:
if isinstance(callbacks, Callback):
@@ -57,16 +55,7 @@ class CallbackManager:
for _callback in callbacks:
if not isinstance(_callback, Callback):
raise TypeError(f"callbacks must be of Callback type, instead of `{type(_callback)}`")
if isinstance(_callback, ProgressCallback):
_has_progress_callback = True
_callbacks += callbacks
if not _has_progress_callback:
# 添加 progress callback
progress_callback = choose_progress_callback(progress_bar=progress_bar)
if progress_callback is None:
logger.info("There is no progress bar, Trainer will not output training progress.")
else:
_callbacks.append(progress_callback)
self.callback_fns = defaultdict(list)
# 因为理论上用户最多只能通过 'trainer.on_train_begin' 或者 'trainer.callback_manager.on_train_begin' 来调用,即其是没办法
# 直接调用具体的某一个 callback 函数,而不调用其余的同名的 callback 函数的,因此我们只需要记录具体 Event 的时机即可;


+ 24
- 8
fastNLP/core/callbacks/progress_callback.py View File

@@ -1,6 +1,6 @@
import json
import sys
from typing import Union

__all__ = [
'choose_progress_callback',
@@ -11,11 +11,22 @@ __all__ = [
from .has_monitor_callback import HasMonitorCallback
from fastNLP.core.utils import f_rich_progress
from fastNLP.core.log import logger
from fastNLP.core.utils.utils import is_notebook



class ProgressCallback(HasMonitorCallback):
def on_train_end(self, trainer):
f_rich_progress.stop()

@property
def name(self): # progress bar的名称
return 'auto'


def choose_progress_callback(progress_bar:str):
def choose_progress_callback(progress_bar: Union[str, ProgressCallback]) -> ProgressCallback:
if progress_bar == 'auto':
if (sys.stdin and sys.stdin.isatty()):
if not f_rich_progress.dummy_rich:
progress_bar = 'rich'
else:
progress_bar = 'raw'
@@ -23,15 +34,12 @@ def choose_progress_callback(progress_bar:str):
return RichCallback()
elif progress_bar == 'raw':
return RawTextCallback()
elif isinstance(progress_bar, ProgressCallback):
return progress_bar
else:
return None


class ProgressCallback(HasMonitorCallback):
def on_train_end(self, trainer):
f_rich_progress.stop()


class RichCallback(ProgressCallback):
def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True,
format_json=True):
@@ -124,6 +132,10 @@ class RichCallback(ProgressCallback):
self.task2id = {}
self.loss = 0

@property
def name(self): # progress bar的名称
return 'rich'


class RawTextCallback(ProgressCallback):
def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True,
@@ -184,3 +196,7 @@ class RawTextCallback(ProgressCallback):
logger.info(json.dumps(trainer.driver.tensor_to_numeric(results)))
else:
logger.info(results)

@property
def name(self): # progress bar的名称
return 'raw'

+ 1
- 1
fastNLP/core/controllers/evaluator.py View File

@@ -134,7 +134,7 @@ class Evaluator:

self.progress_bar = kwargs.get('progress_bar', 'auto')
if self.progress_bar == 'auto':
self.progress_bar = 'rich' if (sys.stdin and sys.stdin.isatty()) else 'raw'
self.progress_bar = 'raw' if f_rich_progress.dummy_rich else 'rich'

self.driver.barrier()



+ 23
- 8
fastNLP/core/controllers/trainer.py View File

@@ -20,6 +20,7 @@ from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _Truncated
from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList
from fastNLP.core.callbacks.callback import _CallbackWrapper
from fastNLP.core.callbacks.callback_events import _SingleEventState
from fastNLP.core.callbacks.progress_callback import choose_progress_callback
from fastNLP.core.drivers import Driver
from fastNLP.core.drivers.utils import choose_driver
from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext
@@ -125,14 +126,13 @@ class Trainer(TrainerEventTrigger):
set_grad_to_none: 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None;
use_dist_sampler: 表示是否使用分布式的 sampler 。在多卡时,分布式 sampler 将自动决定每张卡上读取的 sample ,使得一个epoch
内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。
eval_use_dist_sampler: 表示在 Evaluator 中在使用 TorchDDPDriver 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True;
evaluate_use_dist_sampler: 表示在 Evaluator 中在使用 分布式 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True;
output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一:
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error";
progress_bar: 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'],默认为 auto 。progress 的实现是通过
callback 实现的,若在输入的 callback 中检测到了 ProgressCallback 类型的 callback ,则该参数对 Trainer 无效。
auto 表示如果检测到当前 terminal 为交互型 则使用 rich,否则使用 raw。

progress_bar: 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'] 或者 RichCallback, RawTextCallback对象,
默认为 auto , auto 表示如果检测到当前 terminal 为交互型 则使用 RichCallback,否则使用 RawTextCallback对象。如果
需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 对象。
"""
self.model = model
self.marker = marker
@@ -195,8 +195,20 @@ class Trainer(TrainerEventTrigger):
)
self.driver.set_optimizers(optimizers=optimizers)

# 根据 progress_bar 参数选择 ProgressBarCallback
progress_bar_callback = choose_progress_callback(kwargs.get('progress_bar', 'auto'))
if progress_bar_callback is not None:
if callbacks is None:
callbacks = []
elif not isinstance(callbacks, Sequence):
callbacks = [callbacks]

callbacks = list(callbacks) + [progress_bar_callback]
else:
rank_zero_call(logger.warning)("No progress bar is provided, there will have no information output "
"during training.")
# 初始化 callback manager;
self.callback_manager = CallbackManager(callbacks, kwargs.get('progress_bar', 'auto'))
self.callback_manager = CallbackManager(callbacks)
# 添加所有的函数式 callbacks;
self._fetch_matched_fn_callbacks()
# 添加所有的类 callbacks;
@@ -237,6 +249,9 @@ class Trainer(TrainerEventTrigger):
self.larger_better = larger_better
if metrics is not None and evaluate_dataloaders is not None:
check_evaluate_every(evaluate_every)
progress_bar = kwargs.get('progress_bar', 'auto') # 如果不为
if not (isinstance(progress_bar, str) or progress_bar is None): # 应该是ProgressCallback,获取其名称。
progress_bar = progress_bar.name
self.evaluator = Evaluator(
model=model,
dataloaders=evaluate_dataloaders,
@@ -249,8 +264,8 @@ class Trainer(TrainerEventTrigger):
output_mapping=output_mapping,
fp16=fp16,
verbose=0,
use_dist_sampler=kwargs.get("eval_use_dist_sampler", None),
progress_bar=kwargs.get('progress_bar', 'auto')
use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", None),
progress_bar=progress_bar
)

if train_fn is not None and not isinstance(train_fn, str):


+ 22
- 1
fastNLP/core/utils/rich_progress.py View File

@@ -14,6 +14,7 @@ __all__ = [
]

from fastNLP.envs import get_global_rank
from .utils import is_notebook


class Singleton(type):
@@ -34,6 +35,14 @@ class DummyFRichProgress:
# 防止用户通过 DummyFRichProgress.console.print() 这种调用
return None

@property
def dummy_rich(self)->bool:
"""
当前对象是否是 dummy 的 rich 对象。

:return:
"""
return True

class FRichProgress(Progress, metaclass=Singleton):
"""
@@ -147,6 +156,8 @@ class FRichProgress(Progress, metaclass=Singleton):
super().stop_task(task_id)
super().remove_task(task_id)
self.refresh() # 使得bar不残留
if len(self._tasks) == 0:
super().stop()

def start(self) -> None:
super().start()
@@ -210,6 +221,15 @@ class FRichProgress(Progress, metaclass=Singleton):
if refresh:
self.refresh()

@property
def dummy_rich(self) -> bool:
"""
当前对象是否是 dummy 的 rich 对象。

:return:
"""
return False


class SpeedColumn(ProgressColumn):
"""
@@ -226,7 +246,8 @@ class SpeedColumn(ProgressColumn):
return Text(str(round(1/speed, 2))+' s/it.', style='progress.data.speed')


if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0:
if ((sys.stdin and sys.stdin.isatty()) or is_notebook()) and \
get_global_rank() == 0:
f_rich_progress = FRichProgress().new_progess(
"[progress.description]{task.description}",
"[progress.percentage]{task.percentage:>3.0f}%",


+ 20
- 1
fastNLP/core/utils/utils.py View File

@@ -696,4 +696,23 @@ def get_class_that_defined_method(method):
None)
if isinstance(cls, type):
return cls
return getattr(method, '__objclass__', None) # handle special descriptor objects
return getattr(method, '__objclass__', None) # handle special descriptor objects


def is_notebook():
"""
检查当前运行环境是否为 jupyter

:return:
"""
try:
from IPython import get_ipython

if "IPKernelApp" not in get_ipython().config: # pragma: no cover
raise ImportError("console")
if "VSCODE_PID" in os.environ: # pragma: no cover
raise ImportError("vscode")
except:
return False
else: # pragma: no cover
return True

Loading…
Cancel
Save