Browse Source

[ to #43850241] fix json dump numpy

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9703595

    * fix json dump numpy

* Merge remote-tracking branch 'origin' into fix/josn_dump
master
jiangnana.jnn 3 years ago
parent
commit
dbacead74e
3 changed files with 21 additions and 3 deletions
  1. +1
    -1
      modelscope/trainers/hooks/logger/base.py
  2. +3
    -2
      modelscope/trainers/hooks/logger/text_logger_hook.py
  3. +17
    -0
      modelscope/utils/json_utils.py

+ 1
- 1
modelscope/trainers/hooks/logger/base.py View File

@@ -15,7 +15,7 @@ class LoggerHook(Hook):
"""Base class for logger hooks. """Base class for logger hooks.


Args: Args:
interval (int): Logging interval (every k iterations).
interval (int): Logging interval (every k iterations). It is interval of iterations even by_epoch is true.
ignore_last (bool): Ignore the log of last iterations in each epoch ignore_last (bool): Ignore the log of last iterations in each epoch
if less than `interval`. if less than `interval`.
reset_flag (bool): Whether to clear the output buffer after logging. reset_flag (bool): Whether to clear the output buffer after logging.


+ 3
- 2
modelscope/trainers/hooks/logger/text_logger_hook.py View File

@@ -12,6 +12,7 @@ from modelscope.metainfo import Hooks
from modelscope.trainers.hooks.builder import HOOKS from modelscope.trainers.hooks.builder import HOOKS
from modelscope.trainers.hooks.logger.base import LoggerHook from modelscope.trainers.hooks.logger.base import LoggerHook
from modelscope.utils.constant import LogKeys, ModeKeys from modelscope.utils.constant import LogKeys, ModeKeys
from modelscope.utils.json_utils import EnhancedEncoder
from modelscope.utils.torch_utils import get_dist_info, is_master from modelscope.utils.torch_utils import get_dist_info, is_master




@@ -23,7 +24,7 @@ class TextLoggerHook(LoggerHook):
by_epoch (bool, optional): Whether EpochBasedtrainer is used. by_epoch (bool, optional): Whether EpochBasedtrainer is used.
Default: True. Default: True.
interval (int, optional): Logging interval (every k iterations). interval (int, optional): Logging interval (every k iterations).
Default: 10.
It is interval of iterations even by_epoch is true. Default: 10.
ignore_last (bool, optional): Ignore the log of last iterations in each ignore_last (bool, optional): Ignore the log of last iterations in each
epoch if less than :attr:`interval`. Default: True. epoch if less than :attr:`interval`. Default: True.
reset_flag (bool, optional): Whether to clear the output buffer after reset_flag (bool, optional): Whether to clear the output buffer after
@@ -142,7 +143,7 @@ class TextLoggerHook(LoggerHook):


if is_master(): if is_master():
with open(self.json_log_path, 'a+') as f: with open(self.json_log_path, 'a+') as f:
json.dump(json_log, f)
json.dump(json_log, f, cls=EnhancedEncoder)
f.write('\n') f.write('\n')


def _round_float(self, items, ndigits=5): def _round_float(self, items, ndigits=5):


+ 17
- 0
modelscope/utils/json_utils.py View File

@@ -0,0 +1,17 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import json
import numpy as np


class EnhancedEncoder(json.JSONEncoder):
""" Enhanced json encoder for not supported types """

def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
return json.JSONEncoder.default(self, obj)

Loading…
Cancel
Save