@@ -656,10 +656,13 @@ class EvaluateCallback(Callback): | |||
for key, tester in self.testers.items(): | |||
try: | |||
eval_result = tester.test() | |||
self.pbar.write("Evaluation on {}:".format(key)) | |||
self.pbar.write(tester._format_eval_results(eval_result)) | |||
# self.pbar.write("Evaluation on {}:".format(key)) | |||
self.logger.info("Evaluation on {}:".format(key)) | |||
# self.pbar.write(tester._format_eval_results(eval_result)) | |||
self.logger.info(tester._format_eval_results(eval_result)) | |||
except Exception: | |||
self.pbar.write("Exception happens when evaluate on DataSet named `{}`.".format(key)) | |||
# self.pbar.write("Exception happens when evaluate on DataSet named `{}`.".format(key)) | |||
self.logger.info("Exception happens when evaluate on DataSet named `{}`.".format(key)) | |||
class LRScheduler(Callback): | |||
@@ -22,7 +22,7 @@ from .optimizer import Optimizer | |||
from .utils import _build_args | |||
from .utils import _move_dict_value_to_device | |||
from .utils import _get_func_signature | |||
from ..io.logger import initLogger | |||
from ..io.logger import init_logger | |||
from pkg_resources import parse_version | |||
__all__ = [ | |||
@@ -140,7 +140,7 @@ class DistTrainer(): | |||
self.cp_save_path = None | |||
# use INFO in the master, WARN for others | |||
initLogger(log_path, level=logging.INFO if self.is_master else logging.WARNING) | |||
init_logger(log_path, level=logging.INFO if self.is_master else logging.WARNING) | |||
self.logger = logging.getLogger(__name__) | |||
self.logger.info("Setup Distributed Trainer") | |||
self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format( | |||
@@ -56,6 +56,7 @@ from .utils import _move_model_to_device | |||
from ._parallel_utils import _data_parallel_wrapper | |||
from ._parallel_utils import _model_contains_inner_module | |||
from functools import partial | |||
from ..io.logger import init_logger, get_logger | |||
__all__ = [ | |||
"Tester" | |||
@@ -103,6 +104,8 @@ class Tester(object): | |||
self.batch_size = batch_size | |||
self.verbose = verbose | |||
self.use_tqdm = use_tqdm | |||
init_logger(stdout='tqdm' if use_tqdm else 'plain') | |||
self.logger = get_logger(__name__) | |||
if isinstance(data, DataSet): | |||
self.data_iterator = DataSetIter( | |||
@@ -181,7 +184,8 @@ class Tester(object): | |||
end_time = time.time() | |||
test_str = f'Evaluate data in {round(end_time - start_time, 2)} seconds!' | |||
pbar.write(test_str) | |||
# pbar.write(test_str) | |||
self.logger.info(test_str) | |||
pbar.close() | |||
except _CheckError as e: | |||
prev_func_signature = _get_func_signature(self._predict_func) | |||
@@ -353,8 +353,7 @@ from .utils import _get_func_signature | |||
from .utils import _get_model_device | |||
from .utils import _move_model_to_device | |||
from ._parallel_utils import _model_contains_inner_module | |||
from ..io.logger import initLogger | |||
import logging | |||
from ..io.logger import init_logger, get_logger | |||
class Trainer(object): | |||
@@ -552,8 +551,8 @@ class Trainer(object): | |||
log_path = None | |||
if save_path is not None: | |||
log_path = os.path.join(os.path.dirname(save_path), 'log') | |||
initLogger(log_path) | |||
self.logger = logging.getLogger(__name__) | |||
init_logger(path=log_path, stdout='tqdm' if use_tqdm else 'plain') | |||
self.logger = get_logger(__name__) | |||
self.use_tqdm = use_tqdm | |||
self.pbar = None | |||
@@ -701,8 +700,8 @@ class Trainer(object): | |||
eval_str = "Evaluation on dev at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | |||
self.n_steps) + \ | |||
self.tester._format_eval_results(eval_res) | |||
pbar.write(eval_str + '\n') | |||
# pbar.write(eval_str + '\n') | |||
self.logger.info(eval_str) | |||
# ================= mini-batch end ==================== # | |||
# lr decay; early stopping | |||
@@ -661,7 +661,7 @@ class _pseudo_tqdm: | |||
当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据 | |||
""" | |||
def __init__(self, **kwargs): | |||
self.logger = logging.getLogger() | |||
self.logger = logging.getLogger(__name__) | |||
def write(self, info): | |||
self.logger.info(info) | |||
@@ -0,0 +1,88 @@ | |||
import logging | |||
import logging.config | |||
import torch | |||
import _pickle as pickle | |||
import os | |||
import sys | |||
import warnings | |||
try: | |||
import fitlog | |||
except ImportError: | |||
fitlog = None | |||
try: | |||
from tqdm.auto import tqdm | |||
except ImportError: | |||
tqdm = None | |||
if tqdm is not None: | |||
class TqdmLoggingHandler(logging.Handler): | |||
def __init__(self, level=logging.INFO): | |||
super().__init__(level) | |||
def emit(self, record): | |||
try: | |||
msg = self.format(record) | |||
tqdm.write(msg) | |||
self.flush() | |||
except (KeyboardInterrupt, SystemExit): | |||
raise | |||
except: | |||
self.handleError(record) | |||
else: | |||
class TqdmLoggingHandler(logging.StreamHandler): | |||
def __init__(self, level=logging.INFO): | |||
super().__init__(sys.stdout) | |||
self.setLevel(level) | |||
def init_logger(path=None, stdout='tqdm', level='INFO'): | |||
"""initialize logger""" | |||
if stdout not in ['none', 'plain', 'tqdm']: | |||
raise ValueError('stdout must in one of {}'.format(['none', 'plain', 'tqdm'])) | |||
if isinstance(level, int): | |||
pass | |||
else: | |||
level = level.lower() | |||
level = {'info': logging.INFO, 'debug': logging.DEBUG, | |||
'warn': logging.WARN, 'warning': logging.WARN, | |||
'error': logging.ERROR}[level] | |||
logger = logging.getLogger('fastNLP') | |||
logger.setLevel(level) | |||
handlers_type = set([type(h) for h in logger.handlers]) | |||
# make sure to initialize logger only once | |||
# Stream Handler | |||
if stdout == 'plain' and (logging.StreamHandler not in handlers_type): | |||
stream_handler = logging.StreamHandler(sys.stdout) | |||
elif stdout == 'tqdm' and (TqdmLoggingHandler not in handlers_type): | |||
stream_handler = TqdmLoggingHandler(level) | |||
else: | |||
stream_handler = None | |||
if stream_handler is not None: | |||
stream_formatter = logging.Formatter('[%(levelname)s] %(message)s') | |||
stream_handler.setLevel(level) | |||
stream_handler.setFormatter(stream_formatter) | |||
logger.addHandler(stream_handler) | |||
# File Handler | |||
if path is not None and (logging.FileHandler not in handlers_type): | |||
if os.path.exists(path): | |||
assert os.path.isfile(path) | |||
warnings.warn('log already exists in {}'.format(path)) | |||
dirname = os.path.abspath(os.path.dirname(path)) | |||
os.makedirs(dirname, exist_ok=True) | |||
file_handler = logging.FileHandler(path, mode='a') | |||
file_handler.setLevel(level) | |||
file_formatter = logging.Formatter(fmt='%(asctime)s - [%(levelname)s] - %(name)s - %(message)s', | |||
datefmt='%Y/%m/%d %H:%M:%S') | |||
file_handler.setFormatter(file_formatter) | |||
logger.addHandler(file_handler) | |||
return logger | |||
get_logger = logging.getLogger |
@@ -111,17 +111,17 @@ device = 'cuda:0' if torch.cuda.is_available() else 'cpu' | |||
print(device) | |||
# 4.定义train方法 | |||
# trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, | |||
# sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size), | |||
# metrics=[metric], | |||
# dev_data=datainfo.datasets['test'], device=device, | |||
# check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, | |||
# n_epochs=ops.train_epoch, num_workers=4) | |||
trainer = DistTrainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, | |||
metrics=[metric], | |||
dev_data=datainfo.datasets['test'], device='cuda', | |||
batch_size_per_gpu=ops.batch_size, callbacks_all=callbacks, | |||
n_epochs=ops.train_epoch, num_workers=4) | |||
trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, | |||
sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size), | |||
metrics=[metric], use_tqdm=False, | |||
dev_data=datainfo.datasets['test'], device=device, | |||
check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, | |||
n_epochs=ops.train_epoch, num_workers=4) | |||
# trainer = DistTrainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, | |||
# metrics=[metric], | |||
# dev_data=datainfo.datasets['test'], device='cuda', | |||
# batch_size_per_gpu=ops.batch_size, callbacks_all=callbacks, | |||
# n_epochs=ops.train_epoch, num_workers=4) | |||
if __name__ == "__main__": | |||