@@ -57,7 +57,7 @@ def choose_progress_callback(progress_bar: Union[str, ProgressCallback]) -> Prog | |||||
class RichCallback(ProgressCallback): | class RichCallback(ProgressCallback): | ||||
""" | """ | ||||
在训练过程中打印 rich progress bar 的 callback 。在 Trainer 中,默认就会使用这个 callback 来显示进度。如果需要定制这个 Callback 的 | 在训练过程中打印 rich progress bar 的 callback 。在 Trainer 中,默认就会使用这个 callback 来显示进度。如果需要定制这个 Callback 的 | ||||
参数,请通过实例化本 Callback 并传入到 Trainer 中实现。 | |||||
参数,请通过实例化本 Callback 并传入到 Trainer 中实现。在打印 evaluate 的结果时,不会打印名称以 "_" 开头的内容。 | |||||
:param print_every: 多少个 batch 更新一次显示。 | :param print_every: 多少个 batch 更新一次显示。 | ||||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | ||||
@@ -144,8 +144,10 @@ class RichCallback(ProgressCallback): | |||||
self.progress_bar.console.rule(text_style+f"Eval. results on Epoch:{trainer.cur_epoch_idx}, " | self.progress_bar.console.rule(text_style+f"Eval. results on Epoch:{trainer.cur_epoch_idx}, " | ||||
f"Batch:{trainer.batch_idx_in_epoch}", | f"Batch:{trainer.batch_idx_in_epoch}", | ||||
style=rule_style, characters=characters) | style=rule_style, characters=characters) | ||||
results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if | |||||
not key.startswith('_')} | |||||
if self.format_json: | if self.format_json: | ||||
self.progress_bar.console.print_json(json.dumps(trainer.driver.tensor_to_numeric(results))) | |||||
self.progress_bar.console.print_json(json.dumps(results)) | |||||
else: | else: | ||||
self.progress_bar.print(results) | self.progress_bar.print(results) | ||||
@@ -165,7 +167,7 @@ class RawTextCallback(ProgressCallback): | |||||
def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True, | def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True, | ||||
format_json=True): | format_json=True): | ||||
""" | """ | ||||
通过向命令行打印进度的方式显示 | |||||
通过向命令行打印进度的方式显示。在打印 evaluate 的结果时,不会打印名称以 "_" 开头的内容。 | |||||
:param print_every: 多少个 batch 更新一次显示。 | :param print_every: 多少个 batch 更新一次显示。 | ||||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | ||||
@@ -222,8 +224,10 @@ class RawTextCallback(ProgressCallback): | |||||
text = '-'*self.num_signs + base_text + '-'*self.num_signs | text = '-'*self.num_signs + base_text + '-'*self.num_signs | ||||
logger.info(text) | logger.info(text) | ||||
results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if | |||||
not key.startswith('_')} | |||||
if self.format_json: | if self.format_json: | ||||
logger.info(json.dumps(trainer.driver.tensor_to_numeric(results))) | |||||
logger.info(json.dumps(results)) | |||||
else: | else: | ||||
logger.info(results) | logger.info(results) | ||||
@@ -235,7 +239,7 @@ class RawTextCallback(ProgressCallback): | |||||
class TqdmCallback(ProgressCallback): | class TqdmCallback(ProgressCallback): | ||||
""" | """ | ||||
在训练过程中打印 tqdm progress bar 的 callback 。在 Trainer 中,默认就会使用这个 callback 来显示进度。如果需要定制这个 Callback 的 | 在训练过程中打印 tqdm progress bar 的 callback 。在 Trainer 中,默认就会使用这个 callback 来显示进度。如果需要定制这个 Callback 的 | ||||
参数,请通过实例化本 Callback 并传入到 Trainer 中实现。 | |||||
参数,请通过实例化本 Callback 并传入到 Trainer 中实现。在打印 evaluate 的结果时,不会打印名称以 "_" 开头的内容。 | |||||
:param print_every: 多少个 batch 更新一次显示。 | :param print_every: 多少个 batch 更新一次显示。 | ||||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | ||||
@@ -309,8 +313,10 @@ class TqdmCallback(ProgressCallback): | |||||
text = '-'*self.num_signs + base_text + '-'*self.num_signs | text = '-'*self.num_signs + base_text + '-'*self.num_signs | ||||
logger.info(text) | logger.info(text) | ||||
results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if | |||||
not key.startswith('_')} | |||||
if self.format_json: | if self.format_json: | ||||
logger.info(json.dumps(trainer.driver.tensor_to_numeric(results))) | |||||
logger.info(json.dumps(results)) | |||||
else: | else: | ||||
logger.info(results) | logger.info(results) | ||||
@@ -630,7 +630,7 @@ def is_notebook(): | |||||
def flat_nest_dict(d:Dict, separator:str='#', compress_none_key:bool=True, top_down:bool=False) -> Dict: | def flat_nest_dict(d:Dict, separator:str='#', compress_none_key:bool=True, top_down:bool=False) -> Dict: | ||||
""" | """ | ||||
讲一个 nested 的 dict 转成 flat 的 dict,例如 | |||||
将一个 nested 的 dict 转成 flat 的 dict,例如 | |||||
ex:: | ex:: | ||||
d = {'test': {'f1': {'f': 0.2, 'rec': 0.1}}} -> {'f#f1#test':0.2, 'rec#f1#test':0.1} | d = {'test': {'f1': {'f': 0.2, 'rec': 0.1}}} -> {'f#f1#test':0.2, 'rec#f1#test':0.1} | ||||
@@ -245,8 +245,9 @@ class DataBundle: | |||||
""" | """ | ||||
_progress_desc = progress_desc | _progress_desc = progress_desc | ||||
for name, dataset in self.datasets.items(): | for name, dataset in self.datasets.items(): | ||||
if _progress_desc: | |||||
progress_desc = _progress_desc + f' for `{name}`' | |||||
if len(_progress_desc) == 0: | |||||
_progress_desc = 'Processing' | |||||
progress_desc = _progress_desc + f' for `{name}`' | |||||
if dataset.has_field(field_name=field_name): | if dataset.has_field(field_name=field_name): | ||||
dataset.apply_field(func=func, field_name=field_name, new_field_name=new_field_name, num_proc=num_proc, | dataset.apply_field(func=func, field_name=field_name, new_field_name=new_field_name, num_proc=num_proc, | ||||
progress_desc=progress_desc, progress_bar=progress_bar) | progress_desc=progress_desc, progress_bar=progress_bar) | ||||
@@ -284,8 +285,9 @@ class DataBundle: | |||||
res = {} | res = {} | ||||
_progress_desc = progress_desc | _progress_desc = progress_desc | ||||
for name, dataset in self.datasets.items(): | for name, dataset in self.datasets.items(): | ||||
if _progress_desc: | |||||
progress_desc = _progress_desc + f' for `{name}`' | |||||
if len(_progress_desc) == 0: | |||||
_progress_desc = 'Processing' | |||||
progress_desc = _progress_desc + f' for `{name}`' | |||||
if dataset.has_field(field_name=field_name): | if dataset.has_field(field_name=field_name): | ||||
res[name] = dataset.apply_field_more(func=func, field_name=field_name, num_proc=num_proc, | res[name] = dataset.apply_field_more(func=func, field_name=field_name, num_proc=num_proc, | ||||
modify_fields=modify_fields, | modify_fields=modify_fields, | ||||
@@ -317,8 +319,9 @@ class DataBundle: | |||||
""" | """ | ||||
_progress_desc = progress_desc | _progress_desc = progress_desc | ||||
for name, dataset in self.datasets.items(): | for name, dataset in self.datasets.items(): | ||||
if _progress_desc: | |||||
progress_desc = _progress_desc + f' for `{name}`' | |||||
if len(_progress_desc) == 0: | |||||
_progress_desc = 'Processing' | |||||
progress_desc = _progress_desc + f' for `{name}`' | |||||
dataset.apply(func, new_field_name=new_field_name, num_proc=num_proc, progress_bar=progress_bar, | dataset.apply(func, new_field_name=new_field_name, num_proc=num_proc, progress_bar=progress_bar, | ||||
progress_desc=progress_desc) | progress_desc=progress_desc) | ||||
return self | return self | ||||
@@ -349,8 +352,9 @@ class DataBundle: | |||||
res = {} | res = {} | ||||
_progress_desc = progress_desc | _progress_desc = progress_desc | ||||
for name, dataset in self.datasets.items(): | for name, dataset in self.datasets.items(): | ||||
if _progress_desc: | |||||
progress_desc = _progress_desc + f' for `{name}`' | |||||
if len(_progress_desc) == 0: | |||||
_progress_desc = 'Processing' | |||||
progress_desc = _progress_desc + f' for `{name}`' | |||||
res[name] = dataset.apply_more(func, modify_fields=modify_fields, num_proc=num_proc, | res[name] = dataset.apply_more(func, modify_fields=modify_fields, num_proc=num_proc, | ||||
progress_bar=progress_bar, progress_desc=progress_desc) | progress_bar=progress_bar, progress_desc=progress_desc) | ||||
return res | return res | ||||