Browse Source

[ to #43850241] fix json dump numpy

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9703595

    * fix json dump numpy

* Merge remote-tracking branch 'origin' into fix/josn_dump
master
jiangnana.jnn 3 years ago
parent
commit
dbacead74e
3 changed files with 21 additions and 3 deletions
  1. +1
    -1
      modelscope/trainers/hooks/logger/base.py
  2. +3
    -2
      modelscope/trainers/hooks/logger/text_logger_hook.py
  3. +17
    -0
      modelscope/utils/json_utils.py

+ 1
- 1
modelscope/trainers/hooks/logger/base.py View File

@@ -15,7 +15,7 @@ class LoggerHook(Hook):
"""Base class for logger hooks.

Args:
interval (int): Logging interval (every k iterations).
interval (int): Logging interval (every k iterations). It is interval of iterations even by_epoch is true.
ignore_last (bool): Ignore the log of last iterations in each epoch
if less than `interval`.
reset_flag (bool): Whether to clear the output buffer after logging.


+ 3
- 2
modelscope/trainers/hooks/logger/text_logger_hook.py View File

@@ -12,6 +12,7 @@ from modelscope.metainfo import Hooks
from modelscope.trainers.hooks.builder import HOOKS
from modelscope.trainers.hooks.logger.base import LoggerHook
from modelscope.utils.constant import LogKeys, ModeKeys
from modelscope.utils.json_utils import EnhancedEncoder
from modelscope.utils.torch_utils import get_dist_info, is_master


@@ -23,7 +24,7 @@ class TextLoggerHook(LoggerHook):
by_epoch (bool, optional): Whether EpochBasedtrainer is used.
Default: True.
interval (int, optional): Logging interval (every k iterations).
Default: 10.
It is interval of iterations even by_epoch is true. Default: 10.
ignore_last (bool, optional): Ignore the log of last iterations in each
epoch if less than :attr:`interval`. Default: True.
reset_flag (bool, optional): Whether to clear the output buffer after
@@ -142,7 +143,7 @@ class TextLoggerHook(LoggerHook):

if is_master():
with open(self.json_log_path, 'a+') as f:
json.dump(json_log, f)
json.dump(json_log, f, cls=EnhancedEncoder)
f.write('\n')

def _round_float(self, items, ndigits=5):


+ 17
- 0
modelscope/utils/json_utils.py View File

@@ -0,0 +1,17 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import json
import numpy as np


class EnhancedEncoder(json.JSONEncoder):
""" Enhanced json encoder for not supported types """

def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
return json.JSONEncoder.default(self, obj)

Loading…
Cancel
Save