From 007c047ae7cb0cdc80857ce9ebded3143af231a1 Mon Sep 17 00:00:00 2001 From: yunfan Date: Sun, 18 Aug 2019 13:39:56 +0800 Subject: [PATCH] [update] logger in trainer & tester --- fastNLP/core/callback.py | 9 +- fastNLP/core/dist_trainer.py | 4 +- fastNLP/core/tester.py | 6 +- fastNLP/core/trainer.py | 11 ++- fastNLP/core/utils.py | 2 +- fastNLP/io/logger.py | 88 +++++++++++++++++++ .../text_classification/train_dpcnn.py | 22 ++--- 7 files changed, 118 insertions(+), 24 deletions(-) create mode 100644 fastNLP/io/logger.py diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index 1a20f861..447186ca 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -656,10 +656,13 @@ class EvaluateCallback(Callback): for key, tester in self.testers.items(): try: eval_result = tester.test() - self.pbar.write("Evaluation on {}:".format(key)) - self.pbar.write(tester._format_eval_results(eval_result)) + # self.pbar.write("Evaluation on {}:".format(key)) + self.logger.info("Evaluation on {}:".format(key)) + # self.pbar.write(tester._format_eval_results(eval_result)) + self.logger.info(tester._format_eval_results(eval_result)) except Exception: - self.pbar.write("Exception happens when evaluate on DataSet named `{}`.".format(key)) + # self.pbar.write("Exception happens when evaluate on DataSet named `{}`.".format(key)) + self.logger.info("Exception happens when evaluate on DataSet named `{}`.".format(key)) class LRScheduler(Callback): diff --git a/fastNLP/core/dist_trainer.py b/fastNLP/core/dist_trainer.py index bfd0e70b..e14e17c8 100644 --- a/fastNLP/core/dist_trainer.py +++ b/fastNLP/core/dist_trainer.py @@ -22,7 +22,7 @@ from .optimizer import Optimizer from .utils import _build_args from .utils import _move_dict_value_to_device from .utils import _get_func_signature -from ..io.logger import initLogger +from ..io.logger import init_logger from pkg_resources import parse_version __all__ = [ @@ -140,7 +140,7 @@ class DistTrainer(): self.cp_save_path = None # use INFO in the master, WARN for others - initLogger(log_path, level=logging.INFO if self.is_master else logging.WARNING) + init_logger(log_path, level=logging.INFO if self.is_master else logging.WARNING) self.logger = logging.getLogger(__name__) self.logger.info("Setup Distributed Trainer") self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format( diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index 691bf2ae..10696240 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -56,6 +56,7 @@ from .utils import _move_model_to_device from ._parallel_utils import _data_parallel_wrapper from ._parallel_utils import _model_contains_inner_module from functools import partial +from ..io.logger import init_logger, get_logger __all__ = [ "Tester" @@ -103,6 +104,8 @@ class Tester(object): self.batch_size = batch_size self.verbose = verbose self.use_tqdm = use_tqdm + init_logger(stdout='tqdm' if use_tqdm else 'plain') + self.logger = get_logger(__name__) if isinstance(data, DataSet): self.data_iterator = DataSetIter( @@ -181,7 +184,8 @@ class Tester(object): end_time = time.time() test_str = f'Evaluate data in {round(end_time - start_time, 2)} seconds!' - pbar.write(test_str) + # pbar.write(test_str) + self.logger.info(test_str) pbar.close() except _CheckError as e: prev_func_signature = _get_func_signature(self._predict_func) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 83882df0..d71e23f5 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -353,8 +353,7 @@ from .utils import _get_func_signature from .utils import _get_model_device from .utils import _move_model_to_device from ._parallel_utils import _model_contains_inner_module -from ..io.logger import initLogger -import logging +from ..io.logger import init_logger, get_logger class Trainer(object): @@ -552,8 +551,8 @@ class Trainer(object): log_path = None if save_path is not None: log_path = os.path.join(os.path.dirname(save_path), 'log') - initLogger(log_path) - self.logger = logging.getLogger(__name__) + init_logger(path=log_path, stdout='tqdm' if use_tqdm else 'plain') + self.logger = get_logger(__name__) self.use_tqdm = use_tqdm self.pbar = None @@ -701,8 +700,8 @@ class Trainer(object): eval_str = "Evaluation on dev at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, self.n_steps) + \ self.tester._format_eval_results(eval_res) - pbar.write(eval_str + '\n') - + # pbar.write(eval_str + '\n') + self.logger.info(eval_str) # ================= mini-batch end ==================== # # lr decay; early stopping diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index f2826421..a49d203d 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -661,7 +661,7 @@ class _pseudo_tqdm: 当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据 """ def __init__(self, **kwargs): - self.logger = logging.getLogger() + self.logger = logging.getLogger(__name__) def write(self, info): self.logger.info(info) diff --git a/fastNLP/io/logger.py b/fastNLP/io/logger.py new file mode 100644 index 00000000..287bdbc9 --- /dev/null +++ b/fastNLP/io/logger.py @@ -0,0 +1,88 @@ +import logging +import logging.config +import torch +import _pickle as pickle +import os +import sys +import warnings + +try: + import fitlog +except ImportError: + fitlog = None +try: + from tqdm.auto import tqdm +except ImportError: + tqdm = None + +if tqdm is not None: + class TqdmLoggingHandler(logging.Handler): + def __init__(self, level=logging.INFO): + super().__init__(level) + + def emit(self, record): + try: + msg = self.format(record) + tqdm.write(msg) + self.flush() + except (KeyboardInterrupt, SystemExit): + raise + except: + self.handleError(record) +else: + class TqdmLoggingHandler(logging.StreamHandler): + def __init__(self, level=logging.INFO): + super().__init__(sys.stdout) + self.setLevel(level) + + +def init_logger(path=None, stdout='tqdm', level='INFO'): + """initialize logger""" + if stdout not in ['none', 'plain', 'tqdm']: + raise ValueError('stdout must in one of {}'.format(['none', 'plain', 'tqdm'])) + + if isinstance(level, int): + pass + else: + level = level.lower() + level = {'info': logging.INFO, 'debug': logging.DEBUG, + 'warn': logging.WARN, 'warning': logging.WARN, + 'error': logging.ERROR}[level] + + logger = logging.getLogger('fastNLP') + logger.setLevel(level) + handlers_type = set([type(h) for h in logger.handlers]) + + # make sure to initialize logger only once + # Stream Handler + if stdout == 'plain' and (logging.StreamHandler not in handlers_type): + stream_handler = logging.StreamHandler(sys.stdout) + elif stdout == 'tqdm' and (TqdmLoggingHandler not in handlers_type): + stream_handler = TqdmLoggingHandler(level) + else: + stream_handler = None + + if stream_handler is not None: + stream_formatter = logging.Formatter('[%(levelname)s] %(message)s') + stream_handler.setLevel(level) + stream_handler.setFormatter(stream_formatter) + logger.addHandler(stream_handler) + + # File Handler + if path is not None and (logging.FileHandler not in handlers_type): + if os.path.exists(path): + assert os.path.isfile(path) + warnings.warn('log already exists in {}'.format(path)) + dirname = os.path.abspath(os.path.dirname(path)) + os.makedirs(dirname, exist_ok=True) + + file_handler = logging.FileHandler(path, mode='a') + file_handler.setLevel(level) + file_formatter = logging.Formatter(fmt='%(asctime)s - [%(levelname)s] - %(name)s - %(message)s', + datefmt='%Y/%m/%d %H:%M:%S') + file_handler.setFormatter(file_formatter) + logger.addHandler(file_handler) + + return logger + +get_logger = logging.getLogger diff --git a/reproduction/text_classification/train_dpcnn.py b/reproduction/text_classification/train_dpcnn.py index e4df00bf..99e27640 100644 --- a/reproduction/text_classification/train_dpcnn.py +++ b/reproduction/text_classification/train_dpcnn.py @@ -111,17 +111,17 @@ device = 'cuda:0' if torch.cuda.is_available() else 'cpu' print(device) # 4.定义train方法 -# trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, -# sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size), -# metrics=[metric], -# dev_data=datainfo.datasets['test'], device=device, -# check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, -# n_epochs=ops.train_epoch, num_workers=4) -trainer = DistTrainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, - metrics=[metric], - dev_data=datainfo.datasets['test'], device='cuda', - batch_size_per_gpu=ops.batch_size, callbacks_all=callbacks, - n_epochs=ops.train_epoch, num_workers=4) +trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, + sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size), + metrics=[metric], use_tqdm=False, + dev_data=datainfo.datasets['test'], device=device, + check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, + n_epochs=ops.train_epoch, num_workers=4) +# trainer = DistTrainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, +# metrics=[metric], +# dev_data=datainfo.datasets['test'], device='cuda', +# batch_size_per_gpu=ops.batch_size, callbacks_all=callbacks, +# n_epochs=ops.train_epoch, num_workers=4) if __name__ == "__main__":