Browse Source

[update] logger in trainer & tester

tags/v0.4.10
yunfan 5 years ago
parent
commit
007c047ae7
7 changed files with 118 additions and 24 deletions
  1. +6
    -3
      fastNLP/core/callback.py
  2. +2
    -2
      fastNLP/core/dist_trainer.py
  3. +5
    -1
      fastNLP/core/tester.py
  4. +5
    -6
      fastNLP/core/trainer.py
  5. +1
    -1
      fastNLP/core/utils.py
  6. +88
    -0
      fastNLP/io/logger.py
  7. +11
    -11
      reproduction/text_classification/train_dpcnn.py

+ 6
- 3
fastNLP/core/callback.py View File

@@ -656,10 +656,13 @@ class EvaluateCallback(Callback):
for key, tester in self.testers.items(): for key, tester in self.testers.items():
try: try:
eval_result = tester.test() eval_result = tester.test()
self.pbar.write("Evaluation on {}:".format(key))
self.pbar.write(tester._format_eval_results(eval_result))
# self.pbar.write("Evaluation on {}:".format(key))
self.logger.info("Evaluation on {}:".format(key))
# self.pbar.write(tester._format_eval_results(eval_result))
self.logger.info(tester._format_eval_results(eval_result))
except Exception: except Exception:
self.pbar.write("Exception happens when evaluate on DataSet named `{}`.".format(key))
# self.pbar.write("Exception happens when evaluate on DataSet named `{}`.".format(key))
self.logger.info("Exception happens when evaluate on DataSet named `{}`.".format(key))




class LRScheduler(Callback): class LRScheduler(Callback):


+ 2
- 2
fastNLP/core/dist_trainer.py View File

@@ -22,7 +22,7 @@ from .optimizer import Optimizer
from .utils import _build_args from .utils import _build_args
from .utils import _move_dict_value_to_device from .utils import _move_dict_value_to_device
from .utils import _get_func_signature from .utils import _get_func_signature
from ..io.logger import initLogger
from ..io.logger import init_logger
from pkg_resources import parse_version from pkg_resources import parse_version


