diff --git a/fastNLP/core/callbacks/progress_callback.py b/fastNLP/core/callbacks/progress_callback.py index 2fa62c87..890864ec 100644 --- a/fastNLP/core/callbacks/progress_callback.py +++ b/fastNLP/core/callbacks/progress_callback.py @@ -57,7 +57,7 @@ def choose_progress_callback(progress_bar: Union[str, ProgressCallback]) -> Prog class RichCallback(ProgressCallback): """ 在训练过程中打印 rich progress bar 的 callback 。在 Trainer 中,默认就会使用这个 callback 来显示进度。如果需要定制这个 Callback 的 - 参数,请通过实例化本 Callback 并传入到 Trainer 中实现。 + 参数,请通过实例化本 Callback 并传入到 Trainer 中实现。在打印 evaluate 的结果时,不会打印名称以 "_" 开头的内容。 :param print_every: 多少个 batch 更新一次显示。 :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 @@ -144,8 +144,10 @@ class RichCallback(ProgressCallback): self.progress_bar.console.rule(text_style+f"Eval. results on Epoch:{trainer.cur_epoch_idx}, " f"Batch:{trainer.batch_idx_in_epoch}", style=rule_style, characters=characters) + results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if + not key.startswith('_')} if self.format_json: - self.progress_bar.console.print_json(json.dumps(trainer.driver.tensor_to_numeric(results))) + self.progress_bar.console.print_json(json.dumps(results)) else: self.progress_bar.print(results) @@ -165,7 +167,7 @@ class RawTextCallback(ProgressCallback): def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True, format_json=True): """ - 通过向命令行打印进度的方式显示 + 通过向命令行打印进度的方式显示。在打印 evaluate 的结果时,不会打印名称以 "_" 开头的内容。 :param print_every: 多少个 batch 更新一次显示。 :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 @@ -222,8 +224,10 @@ class RawTextCallback(ProgressCallback): text = '-'*self.num_signs + base_text + '-'*self.num_signs logger.info(text) + results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if + not key.startswith('_')} if self.format_json: - logger.info(json.dumps(trainer.driver.tensor_to_numeric(results))) + logger.info(json.dumps(results)) else: logger.info(results) @@ -235,7 +239,7 @@ class RawTextCallback(ProgressCallback): class TqdmCallback(ProgressCallback): """ 在训练过程中打印 tqdm progress bar 的 callback 。在 Trainer 中,默认就会使用这个 callback 来显示进度。如果需要定制这个 Callback 的 - 参数,请通过实例化本 Callback 并传入到 Trainer 中实现。 + 参数,请通过实例化本 Callback 并传入到 Trainer 中实现。在打印 evaluate 的结果时,不会打印名称以 "_" 开头的内容。 :param print_every: 多少个 batch 更新一次显示。 :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 @@ -309,8 +313,10 @@ class TqdmCallback(ProgressCallback): text = '-'*self.num_signs + base_text + '-'*self.num_signs logger.info(text) + results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if + not key.startswith('_')} if self.format_json: - logger.info(json.dumps(trainer.driver.tensor_to_numeric(results))) + logger.info(json.dumps(results)) else: logger.info(results) diff --git a/fastNLP/core/utils/utils.py b/fastNLP/core/utils/utils.py index 11256d45..33a7ee7e 100644 --- a/fastNLP/core/utils/utils.py +++ b/fastNLP/core/utils/utils.py @@ -630,7 +630,7 @@ def is_notebook(): def flat_nest_dict(d:Dict, separator:str='#', compress_none_key:bool=True, top_down:bool=False) -> Dict: """ - 讲一个 nested 的 dict 转成 flat 的 dict,例如 + 将一个 nested 的 dict 转成 flat 的 dict,例如 ex:: d = {'test': {'f1': {'f': 0.2, 'rec': 0.1}}} -> {'f#f1#test':0.2, 'rec#f1#test':0.1} diff --git a/fastNLP/io/data_bundle.py b/fastNLP/io/data_bundle.py index 58538d61..4029e092 100644 --- a/fastNLP/io/data_bundle.py +++ b/fastNLP/io/data_bundle.py @@ -245,8 +245,9 @@ class DataBundle: """ _progress_desc = progress_desc for name, dataset in self.datasets.items(): - if _progress_desc: - progress_desc = _progress_desc + f' for `{name}`' + if len(_progress_desc) == 0: + _progress_desc = 'Processing' + progress_desc = _progress_desc + f' for `{name}`' if dataset.has_field(field_name=field_name): dataset.apply_field(func=func, field_name=field_name, new_field_name=new_field_name, num_proc=num_proc, progress_desc=progress_desc, progress_bar=progress_bar) @@ -284,8 +285,9 @@ class DataBundle: res = {} _progress_desc = progress_desc for name, dataset in self.datasets.items(): - if _progress_desc: - progress_desc = _progress_desc + f' for `{name}`' + if len(_progress_desc) == 0: + _progress_desc = 'Processing' + progress_desc = _progress_desc + f' for `{name}`' if dataset.has_field(field_name=field_name): res[name] = dataset.apply_field_more(func=func, field_name=field_name, num_proc=num_proc, modify_fields=modify_fields, @@ -317,8 +319,9 @@ class DataBundle: """ _progress_desc = progress_desc for name, dataset in self.datasets.items(): - if _progress_desc: - progress_desc = _progress_desc + f' for `{name}`' + if len(_progress_desc) == 0: + _progress_desc = 'Processing' + progress_desc = _progress_desc + f' for `{name}`' dataset.apply(func, new_field_name=new_field_name, num_proc=num_proc, progress_bar=progress_bar, progress_desc=progress_desc) return self @@ -349,8 +352,9 @@ class DataBundle: res = {} _progress_desc = progress_desc for name, dataset in self.datasets.items(): - if _progress_desc: - progress_desc = _progress_desc + f' for `{name}`' + if len(_progress_desc) == 0: + _progress_desc = 'Processing' + progress_desc = _progress_desc + f' for `{name}`' res[name] = dataset.apply_more(func, modify_fields=modify_fields, num_proc=num_proc, progress_bar=progress_bar, progress_desc=progress_desc) return res