From fde86448833bd87ed08f12875ff4acd2d29e6c06 Mon Sep 17 00:00:00 2001 From: "yuze.zyz" Date: Wed, 30 Nov 2022 21:59:02 +0800 Subject: [PATCH] Fix a bug that the logging file cannot save the correct lr, which is zero instead This bug is a result of float rounding when saving key-value pairs to log files, which is reported by a user. Now the solution is to remove the rounding operation of all values, instead of only the lr value, which I think may be too specific. Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10684029 --- .../trainers/hooks/logger/text_logger_hook.py | 18 +++++++++++++++--- tests/trainers/easycv/test_easycv_trainer.py | 1 + 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/modelscope/trainers/hooks/logger/text_logger_hook.py b/modelscope/trainers/hooks/logger/text_logger_hook.py index 95644783..b317a9c0 100644 --- a/modelscope/trainers/hooks/logger/text_logger_hook.py +++ b/modelscope/trainers/hooks/logger/text_logger_hook.py @@ -9,6 +9,7 @@ import torch from torch import distributed as dist from modelscope.metainfo import Hooks +from modelscope.outputs import OutputKeys from modelscope.trainers.hooks.builder import HOOKS from modelscope.trainers.hooks.logger.base import LoggerHook from modelscope.utils.constant import LogKeys, ModeKeys @@ -30,6 +31,8 @@ class TextLoggerHook(LoggerHook): reset_flag (bool, optional): Whether to clear the output buffer after logging. Default: False. out_dir (str): The directory to save log. If is None, use `trainer.work_dir` + ignore_rounding_keys (`Union[str, List]`): The keys to ignore float rounding, default 'lr' + rounding_digits (`int`): The digits of rounding, exceeding parts will be ignored. """ def __init__(self, @@ -37,13 +40,20 @@ class TextLoggerHook(LoggerHook): interval=10, ignore_last=True, reset_flag=False, - out_dir=None): + out_dir=None, + ignore_rounding_keys='lr', + rounding_digits=5): super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag, by_epoch) self.by_epoch = by_epoch self.time_sec_tot = 0 self.out_dir = out_dir self._logged_keys = [] # store the key has been logged + if isinstance(ignore_rounding_keys, + str) or ignore_rounding_keys is None: + ignore_rounding_keys = [ignore_rounding_keys] + self.ignore_rounding_keys = ignore_rounding_keys + self.rounding_digits = rounding_digits def before_run(self, trainer): super(TextLoggerHook, self).before_run(trainer) @@ -139,7 +149,9 @@ class TextLoggerHook(LoggerHook): # dump log in json format json_log = OrderedDict() for k, v in log_dict.items(): - json_log[k] = self._round_float(v) + json_log[ + k] = v if k in self.ignore_rounding_keys else self._round_float( + v, self.rounding_digits) if is_master(): with open(self.json_log_path, 'a+') as f: @@ -148,7 +160,7 @@ class TextLoggerHook(LoggerHook): def _round_float(self, items, ndigits=5): if isinstance(items, list): - return [self._round_float(item) for item in items] + return [self._round_float(item, ndigits) for item in items] elif isinstance(items, float): return round(items, ndigits) else: diff --git a/tests/trainers/easycv/test_easycv_trainer.py b/tests/trainers/easycv/test_easycv_trainer.py index 5d714097..40b43911 100644 --- a/tests/trainers/easycv/test_easycv_trainer.py +++ b/tests/trainers/easycv/test_easycv_trainer.py @@ -70,6 +70,7 @@ def train_func(work_dir, dist=False, log_interval=3, imgs_per_gpu=4): }, { 'type': 'TextLoggerHook', + 'ignore_rounding_keys': None, 'interval': log_interval }, ]