Browse Source

Rich支持jupyter

tags/v1.0.0alpha
yh_cc 3 years ago
parent
commit
e85cbb067e
6 changed files with 91 additions and 31 deletions
  1. +1
    -12
      fastNLP/core/callbacks/callback_manager.py
  2. +24
    -8
      fastNLP/core/callbacks/progress_callback.py
  3. +1
    -1
      fastNLP/core/controllers/evaluator.py
  4. +23
    -8
      fastNLP/core/controllers/trainer.py
  5. +22
    -1
      fastNLP/core/utils/rich_progress.py
  6. +20
    -1
      fastNLP/core/utils/utils.py

+ 1
- 12
fastNLP/core/callbacks/callback_manager.py View File

@@ -8,7 +8,6 @@ __all__ = [


from .callback_events import Events from .callback_events import Events
from .callback import Callback from .callback import Callback
from .progress_callback import ProgressCallback, choose_progress_callback
from fastNLP.core.log import logger from fastNLP.core.log import logger




@@ -35,7 +34,7 @@ class CallbackManager:
class_callbacks: Optional[List[Callback]] # 用来保留原始的类callback; class_callbacks: Optional[List[Callback]] # 用来保留原始的类callback;
callback_fns: dict callback_fns: dict


def __init__(self, callbacks: Optional[List[Callback]], progress_bar='auto'):
def __init__(self, callbacks: Optional[List[Callback]]):
r""" r"""
注意 callback 的调用顺序: 注意 callback 的调用顺序:
1. 通过函数修饰器 `Trainer.on` 添加的 callback 函数; 1. 通过函数修饰器 `Trainer.on` 添加的 callback 函数;
@@ -46,7 +45,6 @@ class CallbackManager:
""" """
self._need_reproducible_sampler = False self._need_reproducible_sampler = False


_has_progress_callback = False
_callbacks = [] _callbacks = []
if callbacks is not None: if callbacks is not None:
if isinstance(callbacks, Callback): if isinstance(callbacks, Callback):
@@ -57,16 +55,7 @@ class CallbackManager:
for _callback in callbacks: for _callback in callbacks:
if not isinstance(_callback, Callback): if not isinstance(_callback, Callback):
raise TypeError(f"callbacks must be of Callback type, instead of `{type(_callback)}`") raise TypeError(f"callbacks must be of Callback type, instead of `{type(_callback)}`")
if isinstance(_callback, ProgressCallback):
_has_progress_callback = True
_callbacks += callbacks _callbacks += callbacks
if not _has_progress_callback:
# 添加 progress callback
progress_callback = choose_progress_callback(progress_bar=progress_bar)
if progress_callback is None:
logger.info("There is no progress bar, Trainer will not output training progress.")
else:
_callbacks.append(progress_callback)
self.callback_fns = defaultdict(list) self.callback_fns = defaultdict(list)
# 因为理论上用户最多只能通过 'trainer.on_train_begin' 或者 'trainer.callback_manager.on_train_begin' 来调用,即其是没办法 # 因为理论上用户最多只能通过 'trainer.on_train_begin' 或者 'trainer.callback_manager.on_train_begin' 来调用,即其是没办法
# 直接调用具体的某一个 callback 函数,而不调用其余的同名的 callback 函数的,因此我们只需要记录具体 Event 的时机即可; # 直接调用具体的某一个 callback 函数,而不调用其余的同名的 callback 函数的,因此我们只需要记录具体 Event 的时机即可;


+ 24
- 8
fastNLP/core/callbacks/progress_callback.py View File

@@ -1,6 +1,6 @@
import json import json
import sys import sys
from typing import Union


__all__ = [ __all__ = [
'choose_progress_callback', 'choose_progress_callback',
@@ -11,11 +11,22 @@ __all__ = [
from .has_monitor_callback import HasMonitorCallback from .has_monitor_callback import HasMonitorCallback
from fastNLP.core.utils import f_rich_progress from fastNLP.core.utils import f_rich_progress
from fastNLP.core.log import logger from fastNLP.core.log import logger
from fastNLP.core.utils.utils import is_notebook



class ProgressCallback(HasMonitorCallback):
def on_train_end(self, trainer):
f_rich_progress.stop()

@property
def name(self): # progress bar的名称
return 'auto'




def choose_progress_callback(progress_bar:str):
def choose_progress_callback(progress_bar: Union[str, ProgressCallback]) -> ProgressCallback:
if progress_bar == 'auto': if progress_bar == 'auto':
if (sys.stdin and sys.stdin.isatty()):
if not f_rich_progress.dummy_rich:
progress_bar = 'rich' progress_bar = 'rich'
else: else:
progress_bar = 'raw' progress_bar = 'raw'
@@ -23,15 +34,12 @@ def choose_progress_callback(progress_bar:str):
return RichCallback() return RichCallback()
elif progress_bar == 'raw': elif progress_bar == 'raw':
return RawTextCallback() return RawTextCallback()
elif isinstance(progress_bar, ProgressCallback):
return progress_bar
else: else:
return None return None




class ProgressCallback(HasMonitorCallback):
def on_train_end(self, trainer):
f_rich_progress.stop()


class RichCallback(ProgressCallback): class RichCallback(ProgressCallback):
def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True, def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True,
format_json=True): format_json=True):
@@ -124,6 +132,10 @@ class RichCallback(ProgressCallback):
self.task2id = {} self.task2id = {}
self.loss = 0 self.loss = 0


@property
def name(self): # progress bar的名称
return 'rich'



class RawTextCallback(ProgressCallback): class RawTextCallback(ProgressCallback):
def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True, def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True,
@@ -184,3 +196,7 @@ class RawTextCallback(ProgressCallback):
logger.info(json.dumps(trainer.driver.tensor_to_numeric(results))) logger.info(json.dumps(trainer.driver.tensor_to_numeric(results)))
else: else:
logger.info(results) logger.info(results)

@property
def name(self): # progress bar的名称
return 'raw'

+ 1
- 1
fastNLP/core/controllers/evaluator.py View File

@@ -134,7 +134,7 @@ class Evaluator:


self.progress_bar = kwargs.get('progress_bar', 'auto') self.progress_bar = kwargs.get('progress_bar', 'auto')
if self.progress_bar == 'auto': if self.progress_bar == 'auto':
self.progress_bar = 'rich' if (sys.stdin and sys.stdin.isatty()) else 'raw'
self.progress_bar = 'raw' if f_rich_progress.dummy_rich else 'rich'


self.driver.barrier() self.driver.barrier()




+ 23
- 8
fastNLP/core/controllers/trainer.py View File

@@ -20,6 +20,7 @@ from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _Truncated
from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList
from fastNLP.core.callbacks.callback import _CallbackWrapper from fastNLP.core.callbacks.callback import _CallbackWrapper
from fastNLP.core.callbacks.callback_events import _SingleEventState from fastNLP.core.callbacks.callback_events import _SingleEventState
from fastNLP.core.callbacks.progress_callback import choose_progress_callback
from fastNLP.core.drivers import Driver from fastNLP.core.drivers import Driver
from fastNLP.core.drivers.utils import choose_driver from fastNLP.core.drivers.utils import choose_driver
from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext
@@ -125,14 +126,13 @@ class Trainer(TrainerEventTrigger):
set_grad_to_none: 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; set_grad_to_none: 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None;
use_dist_sampler: 表示是否使用分布式的 sampler 。在多卡时,分布式 sampler 将自动决定每张卡上读取的 sample ,使得一个epoch use_dist_sampler: 表示是否使用分布式的 sampler 。在多卡时,分布式 sampler 将自动决定每张卡上读取的 sample ,使得一个epoch
内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。 内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。
eval_use_dist_sampler: 表示在 Evaluator 中在使用 TorchDDPDriver 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True;
evaluate_use_dist_sampler: 表示在 Evaluator 中在使用 分布式 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True;
output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一:
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 ["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error";
progress_bar: 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'],默认为 auto 。progress 的实现是通过
callback 实现的,若在输入的 callback 中检测到了 ProgressCallback 类型的 callback ,则该参数对 Trainer 无效。
auto 表示如果检测到当前 terminal 为交互型 则使用 rich,否则使用 raw。

progress_bar: 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'] 或者 RichCallback, RawTextCallback对象,
默认为 auto , auto 表示如果检测到当前 terminal 为交互型 则使用 RichCallback,否则使用 RawTextCallback对象。如果
需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 对象。
""" """
self.model = model self.model = model
self.marker = marker self.marker = marker
@@ -195,8 +195,20 @@ class Trainer(TrainerEventTrigger):
) )
self.driver.set_optimizers(optimizers=optimizers) self.driver.set_optimizers(optimizers=optimizers)


# 根据 progress_bar 参数选择 ProgressBarCallback
progress_bar_callback = choose_progress_callback(kwargs.get('progress_bar', 'auto'))
if progress_bar_callback is not None:
if callbacks is None:
callbacks = []
elif not isinstance(callbacks, Sequence):
callbacks = [callbacks]

callbacks = list(callbacks) + [progress_bar_callback]
else:
rank_zero_call(logger.warning)("No progress bar is provided, there will have no information output "
"during training.")
# 初始化 callback manager; # 初始化 callback manager;
self.callback_manager = CallbackManager(callbacks, kwargs.get('progress_bar', 'auto'))
self.callback_manager = CallbackManager(callbacks)
# 添加所有的函数式 callbacks; # 添加所有的函数式 callbacks;
self._fetch_matched_fn_callbacks() self._fetch_matched_fn_callbacks()
# 添加所有的类 callbacks; # 添加所有的类 callbacks;
@@ -237,6 +249,9 @@ class Trainer(TrainerEventTrigger):
self.larger_better = larger_better self.larger_better = larger_better
if metrics is not None and evaluate_dataloaders is not None: if metrics is not None and evaluate_dataloaders is not None:
check_evaluate_every(evaluate_every) check_evaluate_every(evaluate_every)
progress_bar = kwargs.get('progress_bar', 'auto') # 如果不为
if not (isinstance(progress_bar, str) or progress_bar is None): # 应该是ProgressCallback,获取其名称。
progress_bar = progress_bar.name
self.evaluator = Evaluator( self.evaluator = Evaluator(
model=model, model=model,
dataloaders=evaluate_dataloaders, dataloaders=evaluate_dataloaders,
@@ -249,8 +264,8 @@ class Trainer(TrainerEventTrigger):
output_mapping=output_mapping, output_mapping=output_mapping,
fp16=fp16, fp16=fp16,
verbose=0, verbose=0,
use_dist_sampler=kwargs.get("eval_use_dist_sampler", None),
progress_bar=kwargs.get('progress_bar', 'auto')
use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", None),
progress_bar=progress_bar
) )


if train_fn is not None and not isinstance(train_fn, str): if train_fn is not None and not isinstance(train_fn, str):


+ 22
- 1
fastNLP/core/utils/rich_progress.py View File

@@ -14,6 +14,7 @@ __all__ = [
] ]


from fastNLP.envs import get_global_rank from fastNLP.envs import get_global_rank
from .utils import is_notebook




class Singleton(type): class Singleton(type):
@@ -34,6 +35,14 @@ class DummyFRichProgress:
# 防止用户通过 DummyFRichProgress.console.print() 这种调用 # 防止用户通过 DummyFRichProgress.console.print() 这种调用
return None return None


@property
def dummy_rich(self)->bool:
"""
当前对象是否是 dummy 的 rich 对象。

:return:
"""
return True


class FRichProgress(Progress, metaclass=Singleton): class FRichProgress(Progress, metaclass=Singleton):
""" """
@@ -147,6 +156,8 @@ class FRichProgress(Progress, metaclass=Singleton):
super().stop_task(task_id) super().stop_task(task_id)
super().remove_task(task_id) super().remove_task(task_id)
self.refresh() # 使得bar不残留 self.refresh() # 使得bar不残留
if len(self._tasks) == 0:
super().stop()


def start(self) -> None: def start(self) -> None:
super().start() super().start()
@@ -210,6 +221,15 @@ class FRichProgress(Progress, metaclass=Singleton):
if refresh: if refresh:
self.refresh() self.refresh()


@property
def dummy_rich(self) -> bool:
"""
当前对象是否是 dummy 的 rich 对象。

:return:
"""
return False



