You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tl_logging.py 6.9 kB

4 years ago

  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. import logging as _logging
  4. import os as _os
  5. import sys as _sys
  6. import threading
  7. import time as _time
  8. from logging import DEBUG, ERROR, FATAL, INFO, WARN
  9. import six
  10. from tensorlayer.decorators import deprecated
  11. __all__ = [
  12. 'DEBUG',
  13. 'debug',
  14. 'ERROR',
  15. 'error',
  16. 'FATAL',
  17. 'fatal',
  18. 'INFO',
  19. 'info',
  20. 'WARN',
  21. 'warning',
  22. 'warn', # Deprecated
  23. 'set_verbosity',
  24. 'get_verbosity'
  25. ]
  26. # Don't use this directly. Use _get_logger() instead.
  27. _logger = None
  28. _logger_lock = threading.Lock()
  29. _level_names = {
  30. FATAL: 'FATAL',
  31. ERROR: 'ERROR',
  32. WARN: 'WARN',
  33. INFO: 'INFO',
  34. DEBUG: 'DEBUG',
  35. }
  36. def _get_logger():
  37. global _logger
  38. # Use double-checked locking to avoid taking lock unnecessarily.
  39. if _logger is not None:
  40. return _logger
  41. _logger_lock.acquire()
  42. try:
  43. if _logger:
  44. return _logger
  45. # Scope the TensorFlow logger to not conflict with users' loggers.
  46. logger = _logging.getLogger('tensorlayer')
  47. # Don't further configure the TensorFlow logger if the root logger is
  48. # already configured. This prevents double logging in those cases.
  49. if not _logging.getLogger().handlers:
  50. # Determine whether we are in an interactive environment
  51. # This is only defined in interactive shells.
  52. if hasattr(_sys, "ps1"):
  53. _interactive = True
  54. else:
  55. _interactive = _sys.flags.interactive
  56. # If we are in an interactive environment (like Jupyter), set loglevel
  57. # to INFO and pipe the output to stdout.
  58. if _interactive:
  59. logger.setLevel(INFO)
  60. _logging_target = _sys.stdout
  61. else:
  62. _logging_target = _sys.stderr
  63. # Add the output handler.
  64. _handler = _logging.StreamHandler(_logging_target)
  65. _handler.setFormatter(_logging.Formatter('[TL] %(message)s'))
  66. logger.addHandler(_handler)
  67. _logger = logger
  68. return _logger
  69. finally:
  70. _logger_lock.release()
  71. def log(level, msg, *args, **kwargs):
  72. _get_logger().log(level, msg, *args, **kwargs)
  73. def debug(msg, *args, **kwargs):
  74. _get_logger().debug(msg, *args, **kwargs)
  75. def info(msg, *args, **kwargs):
  76. _get_logger().info(msg, *args, **kwargs)
  77. def error(msg, *args, **kwargs):
  78. _get_logger().error("ERROR: %s" % msg, *args, **kwargs)
  79. def fatal(msg, *args, **kwargs):
  80. _get_logger().fatal("FATAL: %s" % msg, *args, **kwargs)
  81. @deprecated(date="2018-09-30", instructions="This API is deprecated. Please use as `tl.logging.warning`")
  82. def warn(msg, *args, **kwargs):
  83. warning(msg, *args, **kwargs)
  84. def warning(msg, *args, **kwargs):
  85. _get_logger().warning("WARNING: %s" % msg, *args, **kwargs)
  86. # Mask to convert integer thread ids to unsigned quantities for logging
  87. # purposes
  88. _THREAD_ID_MASK = 2 * _sys.maxsize + 1
  89. _log_prefix = None # later set to google2_log_prefix
  90. # Counter to keep track of number of log entries per token.
  91. _log_counter_per_token = {}
  92. def TaskLevelStatusMessage(msg):
  93. error(msg)
  94. def flush():
  95. raise NotImplementedError()
  96. def vlog(level, msg, *args, **kwargs):
  97. _get_logger().log(level, msg, *args, **kwargs)
  98. def _GetNextLogCountPerToken(token):
  99. """Wrapper for _log_counter_per_token.
  100. Args:
  101. token: The token for which to look up the count.
  102. Returns:
  103. The number of times this function has been called with
  104. *token* as an argument (starting at 0)
  105. """
  106. global _log_counter_per_token # pylint: disable=global-variable-not-assigned
  107. _log_counter_per_token[token] = 1 + _log_counter_per_token.get(token, -1)
  108. return _log_counter_per_token[token]
  109. def log_every_n(level, msg, n, *args):
  110. """Log 'msg % args' at level 'level' once per 'n' times.
  111. Logs the 1st call, (N+1)st call, (2N+1)st call, etc.
  112. Not threadsafe.
  113. Args:
  114. level: The level at which to log.
  115. msg: The message to be logged.
  116. n: The number of times this should be called before it is logged.
  117. *args: The args to be substituted into the msg.
  118. """
  119. count = _GetNextLogCountPerToken(_GetFileAndLine())
  120. log_if(level, msg, not (count % n), *args)
  121. def log_first_n(level, msg, n, *args): # pylint: disable=g-bad-name
  122. """Log 'msg % args' at level 'level' only first 'n' times.
  123. Not threadsafe.
  124. Args:
  125. level: The level at which to log.
  126. msg: The message to be logged.
  127. n: The number of times this should be called before it is logged.
  128. *args: The args to be substituted into the msg.
  129. """
  130. count = _GetNextLogCountPerToken(_GetFileAndLine())
  131. log_if(level, msg, count < n, *args)
  132. def log_if(level, msg, condition, *args):
  133. """Log 'msg % args' at level 'level' only if condition is fulfilled."""
  134. if condition:
  135. vlog(level, msg, *args)
  136. def _GetFileAndLine():
  137. """Returns (filename, linenumber) for the stack frame."""
  138. # Use sys._getframe(). This avoids creating a traceback object.
  139. # pylint: disable=protected-access
  140. f = _sys._getframe()
  141. # pylint: enable=protected-access
  142. our_file = f.f_code.co_filename
  143. f = f.f_back
  144. while f:
  145. code = f.f_code
  146. if code.co_filename != our_file:
  147. return (code.co_filename, f.f_lineno)
  148. f = f.f_back
  149. return ('<unknown>', 0)
  150. def google2_log_prefix(level, timestamp=None, file_and_line=None):
  151. """Assemble a logline prefix using the google2 format."""
  152. # pylint: disable=global-variable-not-assigned
  153. global _level_names
  154. # pylint: enable=global-variable-not-assigned
  155. # Record current time
  156. now = timestamp or _time.time()
  157. now_tuple = _time.localtime(now)
  158. now_microsecond = int(1e6 * (now % 1.0))
  159. (filename, line) = file_and_line or _GetFileAndLine()
  160. basename = _os.path.basename(filename)
  161. # Severity string
  162. severity = 'I'
  163. if level in _level_names:
  164. severity = _level_names[level][0]
  165. s = '%c%02d%02d %02d: %02d: %02d.%06d %5d %s: %d] ' % (
  166. severity,
  167. now_tuple[1], # month
  168. now_tuple[2], # day
  169. now_tuple[3], # hour
  170. now_tuple[4], # min
  171. now_tuple[5], # sec
  172. now_microsecond,
  173. _get_thread_id(),
  174. basename,
  175. line
  176. )
  177. return s
  178. def get_verbosity():
  179. """Return how much logging output will be produced."""
  180. return _get_logger().getEffectiveLevel()
  181. def set_verbosity(v):
  182. """Sets the threshold for what messages will be logged."""
  183. _get_logger().setLevel(v)
  184. def _get_thread_id():
  185. """Get id of current thread, suitable for logging as an unsigned quantity."""
  186. # pylint: disable=protected-access
  187. thread_id = six.moves._thread.get_ident()
  188. # pylint:enable=protected-access
  189. return thread_id & _THREAD_ID_MASK
  190. _log_prefix = google2_log_prefix

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.