Browse Source

Fix a bug that the logging file cannot save the correct lr, which is zero instead

This bug is a result of float rounding when saving key-value pairs to log files, which is reported by a user.
Now the solution is to remove the rounding operation of all values, instead of only the lr value, which I think may be too specific.

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10684029
master^2
yuze.zyz wenmeng.zwm 2 years ago
parent
commit
fde8644883
2 changed files with 16 additions and 3 deletions
  1. +15
    -3
      modelscope/trainers/hooks/logger/text_logger_hook.py
  2. +1
    -0
      tests/trainers/easycv/test_easycv_trainer.py

+ 15
- 3
modelscope/trainers/hooks/logger/text_logger_hook.py View File

@@ -9,6 +9,7 @@ import torch
from torch import distributed as dist from torch import distributed as dist


from modelscope.metainfo import Hooks from modelscope.metainfo import Hooks
from modelscope.outputs import OutputKeys
from modelscope.trainers.hooks.builder import HOOKS from modelscope.trainers.hooks.builder import HOOKS
from modelscope.trainers.hooks.logger.base import LoggerHook from modelscope.trainers.hooks.logger.base import LoggerHook
from modelscope.utils.constant import LogKeys, ModeKeys from modelscope.utils.constant import LogKeys, ModeKeys
@@ -30,6 +31,8 @@ class TextLoggerHook(LoggerHook):
reset_flag (bool, optional): Whether to clear the output buffer after reset_flag (bool, optional): Whether to clear the output buffer after
logging. Default: False. logging. Default: False.
out_dir (str): The directory to save log. If is None, use `trainer.work_dir` out_dir (str): The directory to save log. If is None, use `trainer.work_dir`
ignore_rounding_keys (`Union[str, List]`): The keys to ignore float rounding, default 'lr'
rounding_digits (`int`): The digits of rounding, exceeding parts will be ignored.
""" """


def __init__(self, def __init__(self,
@@ -37,13 +40,20 @@ class TextLoggerHook(LoggerHook):
interval=10, interval=10,
ignore_last=True, ignore_last=True,
reset_flag=False, reset_flag=False,
out_dir=None):
out_dir=None,
ignore_rounding_keys='lr',
rounding_digits=5):
super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag, super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag,
by_epoch) by_epoch)
self.by_epoch = by_epoch self.by_epoch = by_epoch
self.time_sec_tot = 0 self.time_sec_tot = 0
self.out_dir = out_dir self.out_dir = out_dir
self._logged_keys = [] # store the key has been logged self._logged_keys = [] # store the key has been logged
if isinstance(ignore_rounding_keys,
str) or ignore_rounding_keys is None:
ignore_rounding_keys = [ignore_rounding_keys]
self.ignore_rounding_keys = ignore_rounding_keys
self.rounding_digits = rounding_digits


def before_run(self, trainer): def before_run(self, trainer):
super(TextLoggerHook, self).before_run(trainer) super(TextLoggerHook, self).before_run(trainer)
@@ -139,7 +149,9 @@ class TextLoggerHook(LoggerHook):
# dump log in json format # dump log in json format
json_log = OrderedDict() json_log = OrderedDict()
for k, v in log_dict.items(): for k, v in log_dict.items():
json_log[k] = self._round_float(v)
json_log[
k] = v if k in self.ignore_rounding_keys else self._round_float(
v, self.rounding_digits)


if is_master(): if is_master():
with open(self.json_log_path, 'a+') as f: with open(self.json_log_path, 'a+') as f:
@@ -148,7 +160,7 @@ class TextLoggerHook(LoggerHook):


def _round_float(self, items, ndigits=5): def _round_float(self, items, ndigits=5):
if isinstance(items, list): if isinstance(items, list):
return [self._round_float(item) for item in items]
return [self._round_float(item, ndigits) for item in items]
elif isinstance(items, float): elif isinstance(items, float):
return round(items, ndigits) return round(items, ndigits)
else: else:


+ 1
- 0
tests/trainers/easycv/test_easycv_trainer.py View File

@@ -70,6 +70,7 @@ def train_func(work_dir, dist=False, log_interval=3, imgs_per_gpu=4):
}, },
{ {
'type': 'TextLoggerHook', 'type': 'TextLoggerHook',
'ignore_rounding_keys': None,
'interval': log_interval 'interval': log_interval
}, },
] ]


Loading…
Cancel
Save