class SpeedColumn(ProgressColumn): class SpeedColumn(ProgressColumn):
""" """
@@ -226,7 +246,8 @@ class SpeedColumn(ProgressColumn):
return Text(str(round(1/speed, 2))+' s/it.', style='progress.data.speed') return Text(str(round(1/speed, 2))+' s/it.', style='progress.data.speed')




if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0:
if ((sys.stdin and sys.stdin.isatty()) or is_notebook()) and \
get_global_rank() == 0:
f_rich_progress = FRichProgress().new_progess( f_rich_progress = FRichProgress().new_progess(
"[progress.description]{task.description}", "[progress.description]{task.description}",
"[progress.percentage]{task.percentage:>3.0f}%", "[progress.percentage]{task.percentage:>3.0f}%",


+ 20
- 1
fastNLP/core/utils/utils.py View File

@@ -696,4 +696,23 @@ def get_class_that_defined_method(method):
None) None)
if isinstance(cls, type): if isinstance(cls, type):
return cls return cls
return getattr(method, '__objclass__', None) # handle special descriptor objects
return getattr(method, '__objclass__', None) # handle special descriptor objects


def is_notebook():
"""
检查当前运行环境是否为 jupyter

:return:
"""
try:
from IPython import get_ipython

if "IPKernelApp" not in get_ipython().config: # pragma: no cover
raise ImportError("console")
if "VSCODE_PID" in os.environ: # pragma: no cover
raise ImportError("vscode")
except:
return False
else: # pragma: no cover
return True

Loading…
Cancel
Save