From 3b8bc469ba752873333c9fe15bc6b144efe3251d Mon Sep 17 00:00:00 2001 From: yunfan Date: Mon, 19 Aug 2019 14:22:58 +0800 Subject: [PATCH] [update] logger, support straightly import logger to use --- fastNLP/core/callback.py | 4 +- fastNLP/core/dist_trainer.py | 8 +- fastNLP/core/tester.py | 5 +- fastNLP/core/trainer.py | 10 +- fastNLP/io/__init__.py | 3 + fastNLP/io/{logger.py => _logger.py} | 120 ++++++++++-------- .../text_classification/train_dpcnn.py | 17 ++- 7 files changed, 93 insertions(+), 74 deletions(-) rename fastNLP/io/{logger.py => _logger.py} (62%) diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index 17ded171..53767011 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -86,7 +86,7 @@ except: from ..io.model_io import ModelSaver, ModelLoader from .dataset import DataSet from .tester import Tester -import logging +from ..io import logger try: import fitlog @@ -178,7 +178,7 @@ class Callback(object): @property def logger(self): - return getattr(self._trainer, 'logger', logging.getLogger(__name__)) + return getattr(self._trainer, 'logger', logger) def on_train_begin(self): """ diff --git a/fastNLP/core/dist_trainer.py b/fastNLP/core/dist_trainer.py index e14e17c8..8ad282c9 100644 --- a/fastNLP/core/dist_trainer.py +++ b/fastNLP/core/dist_trainer.py @@ -9,7 +9,6 @@ from torch.utils.data.distributed import DistributedSampler from torch.nn.parallel import DistributedDataParallel as DDP import os from tqdm import tqdm -import logging import time from datetime import datetime, timedelta from functools import partial @@ -22,7 +21,8 @@ from .optimizer import Optimizer from .utils import _build_args from .utils import _move_dict_value_to_device from .utils import _get_func_signature -from ..io.logger import init_logger +from ..io import logger +import logging from pkg_resources import parse_version __all__ = [ @@ -140,8 +140,8 @@ class DistTrainer(): self.cp_save_path = None # use INFO in the master, WARN for others - init_logger(log_path, level=logging.INFO if self.is_master else logging.WARNING) - self.logger = logging.getLogger(__name__) + logger.setLevel(logging.INFO if self.is_master else logging.WARNING) + self.logger = logger self.logger.info("Setup Distributed Trainer") self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format( os.getpid(), self.rank, self.local_rank, self.device, self.fp16 if self.fp16 else False)) diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index 10696240..ab86fb62 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -56,7 +56,7 @@ from .utils import _move_model_to_device from ._parallel_utils import _data_parallel_wrapper from ._parallel_utils import _model_contains_inner_module from functools import partial -from ..io.logger import init_logger, get_logger +from ..io import logger __all__ = [ "Tester" @@ -104,8 +104,7 @@ class Tester(object): self.batch_size = batch_size self.verbose = verbose self.use_tqdm = use_tqdm - init_logger(stdout='tqdm' if use_tqdm else 'plain') - self.logger = get_logger(__name__) + self.logger = logger if isinstance(data, DataSet): self.data_iterator = DataSetIter( diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index d71e23f5..783997a7 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -353,7 +353,7 @@ from .utils import _get_func_signature from .utils import _get_model_device from .utils import _move_model_to_device from ._parallel_utils import _model_contains_inner_module -from ..io.logger import init_logger, get_logger +from ..io import logger class Trainer(object): @@ -548,11 +548,7 @@ class Trainer(object): else: raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer))) - log_path = None - if save_path is not None: - log_path = os.path.join(os.path.dirname(save_path), 'log') - init_logger(path=log_path, stdout='tqdm' if use_tqdm else 'plain') - self.logger = get_logger(__name__) + self.logger = logger self.use_tqdm = use_tqdm self.pbar = None @@ -701,7 +697,7 @@ class Trainer(object): self.n_steps) + \ self.tester._format_eval_results(eval_res) # pbar.write(eval_str + '\n') - self.logger.info(eval_str) + self.logger.info(eval_str + '\n') # ================= mini-batch end ==================== # # lr decay; early stopping diff --git a/fastNLP/io/__init__.py b/fastNLP/io/__init__.py index f8c55bf5..a19428d3 100644 --- a/fastNLP/io/__init__.py +++ b/fastNLP/io/__init__.py @@ -72,6 +72,8 @@ __all__ = [ 'ModelLoader', 'ModelSaver', + + 'logger', ] from .embed_loader import EmbedLoader @@ -81,3 +83,4 @@ from .model_io import ModelLoader, ModelSaver from .loader import * from .pipe import * +from ._logger import * diff --git a/fastNLP/io/logger.py b/fastNLP/io/_logger.py similarity index 62% rename from fastNLP/io/logger.py rename to fastNLP/io/_logger.py index 6bdf693d..73c47d42 100644 --- a/fastNLP/io/logger.py +++ b/fastNLP/io/_logger.py @@ -6,8 +6,11 @@ import os import sys import warnings +__all__ = [ + 'logger', +] -__all__ = ['logger'] +ROOT_NAME = 'fastNLP' try: import fitlog @@ -39,7 +42,7 @@ else: self.setLevel(level) -def get_level(level): +def _get_level(level): if isinstance(level, int): pass else: @@ -50,22 +53,45 @@ def get_level(level): return level -def init_logger(path=None, stdout='tqdm', level='INFO'): - """initialize logger""" - if stdout not in ['none', 'plain', 'tqdm']: - raise ValueError('stdout must in one of {}'.format(['none', 'plain', 'tqdm'])) +def _add_file_handler(logger, path, level='INFO'): + for h in logger.handlers: + if isinstance(h, logging.FileHandler): + if os.path.abspath(path) == h.baseFilename: + # file path already added + return - level = get_level(level) + # File Handler + if os.path.exists(path): + assert os.path.isfile(path) + warnings.warn('log already exists in {}'.format(path)) + dirname = os.path.abspath(os.path.dirname(path)) + os.makedirs(dirname, exist_ok=True) - logger = logging.getLogger('fastNLP') - logger.setLevel(level) - handlers_type = set([type(h) for h in logger.handlers]) + file_handler = logging.FileHandler(path, mode='a') + file_handler.setLevel(_get_level(level)) + file_formatter = logging.Formatter(fmt='%(asctime)s - [%(levelname)s] - %(message)s', + datefmt='%Y/%m/%d %H:%M:%S') + file_handler.setFormatter(file_formatter) + logger.addHandler(file_handler) + +def _set_stdout_handler(logger, stdout='tqdm', level='INFO'): + level = _get_level(level) + if stdout not in ['none', 'plain', 'tqdm']: + raise ValueError('stdout must in one of {}'.format(['none', 'plain', 'tqdm'])) # make sure to initialize logger only once + stream_handler = None + for i, h in enumerate(logger.handlers): + if isinstance(h, (logging.StreamHandler, TqdmLoggingHandler)): + stream_handler = h + break + if stream_handler is not None: + logger.removeHandler(stream_handler) + # Stream Handler - if stdout == 'plain' and (logging.StreamHandler not in handlers_type): + if stdout == 'plain': stream_handler = logging.StreamHandler(sys.stdout) - elif stdout == 'tqdm' and (TqdmLoggingHandler not in handlers_type): + elif stdout == 'tqdm': stream_handler = TqdmLoggingHandler(level) else: stream_handler = None @@ -76,52 +102,44 @@ def init_logger(path=None, stdout='tqdm', level='INFO'): stream_handler.setFormatter(stream_formatter) logger.addHandler(stream_handler) - # File Handler - if path is not None and (logging.FileHandler not in handlers_type): - if os.path.exists(path): - assert os.path.isfile(path) - warnings.warn('log already exists in {}'.format(path)) - dirname = os.path.abspath(os.path.dirname(path)) - os.makedirs(dirname, exist_ok=True) - - file_handler = logging.FileHandler(path, mode='a') - file_handler.setLevel(level) - file_formatter = logging.Formatter(fmt='%(asctime)s - [%(levelname)s] - %(name)s - %(message)s', - datefmt='%Y/%m/%d %H:%M:%S') - file_handler.setFormatter(file_formatter) - logger.addHandler(file_handler) - return logger +def _init_logger(path=None, stdout='tqdm', level='INFO'): + """initialize logger""" + level = _get_level(level) + # logger = logging.getLogger(ROOT_NAME) + logger = logging.getLogger() + logger.setLevel(level) -# init logger when import -logger = init_logger() + _set_stdout_handler(logger, stdout, level) + # File Handler + if path is not None: + _add_file_handler(logger, path, level) -def get_logger(name=None): - if name is None: - return logging.getLogger('fastNLP') - return logging.getLogger(name) + return logger -def set_file(path, level='INFO'): - for h in logger.handlers: - if isinstance(h, logging.FileHandler): - if os.path.abspath(path) == h.baseFilename: - # file path already added - return +def _get_logger(name=None, level='INFO'): + level = _get_level(level) + if name is None: + name = ROOT_NAME + assert isinstance(name, str) + if not name.startswith(ROOT_NAME): + name = '{}.{}'.format(ROOT_NAME, name) + logger = logging.getLogger(name) + logger.setLevel(level) + return logger - # File Handler - if os.path.exists(path): - assert os.path.isfile(path) - warnings.warn('log already exists in {}'.format(path)) - dirname = os.path.abspath(os.path.dirname(path)) - os.makedirs(dirname, exist_ok=True) - file_handler = logging.FileHandler(path, mode='a') - file_handler.setLevel(get_level(level)) - file_formatter = logging.Formatter(fmt='%(asctime)s - [%(levelname)s] - %(name)s - %(message)s', - datefmt='%Y/%m/%d %H:%M:%S') - file_handler.setFormatter(file_formatter) - logger.addHandler(file_handler) +class FastNLPLogger(logging.Logger): + def add_file(self, path, level): + _add_file_handler(self, path, level) + + def set_stdout(self, stdout, level): + _set_stdout_handler(self, stdout, level) +_logger = _init_logger(path=None) +logger = FastNLPLogger(ROOT_NAME) +logger.__dict__.update(_logger.__dict__) +del _logger diff --git a/reproduction/text_classification/train_dpcnn.py b/reproduction/text_classification/train_dpcnn.py index 99e27640..704b9f43 100644 --- a/reproduction/text_classification/train_dpcnn.py +++ b/reproduction/text_classification/train_dpcnn.py @@ -15,13 +15,14 @@ from fastNLP.core.const import Const as C from fastNLP.core.vocabulary import VocabularyOption from fastNLP.core.dist_trainer import DistTrainer from utils.util_init import set_rng_seeds +from fastNLP.io import logger import os # os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' # os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" - # hyper +logger.add_file('log', 'INFO') class Config(): seed = 12345 @@ -46,11 +47,11 @@ class Config(): self.datapath = {k: os.path.join(self.datadir, v) for k, v in self.datafile.items()} - ops = Config() set_rng_seeds(ops.seed) -print('RNG SEED: {}'.format(ops.seed)) +# print('RNG SEED: {}'.format(ops.seed)) +logger.info('RNG SEED %d'%ops.seed) # 1.task相关信息:利用dataloader载入dataInfo @@ -81,8 +82,9 @@ print(embedding.embedding.weight.data.mean(), embedding.embedding.weight.data.st # embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)]) datainfo.datasets['train'] = datainfo.datasets['train'][:1000] datainfo.datasets['test'] = datainfo.datasets['test'][:1000] -print(datainfo) -print(datainfo.datasets['train'][0]) +# print(datainfo) +# print(datainfo.datasets['train'][0]) +logger.info(datainfo) model = DPCNN(init_embed=embedding, num_cls=len(datainfo.vocabs[C.TARGET]), embed_dropout=ops.embed_dropout, cls_dropout=ops.cls_dropout) @@ -108,12 +110,13 @@ callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5))) device = 'cuda:0' if torch.cuda.is_available() else 'cpu' -print(device) +# print(device) +logger.info(device) # 4.定义train方法 trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size), - metrics=[metric], use_tqdm=False, + metrics=[metric], use_tqdm=False, save_path='save', dev_data=datainfo.datasets['test'], device=device, check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, n_epochs=ops.train_epoch, num_workers=4)