Browse Source

Fix a bug that the logging file cannot save the correct lr, which is zero instead

This bug is a result of float rounding when saving key-value pairs to log files, which is reported by a user.
Now the solution is to remove the rounding operation of all values, instead of only the lr value, which I think may be too specific.

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10684029
master^2
yuze.zyz wenmeng.zwm 2 years ago
parent
commit
fde8644883
2 changed files with 16 additions and 3 deletions
  1. +15
    -3
      modelscope/trainers/hooks/logger/text_logger_hook.py
  2. +1
    -0
      tests/trainers/easycv/test_easycv_trainer.py

+ 15
- 3
modelscope/trainers/hooks/logger/text_logger_hook.py View File

@@ -9,6 +9,7 @@ import torch
from torch import distributed as dist

from modelscope.metainfo import Hooks
from modelscope.outputs import OutputKeys
from modelscope.trainers.hooks.builder import HOOKS
from modelscope.trainers.hooks.logger.base import LoggerHook
from modelscope.utils.constant import LogKeys, ModeKeys
@@ -30,6 +31,8 @@ class TextLoggerHook(LoggerHook):
reset_flag (bool, optional): Whether to clear the output buffer after
logging. Default: False.
out_dir (str): The directory to save log. If is None, use `trainer.work_dir`
ignore_rounding_keys (`Union[str, List]`): The keys to ignore float rounding, default 'lr'
rounding_digits (`int`): The digits of rounding, exceeding parts will be ignored.
"""

def __init__(self,
@@ -37,13 +40,20 @@ class TextLoggerHook(LoggerHook):
interval=10,
ignore_last=True,
reset_flag=False,
out_dir=None):
out_dir=None,
ignore_rounding_keys='lr',
rounding_digits=5):
super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag,
by_epoch)
self.by_epoch = by_epoch
self.time_sec_tot = 0
self.out_dir = out_dir
self._logged_keys = [] # store the key has been logged
if isinstance(ignore_rounding_keys,
str) or ignore_rounding_keys is None:
ignore_rounding_keys = [ignore_rounding_keys]
self.ignore_rounding_keys = ignore_rounding_keys
self.rounding_digits = rounding_digits

def before_run(self, trainer):
super(TextLoggerHook, self).before_run(trainer)
@@ -139,7 +149,9 @@ class TextLoggerHook(LoggerHook):
# dump log in json format
json_log = OrderedDict()
for k, v in log_dict.items():
json_log[k] = self._round_float(v)
json_log[
k] = v if k in self.ignore_rounding_keys else self._round_float(
v, self.rounding_digits)

if is_master():
with open(self.json_log_path, 'a+') as f:
@@ -148,7 +160,7 @@ class TextLoggerHook(LoggerHook):

def _round_float(self, items, ndigits=5):
if isinstance(items, list):
return [self._round_float(item) for item in items]
return [self._round_float(item, ndigits) for item in items]
elif isinstance(items, float):
return round(items, ndigits)
else:


+ 1
- 0
tests/trainers/easycv/test_easycv_trainer.py View File

@@ -70,6 +70,7 @@ def train_func(work_dir, dist=False, log_interval=3, imgs_per_gpu=4):
},
{
'type': 'TextLoggerHook',
'ignore_rounding_keys': None,
'interval': log_interval
},
]


Loading…
Cancel
Save