From e85cbb067eb83a3c739cfd443f589f67806080fc Mon Sep 17 00:00:00 2001 From: yh_cc Date: Sun, 24 Apr 2022 12:09:19 +0800 Subject: [PATCH] =?UTF-8?q?Rich=E6=94=AF=E6=8C=81jupyter?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callbacks/callback_manager.py | 13 +-------- fastNLP/core/callbacks/progress_callback.py | 32 +++++++++++++++------ fastNLP/core/controllers/evaluator.py | 2 +- fastNLP/core/controllers/trainer.py | 31 ++++++++++++++------ fastNLP/core/utils/rich_progress.py | 23 ++++++++++++++- fastNLP/core/utils/utils.py | 21 +++++++++++++- 6 files changed, 91 insertions(+), 31 deletions(-) diff --git a/fastNLP/core/callbacks/callback_manager.py b/fastNLP/core/callbacks/callback_manager.py index 4aa822ad..c5b00e71 100644 --- a/fastNLP/core/callbacks/callback_manager.py +++ b/fastNLP/core/callbacks/callback_manager.py @@ -8,7 +8,6 @@ __all__ = [ from .callback_events import Events from .callback import Callback -from .progress_callback import ProgressCallback, choose_progress_callback from fastNLP.core.log import logger @@ -35,7 +34,7 @@ class CallbackManager: class_callbacks: Optional[List[Callback]] # 用来保留原始的类callback; callback_fns: dict - def __init__(self, callbacks: Optional[List[Callback]], progress_bar='auto'): + def __init__(self, callbacks: Optional[List[Callback]]): r""" 注意 callback 的调用顺序: 1. 通过函数修饰器 `Trainer.on` 添加的 callback 函数; @@ -46,7 +45,6 @@ class CallbackManager: """ self._need_reproducible_sampler = False - _has_progress_callback = False _callbacks = [] if callbacks is not None: if isinstance(callbacks, Callback): @@ -57,16 +55,7 @@ class CallbackManager: for _callback in callbacks: if not isinstance(_callback, Callback): raise TypeError(f"callbacks must be of Callback type, instead of `{type(_callback)}`") - if isinstance(_callback, ProgressCallback): - _has_progress_callback = True _callbacks += callbacks - if not _has_progress_callback: - # 添加 progress callback - progress_callback = choose_progress_callback(progress_bar=progress_bar) - if progress_callback is None: - logger.info("There is no progress bar, Trainer will not output training progress.") - else: - _callbacks.append(progress_callback) self.callback_fns = defaultdict(list) # 因为理论上用户最多只能通过 'trainer.on_train_begin' 或者 'trainer.callback_manager.on_train_begin' 来调用,即其是没办法 # 直接调用具体的某一个 callback 函数,而不调用其余的同名的 callback 函数的,因此我们只需要记录具体 Event 的时机即可; diff --git a/fastNLP/core/callbacks/progress_callback.py b/fastNLP/core/callbacks/progress_callback.py index 64d72bd0..a6f82896 100644 --- a/fastNLP/core/callbacks/progress_callback.py +++ b/fastNLP/core/callbacks/progress_callback.py @@ -1,6 +1,6 @@ import json import sys - +from typing import Union __all__ = [ 'choose_progress_callback', @@ -11,11 +11,22 @@ __all__ = [ from .has_monitor_callback import HasMonitorCallback from fastNLP.core.utils import f_rich_progress from fastNLP.core.log import logger +from fastNLP.core.utils.utils import is_notebook + + + +class ProgressCallback(HasMonitorCallback): + def on_train_end(self, trainer): + f_rich_progress.stop() + + @property + def name(self): # progress bar的名称 + return 'auto' -def choose_progress_callback(progress_bar:str): +def choose_progress_callback(progress_bar: Union[str, ProgressCallback]) -> ProgressCallback: if progress_bar == 'auto': - if (sys.stdin and sys.stdin.isatty()): + if not f_rich_progress.dummy_rich: progress_bar = 'rich' else: progress_bar = 'raw' @@ -23,15 +34,12 @@ def choose_progress_callback(progress_bar:str): return RichCallback() elif progress_bar == 'raw': return RawTextCallback() + elif isinstance(progress_bar, ProgressCallback): + return progress_bar else: return None -class ProgressCallback(HasMonitorCallback): - def on_train_end(self, trainer): - f_rich_progress.stop() - - class RichCallback(ProgressCallback): def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True, format_json=True): @@ -124,6 +132,10 @@ class RichCallback(ProgressCallback): self.task2id = {} self.loss = 0 + @property + def name(self): # progress bar的名称 + return 'rich' + class RawTextCallback(ProgressCallback): def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True, @@ -184,3 +196,7 @@ class RawTextCallback(ProgressCallback): logger.info(json.dumps(trainer.driver.tensor_to_numeric(results))) else: logger.info(results) + + @property + def name(self): # progress bar的名称 + return 'raw' \ No newline at end of file diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py index 60703ef5..95379302 100644 --- a/fastNLP/core/controllers/evaluator.py +++ b/fastNLP/core/controllers/evaluator.py @@ -134,7 +134,7 @@ class Evaluator: self.progress_bar = kwargs.get('progress_bar', 'auto') if self.progress_bar == 'auto': - self.progress_bar = 'rich' if (sys.stdin and sys.stdin.isatty()) else 'raw' + self.progress_bar = 'raw' if f_rich_progress.dummy_rich else 'rich' self.driver.barrier() diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index e4cd2817..7eb5bbac 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -20,6 +20,7 @@ from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _Truncated from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList from fastNLP.core.callbacks.callback import _CallbackWrapper from fastNLP.core.callbacks.callback_events import _SingleEventState +from fastNLP.core.callbacks.progress_callback import choose_progress_callback from fastNLP.core.drivers import Driver from fastNLP.core.drivers.utils import choose_driver from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext @@ -125,14 +126,13 @@ class Trainer(TrainerEventTrigger): set_grad_to_none: 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; use_dist_sampler: 表示是否使用分布式的 sampler 。在多卡时,分布式 sampler 将自动决定每张卡上读取的 sample ,使得一个epoch 内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。 - eval_use_dist_sampler: 表示在 Evaluator 中在使用 TorchDDPDriver 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True; + evaluate_use_dist_sampler: 表示在 Evaluator 中在使用 分布式 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True; output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: ["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; - progress_bar: 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'],默认为 auto 。progress 的实现是通过 - callback 实现的,若在输入的 callback 中检测到了 ProgressCallback 类型的 callback ,则该参数对 Trainer 无效。 - auto 表示如果检测到当前 terminal 为交互型 则使用 rich,否则使用 raw。 - + progress_bar: 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'] 或者 RichCallback, RawTextCallback对象, + 默认为 auto , auto 表示如果检测到当前 terminal 为交互型 则使用 RichCallback,否则使用 RawTextCallback对象。如果 + 需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 对象。 """ self.model = model self.marker = marker @@ -195,8 +195,20 @@ class Trainer(TrainerEventTrigger): ) self.driver.set_optimizers(optimizers=optimizers) + # 根据 progress_bar 参数选择 ProgressBarCallback + progress_bar_callback = choose_progress_callback(kwargs.get('progress_bar', 'auto')) + if progress_bar_callback is not None: + if callbacks is None: + callbacks = [] + elif not isinstance(callbacks, Sequence): + callbacks = [callbacks] + + callbacks = list(callbacks) + [progress_bar_callback] + else: + rank_zero_call(logger.warning)("No progress bar is provided, there will have no information output " + "during training.") # 初始化 callback manager; - self.callback_manager = CallbackManager(callbacks, kwargs.get('progress_bar', 'auto')) + self.callback_manager = CallbackManager(callbacks) # 添加所有的函数式 callbacks; self._fetch_matched_fn_callbacks() # 添加所有的类 callbacks; @@ -237,6 +249,9 @@ class Trainer(TrainerEventTrigger): self.larger_better = larger_better if metrics is not None and evaluate_dataloaders is not None: check_evaluate_every(evaluate_every) + progress_bar = kwargs.get('progress_bar', 'auto') # 如果不为 + if not (isinstance(progress_bar, str) or progress_bar is None): # 应该是ProgressCallback,获取其名称。 + progress_bar = progress_bar.name self.evaluator = Evaluator( model=model, dataloaders=evaluate_dataloaders, @@ -249,8 +264,8 @@ class Trainer(TrainerEventTrigger): output_mapping=output_mapping, fp16=fp16, verbose=0, - use_dist_sampler=kwargs.get("eval_use_dist_sampler", None), - progress_bar=kwargs.get('progress_bar', 'auto') + use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", None), + progress_bar=progress_bar ) if train_fn is not None and not isinstance(train_fn, str): diff --git a/fastNLP/core/utils/rich_progress.py b/fastNLP/core/utils/rich_progress.py index 82747a01..e7b95d9c 100644 --- a/fastNLP/core/utils/rich_progress.py +++ b/fastNLP/core/utils/rich_progress.py @@ -14,6 +14,7 @@ __all__ = [ ] from fastNLP.envs import get_global_rank +from .utils import is_notebook class Singleton(type): @@ -34,6 +35,14 @@ class DummyFRichProgress: # 防止用户通过 DummyFRichProgress.console.print() 这种调用 return None + @property + def dummy_rich(self)->bool: + """ + 当前对象是否是 dummy 的 rich 对象。 + + :return: + """ + return True class FRichProgress(Progress, metaclass=Singleton): """ @@ -147,6 +156,8 @@ class FRichProgress(Progress, metaclass=Singleton): super().stop_task(task_id) super().remove_task(task_id) self.refresh() # 使得bar不残留 + if len(self._tasks) == 0: + super().stop() def start(self) -> None: super().start() @@ -210,6 +221,15 @@ class FRichProgress(Progress, metaclass=Singleton): if refresh: self.refresh() + @property + def dummy_rich(self) -> bool: + """ + 当前对象是否是 dummy 的 rich 对象。 + + :return: + """ + return False + class SpeedColumn(ProgressColumn): """ @@ -226,7 +246,8 @@ class SpeedColumn(ProgressColumn): return Text(str(round(1/speed, 2))+' s/it.', style='progress.data.speed') -if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0: +if ((sys.stdin and sys.stdin.isatty()) or is_notebook()) and \ + get_global_rank() == 0: f_rich_progress = FRichProgress().new_progess( "[progress.description]{task.description}", "[progress.percentage]{task.percentage:>3.0f}%", diff --git a/fastNLP/core/utils/utils.py b/fastNLP/core/utils/utils.py index ff3386fe..c3f57bcf 100644 --- a/fastNLP/core/utils/utils.py +++ b/fastNLP/core/utils/utils.py @@ -696,4 +696,23 @@ def get_class_that_defined_method(method): None) if isinstance(cls, type): return cls - return getattr(method, '__objclass__', None) # handle special descriptor objects \ No newline at end of file + return getattr(method, '__objclass__', None) # handle special descriptor objects + + +def is_notebook(): + """ + 检查当前运行环境是否为 jupyter + + :return: + """ + try: + from IPython import get_ipython + + if "IPKernelApp" not in get_ipython().config: # pragma: no cover + raise ImportError("console") + if "VSCODE_PID" in os.environ: # pragma: no cover + raise ImportError("vscode") + except: + return False + else: # pragma: no cover + return True \ No newline at end of file