@@ -8,7 +8,6 @@ __all__ = [ | |||||
from .callback_events import Events | from .callback_events import Events | ||||
from .callback import Callback | from .callback import Callback | ||||
from .progress_callback import ProgressCallback, choose_progress_callback | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
@@ -35,7 +34,7 @@ class CallbackManager: | |||||
class_callbacks: Optional[List[Callback]] # 用来保留原始的类callback; | class_callbacks: Optional[List[Callback]] # 用来保留原始的类callback; | ||||
callback_fns: dict | callback_fns: dict | ||||
def __init__(self, callbacks: Optional[List[Callback]], progress_bar='auto'): | |||||
def __init__(self, callbacks: Optional[List[Callback]]): | |||||
r""" | r""" | ||||
注意 callback 的调用顺序: | 注意 callback 的调用顺序: | ||||
1. 通过函数修饰器 `Trainer.on` 添加的 callback 函数; | 1. 通过函数修饰器 `Trainer.on` 添加的 callback 函数; | ||||
@@ -46,7 +45,6 @@ class CallbackManager: | |||||
""" | """ | ||||
self._need_reproducible_sampler = False | self._need_reproducible_sampler = False | ||||
_has_progress_callback = False | |||||
_callbacks = [] | _callbacks = [] | ||||
if callbacks is not None: | if callbacks is not None: | ||||
if isinstance(callbacks, Callback): | if isinstance(callbacks, Callback): | ||||
@@ -57,16 +55,7 @@ class CallbackManager: | |||||
for _callback in callbacks: | for _callback in callbacks: | ||||
if not isinstance(_callback, Callback): | if not isinstance(_callback, Callback): | ||||
raise TypeError(f"callbacks must be of Callback type, instead of `{type(_callback)}`") | raise TypeError(f"callbacks must be of Callback type, instead of `{type(_callback)}`") | ||||
if isinstance(_callback, ProgressCallback): | |||||
_has_progress_callback = True | |||||
_callbacks += callbacks | _callbacks += callbacks | ||||
if not _has_progress_callback: | |||||
# 添加 progress callback | |||||
progress_callback = choose_progress_callback(progress_bar=progress_bar) | |||||
if progress_callback is None: | |||||
logger.info("There is no progress bar, Trainer will not output training progress.") | |||||
else: | |||||
_callbacks.append(progress_callback) | |||||
self.callback_fns = defaultdict(list) | self.callback_fns = defaultdict(list) | ||||
# 因为理论上用户最多只能通过 'trainer.on_train_begin' 或者 'trainer.callback_manager.on_train_begin' 来调用,即其是没办法 | # 因为理论上用户最多只能通过 'trainer.on_train_begin' 或者 'trainer.callback_manager.on_train_begin' 来调用,即其是没办法 | ||||
# 直接调用具体的某一个 callback 函数,而不调用其余的同名的 callback 函数的,因此我们只需要记录具体 Event 的时机即可; | # 直接调用具体的某一个 callback 函数,而不调用其余的同名的 callback 函数的,因此我们只需要记录具体 Event 的时机即可; | ||||
@@ -1,6 +1,6 @@ | |||||
import json | import json | ||||
import sys | import sys | ||||
from typing import Union | |||||
__all__ = [ | __all__ = [ | ||||
'choose_progress_callback', | 'choose_progress_callback', | ||||
@@ -11,11 +11,22 @@ __all__ = [ | |||||
from .has_monitor_callback import HasMonitorCallback | from .has_monitor_callback import HasMonitorCallback | ||||
from fastNLP.core.utils import f_rich_progress | from fastNLP.core.utils import f_rich_progress | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.utils.utils import is_notebook | |||||
class ProgressCallback(HasMonitorCallback): | |||||
def on_train_end(self, trainer): | |||||
f_rich_progress.stop() | |||||
@property | |||||
def name(self): # progress bar的名称 | |||||
return 'auto' | |||||
def choose_progress_callback(progress_bar:str): | |||||
def choose_progress_callback(progress_bar: Union[str, ProgressCallback]) -> ProgressCallback: | |||||
if progress_bar == 'auto': | if progress_bar == 'auto': | ||||
if (sys.stdin and sys.stdin.isatty()): | |||||
if not f_rich_progress.dummy_rich: | |||||
progress_bar = 'rich' | progress_bar = 'rich' | ||||
else: | else: | ||||
progress_bar = 'raw' | progress_bar = 'raw' | ||||
@@ -23,15 +34,12 @@ def choose_progress_callback(progress_bar:str): | |||||
return RichCallback() | return RichCallback() | ||||
elif progress_bar == 'raw': | elif progress_bar == 'raw': | ||||
return RawTextCallback() | return RawTextCallback() | ||||
elif isinstance(progress_bar, ProgressCallback): | |||||
return progress_bar | |||||
else: | else: | ||||
return None | return None | ||||
class ProgressCallback(HasMonitorCallback): | |||||
def on_train_end(self, trainer): | |||||
f_rich_progress.stop() | |||||
class RichCallback(ProgressCallback): | class RichCallback(ProgressCallback): | ||||
def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True, | def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True, | ||||
format_json=True): | format_json=True): | ||||
@@ -124,6 +132,10 @@ class RichCallback(ProgressCallback): | |||||
self.task2id = {} | self.task2id = {} | ||||
self.loss = 0 | self.loss = 0 | ||||
@property | |||||
def name(self): # progress bar的名称 | |||||
return 'rich' | |||||
class RawTextCallback(ProgressCallback): | class RawTextCallback(ProgressCallback): | ||||
def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True, | def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True, | ||||
@@ -184,3 +196,7 @@ class RawTextCallback(ProgressCallback): | |||||
logger.info(json.dumps(trainer.driver.tensor_to_numeric(results))) | logger.info(json.dumps(trainer.driver.tensor_to_numeric(results))) | ||||
else: | else: | ||||
logger.info(results) | logger.info(results) | ||||
@property | |||||
def name(self): # progress bar的名称 | |||||
return 'raw' |
@@ -134,7 +134,7 @@ class Evaluator: | |||||
self.progress_bar = kwargs.get('progress_bar', 'auto') | self.progress_bar = kwargs.get('progress_bar', 'auto') | ||||
if self.progress_bar == 'auto': | if self.progress_bar == 'auto': | ||||
self.progress_bar = 'rich' if (sys.stdin and sys.stdin.isatty()) else 'raw' | |||||
self.progress_bar = 'raw' if f_rich_progress.dummy_rich else 'rich' | |||||
self.driver.barrier() | self.driver.barrier() | ||||
@@ -20,6 +20,7 @@ from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _Truncated | |||||
from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList | from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList | ||||
from fastNLP.core.callbacks.callback import _CallbackWrapper | from fastNLP.core.callbacks.callback import _CallbackWrapper | ||||
from fastNLP.core.callbacks.callback_events import _SingleEventState | from fastNLP.core.callbacks.callback_events import _SingleEventState | ||||
from fastNLP.core.callbacks.progress_callback import choose_progress_callback | |||||
from fastNLP.core.drivers import Driver | from fastNLP.core.drivers import Driver | ||||
from fastNLP.core.drivers.utils import choose_driver | from fastNLP.core.drivers.utils import choose_driver | ||||
from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext | from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext | ||||
@@ -125,14 +126,13 @@ class Trainer(TrainerEventTrigger): | |||||
set_grad_to_none: 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; | set_grad_to_none: 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; | ||||
use_dist_sampler: 表示是否使用分布式的 sampler 。在多卡时,分布式 sampler 将自动决定每张卡上读取的 sample ,使得一个epoch | use_dist_sampler: 表示是否使用分布式的 sampler 。在多卡时,分布式 sampler 将自动决定每张卡上读取的 sample ,使得一个epoch | ||||
内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。 | 内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。 | ||||
eval_use_dist_sampler: 表示在 Evaluator 中在使用 TorchDDPDriver 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True; | |||||
evaluate_use_dist_sampler: 表示在 Evaluator 中在使用 分布式 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True; | |||||
output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: | output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: | ||||
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 | ["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 | ||||
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; | log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; | ||||
progress_bar: 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'],默认为 auto 。progress 的实现是通过 | |||||
callback 实现的,若在输入的 callback 中检测到了 ProgressCallback 类型的 callback ,则该参数对 Trainer 无效。 | |||||
auto 表示如果检测到当前 terminal 为交互型 则使用 rich,否则使用 raw。 | |||||
progress_bar: 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'] 或者 RichCallback, RawTextCallback对象, | |||||
默认为 auto , auto 表示如果检测到当前 terminal 为交互型 则使用 RichCallback,否则使用 RawTextCallback对象。如果 | |||||
需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 对象。 | |||||
""" | """ | ||||
self.model = model | self.model = model | ||||
self.marker = marker | self.marker = marker | ||||
@@ -195,8 +195,20 @@ class Trainer(TrainerEventTrigger): | |||||
) | ) | ||||
self.driver.set_optimizers(optimizers=optimizers) | self.driver.set_optimizers(optimizers=optimizers) | ||||
# 根据 progress_bar 参数选择 ProgressBarCallback | |||||
progress_bar_callback = choose_progress_callback(kwargs.get('progress_bar', 'auto')) | |||||
if progress_bar_callback is not None: | |||||
if callbacks is None: | |||||
callbacks = [] | |||||
elif not isinstance(callbacks, Sequence): | |||||
callbacks = [callbacks] | |||||
callbacks = list(callbacks) + [progress_bar_callback] | |||||
else: | |||||
rank_zero_call(logger.warning)("No progress bar is provided, there will have no information output " | |||||
"during training.") | |||||
# 初始化 callback manager; | # 初始化 callback manager; | ||||
self.callback_manager = CallbackManager(callbacks, kwargs.get('progress_bar', 'auto')) | |||||
self.callback_manager = CallbackManager(callbacks) | |||||
# 添加所有的函数式 callbacks; | # 添加所有的函数式 callbacks; | ||||
self._fetch_matched_fn_callbacks() | self._fetch_matched_fn_callbacks() | ||||
# 添加所有的类 callbacks; | # 添加所有的类 callbacks; | ||||
@@ -237,6 +249,9 @@ class Trainer(TrainerEventTrigger): | |||||
self.larger_better = larger_better | self.larger_better = larger_better | ||||
if metrics is not None and evaluate_dataloaders is not None: | if metrics is not None and evaluate_dataloaders is not None: | ||||
check_evaluate_every(evaluate_every) | check_evaluate_every(evaluate_every) | ||||
progress_bar = kwargs.get('progress_bar', 'auto') # 如果不为 | |||||
if not (isinstance(progress_bar, str) or progress_bar is None): # 应该是ProgressCallback,获取其名称。 | |||||
progress_bar = progress_bar.name | |||||
self.evaluator = Evaluator( | self.evaluator = Evaluator( | ||||
model=model, | model=model, | ||||
dataloaders=evaluate_dataloaders, | dataloaders=evaluate_dataloaders, | ||||
@@ -249,8 +264,8 @@ class Trainer(TrainerEventTrigger): | |||||
output_mapping=output_mapping, | output_mapping=output_mapping, | ||||
fp16=fp16, | fp16=fp16, | ||||
verbose=0, | verbose=0, | ||||
use_dist_sampler=kwargs.get("eval_use_dist_sampler", None), | |||||
progress_bar=kwargs.get('progress_bar', 'auto') | |||||
use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", None), | |||||
progress_bar=progress_bar | |||||
) | ) | ||||
if train_fn is not None and not isinstance(train_fn, str): | if train_fn is not None and not isinstance(train_fn, str): | ||||
@@ -14,6 +14,7 @@ __all__ = [ | |||||
] | ] | ||||
from fastNLP.envs import get_global_rank | from fastNLP.envs import get_global_rank | ||||
from .utils import is_notebook | |||||
class Singleton(type): | class Singleton(type): | ||||
@@ -34,6 +35,14 @@ class DummyFRichProgress: | |||||
# 防止用户通过 DummyFRichProgress.console.print() 这种调用 | # 防止用户通过 DummyFRichProgress.console.print() 这种调用 | ||||
return None | return None | ||||
@property | |||||
def dummy_rich(self)->bool: | |||||
""" | |||||
当前对象是否是 dummy 的 rich 对象。 | |||||
:return: | |||||
""" | |||||
return True | |||||
class FRichProgress(Progress, metaclass=Singleton): | class FRichProgress(Progress, metaclass=Singleton): | ||||
""" | """ | ||||
@@ -147,6 +156,8 @@ class FRichProgress(Progress, metaclass=Singleton): | |||||
super().stop_task(task_id) | super().stop_task(task_id) | ||||
super().remove_task(task_id) | super().remove_task(task_id) | ||||
self.refresh() # 使得bar不残留 | self.refresh() # 使得bar不残留 | ||||
if len(self._tasks) == 0: | |||||
super().stop() | |||||
def start(self) -> None: | def start(self) -> None: | ||||
super().start() | super().start() | ||||
@@ -210,6 +221,15 @@ class FRichProgress(Progress, metaclass=Singleton): | |||||
if refresh: | if refresh: | ||||
self.refresh() | self.refresh() | ||||
@property | |||||
def dummy_rich(self) -> bool: | |||||
""" | |||||
当前对象是否是 dummy 的 rich 对象。 | |||||
:return: | |||||
""" | |||||
return False | |||||
class SpeedColumn(ProgressColumn): | class SpeedColumn(ProgressColumn): | ||||
""" | """ | ||||
@@ -226,7 +246,8 @@ class SpeedColumn(ProgressColumn): | |||||
return Text(str(round(1/speed, 2))+' s/it.', style='progress.data.speed') | return Text(str(round(1/speed, 2))+' s/it.', style='progress.data.speed') | ||||
if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0: | |||||
if ((sys.stdin and sys.stdin.isatty()) or is_notebook()) and \ | |||||
get_global_rank() == 0: | |||||
f_rich_progress = FRichProgress().new_progess( | f_rich_progress = FRichProgress().new_progess( | ||||
"[progress.description]{task.description}", | "[progress.description]{task.description}", | ||||
"[progress.percentage]{task.percentage:>3.0f}%", | "[progress.percentage]{task.percentage:>3.0f}%", | ||||
@@ -696,4 +696,23 @@ def get_class_that_defined_method(method): | |||||
None) | None) | ||||
if isinstance(cls, type): | if isinstance(cls, type): | ||||
return cls | return cls | ||||
return getattr(method, '__objclass__', None) # handle special descriptor objects | |||||
return getattr(method, '__objclass__', None) # handle special descriptor objects | |||||
def is_notebook(): | |||||
""" | |||||
检查当前运行环境是否为 jupyter | |||||
:return: | |||||
""" | |||||
try: | |||||
from IPython import get_ipython | |||||
if "IPKernelApp" not in get_ipython().config: # pragma: no cover | |||||
raise ImportError("console") | |||||
if "VSCODE_PID" in os.environ: # pragma: no cover | |||||
raise ImportError("vscode") | |||||
except: | |||||
return False | |||||
else: # pragma: no cover | |||||
return True |