@@ -9,6 +9,7 @@ import torch
from torch import distributed as dist
from torch import distributed as dist
from modelscope.metainfo import Hooks
from modelscope.metainfo import Hooks
from modelscope.outputs import OutputKeys
from modelscope.trainers.hooks.builder import HOOKS
from modelscope.trainers.hooks.builder import HOOKS
from modelscope.trainers.hooks.logger.base import LoggerHook
from modelscope.trainers.hooks.logger.base import LoggerHook
from modelscope.utils.constant import LogKeys, ModeKeys
from modelscope.utils.constant import LogKeys, ModeKeys
@@ -30,6 +31,8 @@ class TextLoggerHook(LoggerHook):
reset_flag (bool, optional): Whether to clear the output buffer after
reset_flag (bool, optional): Whether to clear the output buffer after
logging. Default: False.
logging. Default: False.
out_dir (str): The directory to save log. If is None, use `trainer.work_dir`
out_dir (str): The directory to save log. If is None, use `trainer.work_dir`
ignore_rounding_keys (`Union[str, List]`): The keys to ignore float rounding, default 'lr'
rounding_digits (`int`): The digits of rounding, exceeding parts will be ignored.
"""
"""
def __init__(self,
def __init__(self,
@@ -37,13 +40,20 @@ class TextLoggerHook(LoggerHook):
interval=10,
interval=10,
ignore_last=True,
ignore_last=True,
reset_flag=False,
reset_flag=False,
out_dir=None):
out_dir=None,
ignore_rounding_keys='lr',
rounding_digits=5):
super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag,
super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag,
by_epoch)
by_epoch)
self.by_epoch = by_epoch
self.by_epoch = by_epoch
self.time_sec_tot = 0
self.time_sec_tot = 0
self.out_dir = out_dir
self.out_dir = out_dir
self._logged_keys = [] # store the key has been logged
self._logged_keys = [] # store the key has been logged
if isinstance(ignore_rounding_keys,
str) or ignore_rounding_keys is None:
ignore_rounding_keys = [ignore_rounding_keys]
self.ignore_rounding_keys = ignore_rounding_keys
self.rounding_digits = rounding_digits
def before_run(self, trainer):
def before_run(self, trainer):
super(TextLoggerHook, self).before_run(trainer)
super(TextLoggerHook, self).before_run(trainer)
@@ -139,7 +149,9 @@ class TextLoggerHook(LoggerHook):
# dump log in json format
# dump log in json format
json_log = OrderedDict()
json_log = OrderedDict()
for k, v in log_dict.items():
for k, v in log_dict.items():
json_log[k] = self._round_float(v)
json_log[
k] = v if k in self.ignore_rounding_keys else self._round_float(
v, self.rounding_digits)
if is_master():
if is_master():
with open(self.json_log_path, 'a+') as f:
with open(self.json_log_path, 'a+') as f:
@@ -148,7 +160,7 @@ class TextLoggerHook(LoggerHook):
def _round_float(self, items, ndigits=5):
def _round_float(self, items, ndigits=5):
if isinstance(items, list):
if isinstance(items, list):
return [self._round_float(item) for item in items]
return [self._round_float(item, ndigits ) for item in items]
elif isinstance(items, float):
elif isinstance(items, float):
return round(items, ndigits)
return round(items, ndigits)
else:
else: