Browse Source

progress打印增加一种特殊yueding

tags/v1.0.0alpha
yhcc 3 years ago
parent
commit
e4a7e64600
3 changed files with 25 additions and 15 deletions
  1. +12
    -6
      fastNLP/core/callbacks/progress_callback.py
  2. +1
    -1
      fastNLP/core/utils/utils.py
  3. +12
    -8
      fastNLP/io/data_bundle.py

+ 12
- 6
fastNLP/core/callbacks/progress_callback.py View File

@@ -57,7 +57,7 @@ def choose_progress_callback(progress_bar: Union[str, ProgressCallback]) -> Prog
class RichCallback(ProgressCallback):
"""
在训练过程中打印 rich progress bar 的 callback 。在 Trainer 中,默认就会使用这个 callback 来显示进度。如果需要定制这个 Callback 的
参数,请通过实例化本 Callback 并传入到 Trainer 中实现。
参数,请通过实例化本 Callback 并传入到 Trainer 中实现。在打印 evaluate 的结果时,不会打印名称以 "_" 开头的内容。

:param print_every: 多少个 batch 更新一次显示。
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字
@@ -144,8 +144,10 @@ class RichCallback(ProgressCallback):
self.progress_bar.console.rule(text_style+f"Eval. results on Epoch:{trainer.cur_epoch_idx}, "
f"Batch:{trainer.batch_idx_in_epoch}",
style=rule_style, characters=characters)
results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if
not key.startswith('_')}
if self.format_json:
self.progress_bar.console.print_json(json.dumps(trainer.driver.tensor_to_numeric(results)))
self.progress_bar.console.print_json(json.dumps(results))
else:
self.progress_bar.print(results)

@@ -165,7 +167,7 @@ class RawTextCallback(ProgressCallback):
def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True,
format_json=True):
"""
通过向命令行打印进度的方式显示
通过向命令行打印进度的方式显示。在打印 evaluate 的结果时,不会打印名称以 "_" 开头的内容。

:param print_every: 多少个 batch 更新一次显示。
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字
@@ -222,8 +224,10 @@ class RawTextCallback(ProgressCallback):
text = '-'*self.num_signs + base_text + '-'*self.num_signs

logger.info(text)
results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if
not key.startswith('_')}
if self.format_json:
logger.info(json.dumps(trainer.driver.tensor_to_numeric(results)))
logger.info(json.dumps(results))
else:
logger.info(results)

@@ -235,7 +239,7 @@ class RawTextCallback(ProgressCallback):
class TqdmCallback(ProgressCallback):
"""
在训练过程中打印 tqdm progress bar 的 callback 。在 Trainer 中,默认就会使用这个 callback 来显示进度。如果需要定制这个 Callback 的
参数,请通过实例化本 Callback 并传入到 Trainer 中实现。
参数,请通过实例化本 Callback 并传入到 Trainer 中实现。在打印 evaluate 的结果时,不会打印名称以 "_" 开头的内容。

:param print_every: 多少个 batch 更新一次显示。
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字
@@ -309,8 +313,10 @@ class TqdmCallback(ProgressCallback):
text = '-'*self.num_signs + base_text + '-'*self.num_signs

logger.info(text)
results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if
not key.startswith('_')}
if self.format_json:
logger.info(json.dumps(trainer.driver.tensor_to_numeric(results)))
logger.info(json.dumps(results))
else:
logger.info(results)



+ 1
- 1
fastNLP/core/utils/utils.py View File

@@ -630,7 +630,7 @@ def is_notebook():

def flat_nest_dict(d:Dict, separator:str='#', compress_none_key:bool=True, top_down:bool=False) -> Dict:
"""
一个 nested 的 dict 转成 flat 的 dict,例如
一个 nested 的 dict 转成 flat 的 dict,例如
ex::
d = {'test': {'f1': {'f': 0.2, 'rec': 0.1}}} -> {'f#f1#test':0.2, 'rec#f1#test':0.1}



+ 12
- 8
fastNLP/io/data_bundle.py View File

@@ -245,8 +245,9 @@ class DataBundle:
"""
_progress_desc = progress_desc
for name, dataset in self.datasets.items():
if _progress_desc:
progress_desc = _progress_desc + f' for `{name}`'
if len(_progress_desc) == 0:
_progress_desc = 'Processing'
progress_desc = _progress_desc + f' for `{name}`'
if dataset.has_field(field_name=field_name):
dataset.apply_field(func=func, field_name=field_name, new_field_name=new_field_name, num_proc=num_proc,
progress_desc=progress_desc, progress_bar=progress_bar)
@@ -284,8 +285,9 @@ class DataBundle:
res = {}
_progress_desc = progress_desc
for name, dataset in self.datasets.items():
if _progress_desc:
progress_desc = _progress_desc + f' for `{name}`'
if len(_progress_desc) == 0:
_progress_desc = 'Processing'
progress_desc = _progress_desc + f' for `{name}`'
if dataset.has_field(field_name=field_name):
res[name] = dataset.apply_field_more(func=func, field_name=field_name, num_proc=num_proc,
modify_fields=modify_fields,
@@ -317,8 +319,9 @@ class DataBundle:
"""
_progress_desc = progress_desc
for name, dataset in self.datasets.items():
if _progress_desc:
progress_desc = _progress_desc + f' for `{name}`'
if len(_progress_desc) == 0:
_progress_desc = 'Processing'
progress_desc = _progress_desc + f' for `{name}`'
dataset.apply(func, new_field_name=new_field_name, num_proc=num_proc, progress_bar=progress_bar,
progress_desc=progress_desc)
return self
@@ -349,8 +352,9 @@ class DataBundle:
res = {}
_progress_desc = progress_desc
for name, dataset in self.datasets.items():
if _progress_desc:
progress_desc = _progress_desc + f' for `{name}`'
if len(_progress_desc) == 0:
_progress_desc = 'Processing'
progress_desc = _progress_desc + f' for `{name}`'
res[name] = dataset.apply_more(func, modify_fields=modify_fields, num_proc=num_proc,
progress_bar=progress_bar, progress_desc=progress_desc)
return res


Loading…
Cancel
Save