From 1168b9dc243619232963eef11a16068099e9c0e4 Mon Sep 17 00:00:00 2001 From: yunfan Date: Sun, 18 Aug 2019 17:55:28 +0800 Subject: [PATCH] [update] logger in trainer & tester --- fastNLP/io/logger.py | 51 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 45 insertions(+), 6 deletions(-) diff --git a/fastNLP/io/logger.py b/fastNLP/io/logger.py index 287bdbc9..6bdf693d 100644 --- a/fastNLP/io/logger.py +++ b/fastNLP/io/logger.py @@ -6,6 +6,9 @@ import os import sys import warnings + +__all__ = ['logger'] + try: import fitlog except ImportError: @@ -36,11 +39,7 @@ else: self.setLevel(level) -def init_logger(path=None, stdout='tqdm', level='INFO'): - """initialize logger""" - if stdout not in ['none', 'plain', 'tqdm']: - raise ValueError('stdout must in one of {}'.format(['none', 'plain', 'tqdm'])) - +def get_level(level): if isinstance(level, int): pass else: @@ -48,6 +47,15 @@ def init_logger(path=None, stdout='tqdm', level='INFO'): level = {'info': logging.INFO, 'debug': logging.DEBUG, 'warn': logging.WARN, 'warning': logging.WARN, 'error': logging.ERROR}[level] + return level + + +def init_logger(path=None, stdout='tqdm', level='INFO'): + """initialize logger""" + if stdout not in ['none', 'plain', 'tqdm']: + raise ValueError('stdout must in one of {}'.format(['none', 'plain', 'tqdm'])) + + level = get_level(level) logger = logging.getLogger('fastNLP') logger.setLevel(level) @@ -85,4 +93,35 @@ def init_logger(path=None, stdout='tqdm', level='INFO'): return logger -get_logger = logging.getLogger + +# init logger when import +logger = init_logger() + + +def get_logger(name=None): + if name is None: + return logging.getLogger('fastNLP') + return logging.getLogger(name) + + +def set_file(path, level='INFO'): + for h in logger.handlers: + if isinstance(h, logging.FileHandler): + if os.path.abspath(path) == h.baseFilename: + # file path already added + return + + # File Handler + if os.path.exists(path): + assert os.path.isfile(path) + warnings.warn('log already exists in {}'.format(path)) + dirname = os.path.abspath(os.path.dirname(path)) + os.makedirs(dirname, exist_ok=True) + + file_handler = logging.FileHandler(path, mode='a') + file_handler.setLevel(get_level(level)) + file_formatter = logging.Formatter(fmt='%(asctime)s - [%(levelname)s] - %(name)s - %(message)s', + datefmt='%Y/%m/%d %H:%M:%S') + file_handler.setFormatter(file_formatter) + logger.addHandler(file_handler) +