@@ -656,10 +656,13 @@ class EvaluateCallback(Callback): | |||||
for key, tester in self.testers.items(): | for key, tester in self.testers.items(): | ||||
try: | try: | ||||
eval_result = tester.test() | eval_result = tester.test() | ||||
self.pbar.write("Evaluation on {}:".format(key)) | |||||
self.pbar.write(tester._format_eval_results(eval_result)) | |||||
# self.pbar.write("Evaluation on {}:".format(key)) | |||||
self.logger.info("Evaluation on {}:".format(key)) | |||||
# self.pbar.write(tester._format_eval_results(eval_result)) | |||||
self.logger.info(tester._format_eval_results(eval_result)) | |||||
except Exception: | except Exception: | ||||
self.pbar.write("Exception happens when evaluate on DataSet named `{}`.".format(key)) | |||||
# self.pbar.write("Exception happens when evaluate on DataSet named `{}`.".format(key)) | |||||
self.logger.info("Exception happens when evaluate on DataSet named `{}`.".format(key)) | |||||
class LRScheduler(Callback): | class LRScheduler(Callback): | ||||
@@ -22,7 +22,7 @@ from .optimizer import Optimizer | |||||
from .utils import _build_args | from .utils import _build_args | ||||
from .utils import _move_dict_value_to_device | from .utils import _move_dict_value_to_device | ||||
from .utils import _get_func_signature | from .utils import _get_func_signature | ||||
from ..io.logger import initLogger | |||||
from ..io.logger import init_logger | |||||
from pkg_resources import parse_version | from pkg_resources import parse_version | ||||
__all__ = [ | __all__ = [ | ||||
@@ -140,7 +140,7 @@ class DistTrainer(): | |||||
self.cp_save_path = None | self.cp_save_path = None | ||||
# use INFO in the master, WARN for others | # use INFO in the master, WARN for others | ||||
initLogger(log_path, level=logging.INFO if self.is_master else logging.WARNING) | |||||
init_logger(log_path, level=logging.INFO if self.is_master else logging.WARNING) | |||||
self.logger = logging.getLogger(__name__) | self.logger = logging.getLogger(__name__) | ||||
self.logger.info("Setup Distributed Trainer") | self.logger.info("Setup Distributed Trainer") | ||||
self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format( | self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format( | ||||
@@ -56,6 +56,7 @@ from .utils import _move_model_to_device | |||||
from ._parallel_utils import _data_parallel_wrapper | from ._parallel_utils import _data_parallel_wrapper | ||||
from ._parallel_utils import _model_contains_inner_module | from ._parallel_utils import _model_contains_inner_module | ||||
from functools import partial | from functools import partial | ||||
from ..io.logger import init_logger, get_logger | |||||
__all__ = [ | __all__ = [ | ||||
"Tester" | "Tester" | ||||
@@ -103,6 +104,8 @@ class Tester(object): | |||||
self.batch_size = batch_size | self.batch_size = batch_size | ||||
self.verbose = verbose | self.verbose = verbose | ||||
self.use_tqdm = use_tqdm | self.use_tqdm = use_tqdm | ||||
init_logger(stdout='tqdm' if use_tqdm else 'plain') | |||||
self.logger = get_logger(__name__) | |||||
if isinstance(data, DataSet): | if isinstance(data, DataSet): | ||||
self.data_iterator = DataSetIter( | self.data_iterator = DataSetIter( | ||||
@@ -181,7 +184,8 @@ class Tester(object): | |||||
end_time = time.time() | end_time = time.time() | ||||
test_str = f'Evaluate data in {round(end_time - start_time, 2)} seconds!' | test_str = f'Evaluate data in {round(end_time - start_time, 2)} seconds!' | ||||
pbar.write(test_str) | |||||
# pbar.write(test_str) | |||||
self.logger.info(test_str) | |||||
pbar.close() | pbar.close() | ||||
except _CheckError as e: | except _CheckError as e: | ||||
prev_func_signature = _get_func_signature(self._predict_func) | prev_func_signature = _get_func_signature(self._predict_func) | ||||
@@ -353,8 +353,7 @@ from .utils import _get_func_signature | |||||
from .utils import _get_model_device | from .utils import _get_model_device | ||||
from .utils import _move_model_to_device | from .utils import _move_model_to_device | ||||
from ._parallel_utils import _model_contains_inner_module | from ._parallel_utils import _model_contains_inner_module | ||||
from ..io.logger import initLogger | |||||
import logging | |||||
from ..io.logger import init_logger, get_logger | |||||
class Trainer(object): | class Trainer(object): | ||||
@@ -552,8 +551,8 @@ class Trainer(object): | |||||
log_path = None | log_path = None | ||||
if save_path is not None: | if save_path is not None: | ||||
log_path = os.path.join(os.path.dirname(save_path), 'log') | log_path = os.path.join(os.path.dirname(save_path), 'log') | ||||
initLogger(log_path) | |||||
self.logger = logging.getLogger(__name__) | |||||
init_logger(path=log_path, stdout='tqdm' if use_tqdm else 'plain') | |||||
self.logger = get_logger(__name__) | |||||
self.use_tqdm = use_tqdm | self.use_tqdm = use_tqdm | ||||
self.pbar = None | self.pbar = None | ||||
@@ -701,8 +700,8 @@ class Trainer(object): | |||||
eval_str = "Evaluation on dev at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | eval_str = "Evaluation on dev at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | ||||
self.n_steps) + \ | self.n_steps) + \ | ||||
self.tester._format_eval_results(eval_res) | self.tester._format_eval_results(eval_res) | ||||
pbar.write(eval_str + '\n') | |||||
# pbar.write(eval_str + '\n') | |||||
self.logger.info(eval_str) | |||||
# ================= mini-batch end ==================== # | # ================= mini-batch end ==================== # | ||||
# lr decay; early stopping | # lr decay; early stopping | ||||
@@ -661,7 +661,7 @@ class _pseudo_tqdm: | |||||
当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据 | 当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据 | ||||
""" | """ | ||||
def __init__(self, **kwargs): | def __init__(self, **kwargs): | ||||
self.logger = logging.getLogger() | |||||
self.logger = logging.getLogger(__name__) | |||||
def write(self, info): | def write(self, info): | ||||
self.logger.info(info) | self.logger.info(info) | ||||
@@ -0,0 +1,88 @@ | |||||
import logging | |||||
import logging.config | |||||
import torch | |||||
import _pickle as pickle | |||||
import os | |||||
import sys | |||||
import warnings | |||||
try: | |||||
import fitlog | |||||
except ImportError: | |||||
fitlog = None | |||||
try: | |||||
from tqdm.auto import tqdm | |||||
except ImportError: | |||||
tqdm = None | |||||
if tqdm is not None: | |||||
class TqdmLoggingHandler(logging.Handler): | |||||
def __init__(self, level=logging.INFO): | |||||
super().__init__(level) | |||||
def emit(self, record): | |||||
try: | |||||
msg = self.format(record) | |||||
tqdm.write(msg) | |||||
self.flush() | |||||
except (KeyboardInterrupt, SystemExit): | |||||
raise | |||||
except: | |||||
self.handleError(record) | |||||
else: | |||||
class TqdmLoggingHandler(logging.StreamHandler): | |||||
def __init__(self, level=logging.INFO): | |||||
super().__init__(sys.stdout) | |||||
self.setLevel(level) | |||||
def init_logger(path=None, stdout='tqdm', level='INFO'): | |||||
"""initialize logger""" | |||||
if stdout not in ['none', 'plain', 'tqdm']: | |||||
raise ValueError('stdout must in one of {}'.format(['none', 'plain', 'tqdm'])) | |||||
if isinstance(level, int): | |||||
pass | |||||
else: | |||||
level = level.lower() | |||||
level = {'info': logging.INFO, 'debug': logging.DEBUG, | |||||
'warn': logging.WARN, 'warning': logging.WARN, | |||||
'error': logging.ERROR}[level] | |||||
logger = logging.getLogger('fastNLP') | |||||
logger.setLevel(level) | |||||
handlers_type = set([type(h) for h in logger.handlers]) | |||||
# make sure to initialize logger only once | |||||
# Stream Handler | |||||
if stdout == 'plain' and (logging.StreamHandler not in handlers_type): | |||||
stream_handler = logging.StreamHandler(sys.stdout) | |||||
elif stdout == 'tqdm' and (TqdmLoggingHandler not in handlers_type): | |||||
stream_handler = TqdmLoggingHandler(level) | |||||
else: | |||||
stream_handler = None | |||||
if stream_handler is not None: | |||||
stream_formatter = logging.Formatter('[%(levelname)s] %(message)s') | |||||
stream_handler.setLevel(level) | |||||
stream_handler.setFormatter(stream_formatter) | |||||
logger.addHandler(stream_handler) | |||||
# File Handler | |||||
if path is not None and (logging.FileHandler not in handlers_type): | |||||
if os.path.exists(path): | |||||
assert os.path.isfile(path) | |||||
warnings.warn('log already exists in {}'.format(path)) | |||||
dirname = os.path.abspath(os.path.dirname(path)) | |||||
os.makedirs(dirname, exist_ok=True) | |||||
file_handler = logging.FileHandler(path, mode='a') | |||||
file_handler.setLevel(level) | |||||
file_formatter = logging.Formatter(fmt='%(asctime)s - [%(levelname)s] - %(name)s - %(message)s', | |||||
datefmt='%Y/%m/%d %H:%M:%S') | |||||
file_handler.setFormatter(file_formatter) | |||||
logger.addHandler(file_handler) | |||||
return logger | |||||
get_logger = logging.getLogger |
@@ -111,17 +111,17 @@ device = 'cuda:0' if torch.cuda.is_available() else 'cpu' | |||||
print(device) | print(device) | ||||
# 4.定义train方法 | # 4.定义train方法 | ||||
# trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, | |||||
# sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size), | |||||
# metrics=[metric], | |||||
# dev_data=datainfo.datasets['test'], device=device, | |||||
# check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, | |||||
# n_epochs=ops.train_epoch, num_workers=4) | |||||
trainer = DistTrainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, | |||||
metrics=[metric], | |||||
dev_data=datainfo.datasets['test'], device='cuda', | |||||
batch_size_per_gpu=ops.batch_size, callbacks_all=callbacks, | |||||
n_epochs=ops.train_epoch, num_workers=4) | |||||
trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, | |||||
sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size), | |||||
metrics=[metric], use_tqdm=False, | |||||
dev_data=datainfo.datasets['test'], device=device, | |||||
check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, | |||||
n_epochs=ops.train_epoch, num_workers=4) | |||||
# trainer = DistTrainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, | |||||
# metrics=[metric], | |||||
# dev_data=datainfo.datasets['test'], device='cuda', | |||||
# batch_size_per_gpu=ops.batch_size, callbacks_all=callbacks, | |||||
# n_epochs=ops.train_epoch, num_workers=4) | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||