@@ -57,7 +57,7 @@ def choose_progress_callback(progress_bar: Union[str, ProgressCallback]) -> Prog | |||
class RichCallback(ProgressCallback): | |||
""" | |||
在训练过程中打印 rich progress bar 的 callback 。在 Trainer 中,默认就会使用这个 callback 来显示进度。如果需要定制这个 Callback 的 | |||
参数,请通过实例化本 Callback 并传入到 Trainer 中实现。 | |||
参数,请通过实例化本 Callback 并传入到 Trainer 中实现。在打印 evaluate 的结果时,不会打印名称以 "_" 开头的内容。 | |||
:param print_every: 多少个 batch 更新一次显示。 | |||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | |||
@@ -144,8 +144,10 @@ class RichCallback(ProgressCallback): | |||
self.progress_bar.console.rule(text_style+f"Eval. results on Epoch:{trainer.cur_epoch_idx}, " | |||
f"Batch:{trainer.batch_idx_in_epoch}", | |||
style=rule_style, characters=characters) | |||
results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if | |||
not key.startswith('_')} | |||
if self.format_json: | |||
self.progress_bar.console.print_json(json.dumps(trainer.driver.tensor_to_numeric(results))) | |||
self.progress_bar.console.print_json(json.dumps(results)) | |||
else: | |||
self.progress_bar.print(results) | |||
@@ -165,7 +167,7 @@ class RawTextCallback(ProgressCallback): | |||
def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True, | |||
format_json=True): | |||
""" | |||
通过向命令行打印进度的方式显示 | |||
通过向命令行打印进度的方式显示。在打印 evaluate 的结果时,不会打印名称以 "_" 开头的内容。 | |||
:param print_every: 多少个 batch 更新一次显示。 | |||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | |||
@@ -222,8 +224,10 @@ class RawTextCallback(ProgressCallback): | |||
text = '-'*self.num_signs + base_text + '-'*self.num_signs | |||
logger.info(text) | |||
results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if | |||
not key.startswith('_')} | |||
if self.format_json: | |||
logger.info(json.dumps(trainer.driver.tensor_to_numeric(results))) | |||
logger.info(json.dumps(results)) | |||
else: | |||
logger.info(results) | |||
@@ -235,7 +239,7 @@ class RawTextCallback(ProgressCallback): | |||
class TqdmCallback(ProgressCallback): | |||
""" | |||
在训练过程中打印 tqdm progress bar 的 callback 。在 Trainer 中,默认就会使用这个 callback 来显示进度。如果需要定制这个 Callback 的 | |||
参数,请通过实例化本 Callback 并传入到 Trainer 中实现。 | |||
参数,请通过实例化本 Callback 并传入到 Trainer 中实现。在打印 evaluate 的结果时,不会打印名称以 "_" 开头的内容。 | |||
:param print_every: 多少个 batch 更新一次显示。 | |||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | |||
@@ -309,8 +313,10 @@ class TqdmCallback(ProgressCallback): | |||
text = '-'*self.num_signs + base_text + '-'*self.num_signs | |||
logger.info(text) | |||
results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if | |||
not key.startswith('_')} | |||
if self.format_json: | |||
logger.info(json.dumps(trainer.driver.tensor_to_numeric(results))) | |||
logger.info(json.dumps(results)) | |||
else: | |||
logger.info(results) | |||
@@ -630,7 +630,7 @@ def is_notebook(): | |||
def flat_nest_dict(d:Dict, separator:str='#', compress_none_key:bool=True, top_down:bool=False) -> Dict: | |||
""" | |||
讲一个 nested 的 dict 转成 flat 的 dict,例如 | |||
将一个 nested 的 dict 转成 flat 的 dict,例如 | |||
ex:: | |||
d = {'test': {'f1': {'f': 0.2, 'rec': 0.1}}} -> {'f#f1#test':0.2, 'rec#f1#test':0.1} | |||
@@ -245,8 +245,9 @@ class DataBundle: | |||
""" | |||
_progress_desc = progress_desc | |||
for name, dataset in self.datasets.items(): | |||
if _progress_desc: | |||
progress_desc = _progress_desc + f' for `{name}`' | |||
if len(_progress_desc) == 0: | |||
_progress_desc = 'Processing' | |||
progress_desc = _progress_desc + f' for `{name}`' | |||
if dataset.has_field(field_name=field_name): | |||
dataset.apply_field(func=func, field_name=field_name, new_field_name=new_field_name, num_proc=num_proc, | |||
progress_desc=progress_desc, progress_bar=progress_bar) | |||
@@ -284,8 +285,9 @@ class DataBundle: | |||
res = {} | |||
_progress_desc = progress_desc | |||
for name, dataset in self.datasets.items(): | |||
if _progress_desc: | |||
progress_desc = _progress_desc + f' for `{name}`' | |||
if len(_progress_desc) == 0: | |||
_progress_desc = 'Processing' | |||
progress_desc = _progress_desc + f' for `{name}`' | |||
if dataset.has_field(field_name=field_name): | |||
res[name] = dataset.apply_field_more(func=func, field_name=field_name, num_proc=num_proc, | |||
modify_fields=modify_fields, | |||
@@ -317,8 +319,9 @@ class DataBundle: | |||
""" | |||
_progress_desc = progress_desc | |||
for name, dataset in self.datasets.items(): | |||
if _progress_desc: | |||
progress_desc = _progress_desc + f' for `{name}`' | |||
if len(_progress_desc) == 0: | |||
_progress_desc = 'Processing' | |||
progress_desc = _progress_desc + f' for `{name}`' | |||
dataset.apply(func, new_field_name=new_field_name, num_proc=num_proc, progress_bar=progress_bar, | |||
progress_desc=progress_desc) | |||
return self | |||
@@ -349,8 +352,9 @@ class DataBundle: | |||
res = {} | |||
_progress_desc = progress_desc | |||
for name, dataset in self.datasets.items(): | |||
if _progress_desc: | |||
progress_desc = _progress_desc + f' for `{name}`' | |||
if len(_progress_desc) == 0: | |||
_progress_desc = 'Processing' | |||
progress_desc = _progress_desc + f' for `{name}`' | |||
res[name] = dataset.apply_more(func, modify_fields=modify_fields, num_proc=num_proc, | |||
progress_bar=progress_bar, progress_desc=progress_desc) | |||
return res | |||