Browse Source

[update] logger in trainer & tester

tags/v0.4.10
yunfan 5 years ago
parent
commit
1168b9dc24
1 changed files with 45 additions and 6 deletions
  1. +45
    -6
      fastNLP/io/logger.py

+ 45
- 6
fastNLP/io/logger.py View File

@@ -6,6 +6,9 @@ import os
import sys import sys
import warnings import warnings



__all__ = ['logger']

try: try:
import fitlog import fitlog
except ImportError: except ImportError:
@@ -36,11 +39,7 @@ else:
self.setLevel(level) self.setLevel(level)




def init_logger(path=None, stdout='tqdm', level='INFO'):
"""initialize logger"""
if stdout not in ['none', 'plain', 'tqdm']:
raise ValueError('stdout must in one of {}'.format(['none', 'plain', 'tqdm']))

def get_level(level):
if isinstance(level, int): if isinstance(level, int):
pass pass
else: else:
@@ -48,6 +47,15 @@ def init_logger(path=None, stdout='tqdm', level='INFO'):
level = {'info': logging.INFO, 'debug': logging.DEBUG, level = {'info': logging.INFO, 'debug': logging.DEBUG,
'warn': logging.WARN, 'warning': logging.WARN, 'warn': logging.WARN, 'warning': logging.WARN,
'error': logging.ERROR}[level] 'error': logging.ERROR}[level]
return level


def init_logger(path=None, stdout='tqdm', level='INFO'):
"""initialize logger"""
if stdout not in ['none', 'plain', 'tqdm']:
raise ValueError('stdout must in one of {}'.format(['none', 'plain', 'tqdm']))

level = get_level(level)


logger = logging.getLogger('fastNLP') logger = logging.getLogger('fastNLP')
logger.setLevel(level) logger.setLevel(level)
@@ -85,4 +93,35 @@ def init_logger(path=None, stdout='tqdm', level='INFO'):


return logger return logger


get_logger = logging.getLogger

# init logger when import
logger = init_logger()


def get_logger(name=None):
if name is None:
return logging.getLogger('fastNLP')
return logging.getLogger(name)


def set_file(path, level='INFO'):
for h in logger.handlers:
if isinstance(h, logging.FileHandler):
if os.path.abspath(path) == h.baseFilename:
# file path already added
return

# File Handler
if os.path.exists(path):
assert os.path.isfile(path)
warnings.warn('log already exists in {}'.format(path))
dirname = os.path.abspath(os.path.dirname(path))
os.makedirs(dirname, exist_ok=True)

file_handler = logging.FileHandler(path, mode='a')
file_handler.setLevel(get_level(level))
file_formatter = logging.Formatter(fmt='%(asctime)s - [%(levelname)s] - %(name)s - %(message)s',
datefmt='%Y/%m/%d %H:%M:%S')
file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler)


Loading…
Cancel
Save