@@ -8,7 +8,6 @@ __all__ = [ | |||
from .callback_events import Events | |||
from .callback import Callback | |||
from .progress_callback import ProgressCallback, choose_progress_callback | |||
from fastNLP.core.log import logger | |||
@@ -35,7 +34,7 @@ class CallbackManager: | |||
class_callbacks: Optional[List[Callback]] # 用来保留原始的类callback; | |||
callback_fns: dict | |||
def __init__(self, callbacks: Optional[List[Callback]], progress_bar='auto'): | |||
def __init__(self, callbacks: Optional[List[Callback]]): | |||
r""" | |||
注意 callback 的调用顺序: | |||
1. 通过函数修饰器 `Trainer.on` 添加的 callback 函数; | |||
@@ -46,7 +45,6 @@ class CallbackManager: | |||
""" | |||
self._need_reproducible_sampler = False | |||
_has_progress_callback = False | |||
_callbacks = [] | |||
if callbacks is not None: | |||
if isinstance(callbacks, Callback): | |||
@@ -57,16 +55,7 @@ class CallbackManager: | |||
for _callback in callbacks: | |||
if not isinstance(_callback, Callback): | |||
raise TypeError(f"callbacks must be of Callback type, instead of `{type(_callback)}`") | |||
if isinstance(_callback, ProgressCallback): | |||
_has_progress_callback = True | |||
_callbacks += callbacks | |||
if not _has_progress_callback: | |||
# 添加 progress callback | |||
progress_callback = choose_progress_callback(progress_bar=progress_bar) | |||
if progress_callback is None: | |||
logger.info("There is no progress bar, Trainer will not output training progress.") | |||
else: | |||
_callbacks.append(progress_callback) | |||
self.callback_fns = defaultdict(list) | |||
# 因为理论上用户最多只能通过 'trainer.on_train_begin' 或者 'trainer.callback_manager.on_train_begin' 来调用,即其是没办法 | |||
# 直接调用具体的某一个 callback 函数,而不调用其余的同名的 callback 函数的,因此我们只需要记录具体 Event 的时机即可; | |||
@@ -1,6 +1,6 @@ | |||
import json | |||
import sys | |||
from typing import Union | |||
__all__ = [ | |||
'choose_progress_callback', | |||
@@ -11,11 +11,22 @@ __all__ = [ | |||
from .has_monitor_callback import HasMonitorCallback | |||
from fastNLP.core.utils import f_rich_progress | |||
from fastNLP.core.log import logger | |||
from fastNLP.core.utils.utils import is_notebook | |||
class ProgressCallback(HasMonitorCallback): | |||
def on_train_end(self, trainer): | |||
f_rich_progress.stop() | |||
@property | |||
def name(self): # progress bar的名称 | |||
return 'auto' | |||
def choose_progress_callback(progress_bar:str): | |||
def choose_progress_callback(progress_bar: Union[str, ProgressCallback]) -> ProgressCallback: | |||
if progress_bar == 'auto': | |||
if (sys.stdin and sys.stdin.isatty()): | |||
if not f_rich_progress.dummy_rich: | |||
progress_bar = 'rich' | |||
else: | |||
progress_bar = 'raw' | |||
@@ -23,15 +34,12 @@ def choose_progress_callback(progress_bar:str): | |||
return RichCallback() | |||
elif progress_bar == 'raw': | |||
return RawTextCallback() | |||
elif isinstance(progress_bar, ProgressCallback): | |||
return progress_bar | |||
else: | |||
return None | |||
class ProgressCallback(HasMonitorCallback): | |||
def on_train_end(self, trainer): | |||
f_rich_progress.stop() | |||
class RichCallback(ProgressCallback): | |||
def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True, | |||
format_json=True): | |||
@@ -124,6 +132,10 @@ class RichCallback(ProgressCallback): | |||
self.task2id = {} | |||
self.loss = 0 | |||
@property | |||
def name(self): # progress bar的名称 | |||
return 'rich' | |||
class RawTextCallback(ProgressCallback): | |||
def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True, | |||
@@ -184,3 +196,7 @@ class RawTextCallback(ProgressCallback): | |||
logger.info(json.dumps(trainer.driver.tensor_to_numeric(results))) | |||
else: | |||
logger.info(results) | |||
@property | |||
def name(self): # progress bar的名称 | |||
return 'raw' |
@@ -134,7 +134,7 @@ class Evaluator: | |||
self.progress_bar = kwargs.get('progress_bar', 'auto') | |||
if self.progress_bar == 'auto': | |||
self.progress_bar = 'rich' if (sys.stdin and sys.stdin.isatty()) else 'raw' | |||
self.progress_bar = 'raw' if f_rich_progress.dummy_rich else 'rich' | |||
self.driver.barrier() | |||
@@ -20,6 +20,7 @@ from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _Truncated | |||
from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList | |||
from fastNLP.core.callbacks.callback import _CallbackWrapper | |||
from fastNLP.core.callbacks.callback_events import _SingleEventState | |||
from fastNLP.core.callbacks.progress_callback import choose_progress_callback | |||
from fastNLP.core.drivers import Driver | |||
from fastNLP.core.drivers.utils import choose_driver | |||
from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext | |||
@@ -125,14 +126,13 @@ class Trainer(TrainerEventTrigger): | |||
set_grad_to_none: 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; | |||
use_dist_sampler: 表示是否使用分布式的 sampler 。在多卡时,分布式 sampler 将自动决定每张卡上读取的 sample ,使得一个epoch | |||
内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。 | |||
eval_use_dist_sampler: 表示在 Evaluator 中在使用 TorchDDPDriver 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True; | |||
evaluate_use_dist_sampler: 表示在 Evaluator 中在使用 分布式 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True; | |||
output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: | |||
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 | |||
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; | |||
progress_bar: 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'],默认为 auto 。progress 的实现是通过 | |||
callback 实现的,若在输入的 callback 中检测到了 ProgressCallback 类型的 callback ,则该参数对 Trainer 无效。 | |||
auto 表示如果检测到当前 terminal 为交互型 则使用 rich,否则使用 raw。 | |||
progress_bar: 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'] 或者 RichCallback, RawTextCallback对象, | |||
默认为 auto , auto 表示如果检测到当前 terminal 为交互型 则使用 RichCallback,否则使用 RawTextCallback对象。如果 | |||
需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 对象。 | |||
""" | |||
self.model = model | |||
self.marker = marker | |||
@@ -195,8 +195,20 @@ class Trainer(TrainerEventTrigger): | |||
) | |||
self.driver.set_optimizers(optimizers=optimizers) | |||
# 根据 progress_bar 参数选择 ProgressBarCallback | |||
progress_bar_callback = choose_progress_callback(kwargs.get('progress_bar', 'auto')) | |||
if progress_bar_callback is not None: | |||
if callbacks is None: | |||
callbacks = [] | |||
elif not isinstance(callbacks, Sequence): | |||
callbacks = [callbacks] | |||
callbacks = list(callbacks) + [progress_bar_callback] | |||
else: | |||
rank_zero_call(logger.warning)("No progress bar is provided, there will have no information output " | |||
"during training.") | |||
# 初始化 callback manager; | |||
self.callback_manager = CallbackManager(callbacks, kwargs.get('progress_bar', 'auto')) | |||
self.callback_manager = CallbackManager(callbacks) | |||
# 添加所有的函数式 callbacks; | |||
self._fetch_matched_fn_callbacks() | |||
# 添加所有的类 callbacks; | |||
@@ -237,6 +249,9 @@ class Trainer(TrainerEventTrigger): | |||
self.larger_better = larger_better | |||
if metrics is not None and evaluate_dataloaders is not None: | |||
check_evaluate_every(evaluate_every) | |||
progress_bar = kwargs.get('progress_bar', 'auto') # 如果不为 | |||
if not (isinstance(progress_bar, str) or progress_bar is None): # 应该是ProgressCallback,获取其名称。 | |||
progress_bar = progress_bar.name | |||
self.evaluator = Evaluator( | |||
model=model, | |||
dataloaders=evaluate_dataloaders, | |||
@@ -249,8 +264,8 @@ class Trainer(TrainerEventTrigger): | |||
output_mapping=output_mapping, | |||
fp16=fp16, | |||
verbose=0, | |||
use_dist_sampler=kwargs.get("eval_use_dist_sampler", None), | |||
progress_bar=kwargs.get('progress_bar', 'auto') | |||
use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", None), | |||
progress_bar=progress_bar | |||
) | |||
if train_fn is not None and not isinstance(train_fn, str): | |||
@@ -14,6 +14,7 @@ __all__ = [ | |||
] | |||
from fastNLP.envs import get_global_rank | |||
from .utils import is_notebook | |||
class Singleton(type): | |||
@@ -34,6 +35,14 @@ class DummyFRichProgress: | |||
# 防止用户通过 DummyFRichProgress.console.print() 这种调用 | |||
return None | |||
@property | |||
def dummy_rich(self)->bool: | |||
""" | |||
当前对象是否是 dummy 的 rich 对象。 | |||
:return: | |||
""" | |||
return True | |||
class FRichProgress(Progress, metaclass=Singleton): | |||
""" | |||
@@ -147,6 +156,8 @@ class FRichProgress(Progress, metaclass=Singleton): | |||
super().stop_task(task_id) | |||
super().remove_task(task_id) | |||
self.refresh() # 使得bar不残留 | |||
if len(self._tasks) == 0: | |||
super().stop() | |||
def start(self) -> None: | |||
super().start() | |||
@@ -210,6 +221,15 @@ class FRichProgress(Progress, metaclass=Singleton): | |||
if refresh: | |||
self.refresh() | |||
@property | |||
def dummy_rich(self) -> bool: | |||
""" | |||
当前对象是否是 dummy 的 rich 对象。 | |||
:return: | |||
""" | |||
return False | |||
class SpeedColumn(ProgressColumn): | |||
""" | |||
@@ -226,7 +246,8 @@ class SpeedColumn(ProgressColumn): | |||
return Text(str(round(1/speed, 2))+' s/it.', style='progress.data.speed') | |||
if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0: | |||
if ((sys.stdin and sys.stdin.isatty()) or is_notebook()) and \ | |||
get_global_rank() == 0: | |||
f_rich_progress = FRichProgress().new_progess( | |||
"[progress.description]{task.description}", | |||
"[progress.percentage]{task.percentage:>3.0f}%", | |||
@@ -696,4 +696,23 @@ def get_class_that_defined_method(method): | |||
None) | |||
if isinstance(cls, type): | |||
return cls | |||
return getattr(method, '__objclass__', None) # handle special descriptor objects | |||
return getattr(method, '__objclass__', None) # handle special descriptor objects | |||
def is_notebook(): | |||
""" | |||
检查当前运行环境是否为 jupyter | |||
:return: | |||
""" | |||
try: | |||
from IPython import get_ipython | |||
if "IPKernelApp" not in get_ipython().config: # pragma: no cover | |||
raise ImportError("console") | |||
if "VSCODE_PID" in os.environ: # pragma: no cover | |||
raise ImportError("vscode") | |||
except: | |||
return False | |||
else: # pragma: no cover | |||
return True |