__all__ = [ __all__ = [
@@ -140,7 +140,7 @@ class DistTrainer():
self.cp_save_path = None self.cp_save_path = None


# use INFO in the master, WARN for others # use INFO in the master, WARN for others
initLogger(log_path, level=logging.INFO if self.is_master else logging.WARNING)
init_logger(log_path, level=logging.INFO if self.is_master else logging.WARNING)
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self.logger.info("Setup Distributed Trainer") self.logger.info("Setup Distributed Trainer")
self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format( self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format(


+ 5
- 1
fastNLP/core/tester.py View File

@@ -56,6 +56,7 @@ from .utils import _move_model_to_device
from ._parallel_utils import _data_parallel_wrapper from ._parallel_utils import _data_parallel_wrapper
from ._parallel_utils import _model_contains_inner_module from ._parallel_utils import _model_contains_inner_module
from functools import partial from functools import partial
from ..io.logger import init_logger, get_logger


__all__ = [ __all__ = [
"Tester" "Tester"
@@ -103,6 +104,8 @@ class Tester(object):
self.batch_size = batch_size self.batch_size = batch_size
self.verbose = verbose self.verbose = verbose
self.use_tqdm = use_tqdm self.use_tqdm = use_tqdm
init_logger(stdout='tqdm' if use_tqdm else 'plain')
self.logger = get_logger(__name__)


if isinstance(data, DataSet): if isinstance(data, DataSet):
self.data_iterator = DataSetIter( self.data_iterator = DataSetIter(
@@ -181,7 +184,8 @@ class Tester(object):


end_time = time.time() end_time = time.time()
test_str = f'Evaluate data in {round(end_time - start_time, 2)} seconds!' test_str = f'Evaluate data in {round(end_time - start_time, 2)} seconds!'
pbar.write(test_str)
# pbar.write(test_str)
self.logger.info(test_str)
pbar.close() pbar.close()
except _CheckError as e: except _CheckError as e:
prev_func_signature = _get_func_signature(self._predict_func) prev_func_signature = _get_func_signature(self._predict_func)


+ 5
- 6
fastNLP/core/trainer.py View File

@@ -353,8 +353,7 @@ from .utils import _get_func_signature
from .utils import _get_model_device from .utils import _get_model_device
from .utils import _move_model_to_device from .utils import _move_model_to_device
from ._parallel_utils import _model_contains_inner_module from ._parallel_utils import _model_contains_inner_module
from ..io.logger import initLogger
import logging
from ..io.logger import init_logger, get_logger




class Trainer(object): class Trainer(object):
@@ -552,8 +551,8 @@ class Trainer(object):
log_path = None log_path = None
if save_path is not None: if save_path is not None:
log_path = os.path.join(os.path.dirname(save_path), 'log') log_path = os.path.join(os.path.dirname(save_path), 'log')
initLogger(log_path)
self.logger = logging.getLogger(__name__)
init_logger(path=log_path, stdout='tqdm' if use_tqdm else 'plain')
self.logger = get_logger(__name__)


self.use_tqdm = use_tqdm self.use_tqdm = use_tqdm
self.pbar = None self.pbar = None
@@ -701,8 +700,8 @@ class Trainer(object):
eval_str = "Evaluation on dev at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, eval_str = "Evaluation on dev at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step,
self.n_steps) + \ self.n_steps) + \
self.tester._format_eval_results(eval_res) self.tester._format_eval_results(eval_res)
pbar.write(eval_str + '\n')
# pbar.write(eval_str + '\n')
self.logger.info(eval_str)
# ================= mini-batch end ==================== # # ================= mini-batch end ==================== #


# lr decay; early stopping # lr decay; early stopping


+ 1
- 1
fastNLP/core/utils.py View File

@@ -661,7 +661,7 @@ class _pseudo_tqdm:
当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据 当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.logger = logging.getLogger()
self.logger = logging.getLogger(__name__)
def write(self, info): def write(self, info):
self.logger.info(info) self.logger.info(info)


+ 88
- 0
fastNLP/io/logger.py View File

@@ -0,0 +1,88 @@
import logging
import logging.config
import torch
import _pickle as pickle
import os
import sys
import warnings

try:
import fitlog
except ImportError:
fitlog = None
try:
from tqdm.auto import tqdm
except ImportError:
tqdm = None

if tqdm is not None:
class TqdmLoggingHandler(logging.Handler):
def __init__(self, level=logging.INFO):
super().__init__(level)

def emit(self, record):
try:
msg = self.format(record)
tqdm.write(msg)
self.flush()
except (KeyboardInterrupt, SystemExit):
raise
except:
self.handleError(record)
else:
class TqdmLoggingHandler(logging.StreamHandler):
def __init__(self, level=logging.INFO):
super().__init__(sys.stdout)
self.setLevel(level)


def init_logger(path=None, stdout='tqdm', level='INFO'):
"""initialize logger"""
if stdout not in ['none', 'plain', 'tqdm']:
raise ValueError('stdout must in one of {}'.format(['none', 'plain', 'tqdm']))

if isinstance(level, int):
pass
else:
level = level.lower()
level = {'info': logging.INFO, 'debug': logging.DEBUG,
'warn': logging.WARN, 'warning': logging.WARN,
'error': logging.ERROR}[level]

logger = logging.getLogger('fastNLP')
logger.setLevel(level)
handlers_type = set([type(h) for h in logger.handlers])

# make sure to initialize logger only once
# Stream Handler
if stdout == 'plain' and (logging.StreamHandler not in handlers_type):
stream_handler = logging.StreamHandler(sys.stdout)
elif stdout == 'tqdm' and (TqdmLoggingHandler not in handlers_type):
stream_handler = TqdmLoggingHandler(level)
else:
stream_handler = None

if stream_handler is not None:
stream_formatter = logging.Formatter('[%(levelname)s] %(message)s')
stream_handler.setLevel(level)
stream_handler.setFormatter(stream_formatter)
logger.addHandler(stream_handler)

# File Handler
if path is not None and (logging.FileHandler not in handlers_type):
if os.path.exists(path):
assert os.path.isfile(path)
warnings.warn('log already exists in {}'.format(path))
dirname = os.path.abspath(os.path.dirname(path))
os.makedirs(dirname, exist_ok=True)

file_handler = logging.FileHandler(path, mode='a')
file_handler.setLevel(level)
file_formatter = logging.Formatter(fmt='%(asctime)s - [%(levelname)s] - %(name)s - %(message)s',
datefmt='%Y/%m/%d %H:%M:%S')
file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler)

return logger

get_logger = logging.getLogger

+ 11
- 11
reproduction/text_classification/train_dpcnn.py View File

@@ -111,17 +111,17 @@ device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device) print(device)


# 4.定义train方法 # 4.定义train方法
# trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss,
# sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size),
# metrics=[metric],
# dev_data=datainfo.datasets['test'], device=device,
# check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks,
# n_epochs=ops.train_epoch, num_workers=4)
trainer = DistTrainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss,
metrics=[metric],
dev_data=datainfo.datasets['test'], device='cuda',
batch_size_per_gpu=ops.batch_size, callbacks_all=callbacks,
n_epochs=ops.train_epoch, num_workers=4)
trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss,
sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size),
metrics=[metric], use_tqdm=False,
dev_data=datainfo.datasets['test'], device=device,
check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks,
n_epochs=ops.train_epoch, num_workers=4)
# trainer = DistTrainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss,
# metrics=[metric],
# dev_data=datainfo.datasets['test'], device='cuda',
# batch_size_per_gpu=ops.batch_size, callbacks_all=callbacks,
# n_epochs=ops.train_epoch, num_workers=4)




if __name__ == "__main__": if __name__ == "__main__":


Loading…
Cancel
Save