@@ -86,7 +86,7 @@ except: | |||
from ..io.model_io import ModelSaver, ModelLoader | |||
from .dataset import DataSet | |||
from .tester import Tester | |||
import logging | |||
from ..io import logger | |||
try: | |||
import fitlog | |||
@@ -178,7 +178,7 @@ class Callback(object): | |||
@property | |||
def logger(self): | |||
return getattr(self._trainer, 'logger', logging.getLogger(__name__)) | |||
return getattr(self._trainer, 'logger', logger) | |||
def on_train_begin(self): | |||
""" | |||
@@ -9,7 +9,6 @@ from torch.utils.data.distributed import DistributedSampler | |||
from torch.nn.parallel import DistributedDataParallel as DDP | |||
import os | |||
from tqdm import tqdm | |||
import logging | |||
import time | |||
from datetime import datetime, timedelta | |||
from functools import partial | |||
@@ -22,7 +21,8 @@ from .optimizer import Optimizer | |||
from .utils import _build_args | |||
from .utils import _move_dict_value_to_device | |||
from .utils import _get_func_signature | |||
from ..io.logger import init_logger | |||
from ..io import logger | |||
import logging | |||
from pkg_resources import parse_version | |||
__all__ = [ | |||
@@ -140,8 +140,8 @@ class DistTrainer(): | |||
self.cp_save_path = None | |||
# use INFO in the master, WARN for others | |||
init_logger(log_path, level=logging.INFO if self.is_master else logging.WARNING) | |||
self.logger = logging.getLogger(__name__) | |||
logger.setLevel(logging.INFO if self.is_master else logging.WARNING) | |||
self.logger = logger | |||
self.logger.info("Setup Distributed Trainer") | |||
self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format( | |||
os.getpid(), self.rank, self.local_rank, self.device, self.fp16 if self.fp16 else False)) | |||
@@ -56,7 +56,7 @@ from .utils import _move_model_to_device | |||
from ._parallel_utils import _data_parallel_wrapper | |||
from ._parallel_utils import _model_contains_inner_module | |||
from functools import partial | |||
from ..io.logger import init_logger, get_logger | |||
from ..io import logger | |||
__all__ = [ | |||
"Tester" | |||
@@ -104,8 +104,7 @@ class Tester(object): | |||
self.batch_size = batch_size | |||
self.verbose = verbose | |||
self.use_tqdm = use_tqdm | |||
init_logger(stdout='tqdm' if use_tqdm else 'plain') | |||
self.logger = get_logger(__name__) | |||
self.logger = logger | |||
if isinstance(data, DataSet): | |||
self.data_iterator = DataSetIter( | |||
@@ -353,7 +353,7 @@ from .utils import _get_func_signature | |||
from .utils import _get_model_device | |||
from .utils import _move_model_to_device | |||
from ._parallel_utils import _model_contains_inner_module | |||
from ..io.logger import init_logger, get_logger | |||
from ..io import logger | |||
class Trainer(object): | |||
@@ -548,11 +548,7 @@ class Trainer(object): | |||
else: | |||
raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer))) | |||
log_path = None | |||
if save_path is not None: | |||
log_path = os.path.join(os.path.dirname(save_path), 'log') | |||
init_logger(path=log_path, stdout='tqdm' if use_tqdm else 'plain') | |||
self.logger = get_logger(__name__) | |||
self.logger = logger | |||
self.use_tqdm = use_tqdm | |||
self.pbar = None | |||
@@ -701,7 +697,7 @@ class Trainer(object): | |||
self.n_steps) + \ | |||
self.tester._format_eval_results(eval_res) | |||
# pbar.write(eval_str + '\n') | |||
self.logger.info(eval_str) | |||
self.logger.info(eval_str + '\n') | |||
# ================= mini-batch end ==================== # | |||
# lr decay; early stopping | |||
@@ -72,6 +72,8 @@ __all__ = [ | |||
'ModelLoader', | |||
'ModelSaver', | |||
'logger', | |||
] | |||
from .embed_loader import EmbedLoader | |||
@@ -81,3 +83,4 @@ from .model_io import ModelLoader, ModelSaver | |||
from .loader import * | |||
from .pipe import * | |||
from ._logger import * |
@@ -6,8 +6,11 @@ import os | |||
import sys | |||
import warnings | |||
__all__ = [ | |||
'logger', | |||
] | |||
__all__ = ['logger'] | |||
ROOT_NAME = 'fastNLP' | |||
try: | |||
import fitlog | |||
@@ -39,7 +42,7 @@ else: | |||
self.setLevel(level) | |||
def get_level(level): | |||
def _get_level(level): | |||
if isinstance(level, int): | |||
pass | |||
else: | |||
@@ -50,22 +53,45 @@ def get_level(level): | |||
return level | |||
def init_logger(path=None, stdout='tqdm', level='INFO'): | |||
"""initialize logger""" | |||
if stdout not in ['none', 'plain', 'tqdm']: | |||
raise ValueError('stdout must in one of {}'.format(['none', 'plain', 'tqdm'])) | |||
def _add_file_handler(logger, path, level='INFO'): | |||
for h in logger.handlers: | |||
if isinstance(h, logging.FileHandler): | |||
if os.path.abspath(path) == h.baseFilename: | |||
# file path already added | |||
return | |||
level = get_level(level) | |||
# File Handler | |||
if os.path.exists(path): | |||
assert os.path.isfile(path) | |||
warnings.warn('log already exists in {}'.format(path)) | |||
dirname = os.path.abspath(os.path.dirname(path)) | |||
os.makedirs(dirname, exist_ok=True) | |||
logger = logging.getLogger('fastNLP') | |||
logger.setLevel(level) | |||
handlers_type = set([type(h) for h in logger.handlers]) | |||
file_handler = logging.FileHandler(path, mode='a') | |||
file_handler.setLevel(_get_level(level)) | |||
file_formatter = logging.Formatter(fmt='%(asctime)s - [%(levelname)s] - %(message)s', | |||
datefmt='%Y/%m/%d %H:%M:%S') | |||
file_handler.setFormatter(file_formatter) | |||
logger.addHandler(file_handler) | |||
def _set_stdout_handler(logger, stdout='tqdm', level='INFO'): | |||
level = _get_level(level) | |||
if stdout not in ['none', 'plain', 'tqdm']: | |||
raise ValueError('stdout must in one of {}'.format(['none', 'plain', 'tqdm'])) | |||
# make sure to initialize logger only once | |||
stream_handler = None | |||
for i, h in enumerate(logger.handlers): | |||
if isinstance(h, (logging.StreamHandler, TqdmLoggingHandler)): | |||
stream_handler = h | |||
break | |||
if stream_handler is not None: | |||
logger.removeHandler(stream_handler) | |||
# Stream Handler | |||
if stdout == 'plain' and (logging.StreamHandler not in handlers_type): | |||
if stdout == 'plain': | |||
stream_handler = logging.StreamHandler(sys.stdout) | |||
elif stdout == 'tqdm' and (TqdmLoggingHandler not in handlers_type): | |||
elif stdout == 'tqdm': | |||
stream_handler = TqdmLoggingHandler(level) | |||
else: | |||
stream_handler = None | |||
@@ -76,52 +102,44 @@ def init_logger(path=None, stdout='tqdm', level='INFO'): | |||
stream_handler.setFormatter(stream_formatter) | |||
logger.addHandler(stream_handler) | |||
# File Handler | |||
if path is not None and (logging.FileHandler not in handlers_type): | |||
if os.path.exists(path): | |||
assert os.path.isfile(path) | |||
warnings.warn('log already exists in {}'.format(path)) | |||
dirname = os.path.abspath(os.path.dirname(path)) | |||
os.makedirs(dirname, exist_ok=True) | |||
file_handler = logging.FileHandler(path, mode='a') | |||
file_handler.setLevel(level) | |||
file_formatter = logging.Formatter(fmt='%(asctime)s - [%(levelname)s] - %(name)s - %(message)s', | |||
datefmt='%Y/%m/%d %H:%M:%S') | |||
file_handler.setFormatter(file_formatter) | |||
logger.addHandler(file_handler) | |||
return logger | |||
def _init_logger(path=None, stdout='tqdm', level='INFO'): | |||
"""initialize logger""" | |||
level = _get_level(level) | |||
# logger = logging.getLogger(ROOT_NAME) | |||
logger = logging.getLogger() | |||
logger.setLevel(level) | |||
# init logger when import | |||
logger = init_logger() | |||
_set_stdout_handler(logger, stdout, level) | |||
# File Handler | |||
if path is not None: | |||
_add_file_handler(logger, path, level) | |||
def get_logger(name=None): | |||
if name is None: | |||
return logging.getLogger('fastNLP') | |||
return logging.getLogger(name) | |||
return logger | |||
def set_file(path, level='INFO'): | |||
for h in logger.handlers: | |||
if isinstance(h, logging.FileHandler): | |||
if os.path.abspath(path) == h.baseFilename: | |||
# file path already added | |||
return | |||
def _get_logger(name=None, level='INFO'): | |||
level = _get_level(level) | |||
if name is None: | |||
name = ROOT_NAME | |||
assert isinstance(name, str) | |||
if not name.startswith(ROOT_NAME): | |||
name = '{}.{}'.format(ROOT_NAME, name) | |||
logger = logging.getLogger(name) | |||
logger.setLevel(level) | |||
return logger | |||
# File Handler | |||
if os.path.exists(path): | |||
assert os.path.isfile(path) | |||
warnings.warn('log already exists in {}'.format(path)) | |||
dirname = os.path.abspath(os.path.dirname(path)) | |||
os.makedirs(dirname, exist_ok=True) | |||
file_handler = logging.FileHandler(path, mode='a') | |||
file_handler.setLevel(get_level(level)) | |||
file_formatter = logging.Formatter(fmt='%(asctime)s - [%(levelname)s] - %(name)s - %(message)s', | |||
datefmt='%Y/%m/%d %H:%M:%S') | |||
file_handler.setFormatter(file_formatter) | |||
logger.addHandler(file_handler) | |||
class FastNLPLogger(logging.Logger): | |||
def add_file(self, path, level): | |||
_add_file_handler(self, path, level) | |||
def set_stdout(self, stdout, level): | |||
_set_stdout_handler(self, stdout, level) | |||
_logger = _init_logger(path=None) | |||
logger = FastNLPLogger(ROOT_NAME) | |||
logger.__dict__.update(_logger.__dict__) | |||
del _logger |
@@ -15,13 +15,14 @@ from fastNLP.core.const import Const as C | |||
from fastNLP.core.vocabulary import VocabularyOption | |||
from fastNLP.core.dist_trainer import DistTrainer | |||
from utils.util_init import set_rng_seeds | |||
from fastNLP.io import logger | |||
import os | |||
# os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' | |||
# os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' | |||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | |||
# hyper | |||
logger.add_file('log', 'INFO') | |||
class Config(): | |||
seed = 12345 | |||
@@ -46,11 +47,11 @@ class Config(): | |||
self.datapath = {k: os.path.join(self.datadir, v) | |||
for k, v in self.datafile.items()} | |||
ops = Config() | |||
set_rng_seeds(ops.seed) | |||
print('RNG SEED: {}'.format(ops.seed)) | |||
# print('RNG SEED: {}'.format(ops.seed)) | |||
logger.info('RNG SEED %d'%ops.seed) | |||
# 1.task相关信息:利用dataloader载入dataInfo | |||
@@ -81,8 +82,9 @@ print(embedding.embedding.weight.data.mean(), embedding.embedding.weight.data.st | |||
# embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)]) | |||
datainfo.datasets['train'] = datainfo.datasets['train'][:1000] | |||
datainfo.datasets['test'] = datainfo.datasets['test'][:1000] | |||
print(datainfo) | |||
print(datainfo.datasets['train'][0]) | |||
# print(datainfo) | |||
# print(datainfo.datasets['train'][0]) | |||
logger.info(datainfo) | |||
model = DPCNN(init_embed=embedding, num_cls=len(datainfo.vocabs[C.TARGET]), | |||
embed_dropout=ops.embed_dropout, cls_dropout=ops.cls_dropout) | |||
@@ -108,12 +110,13 @@ callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5))) | |||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu' | |||
print(device) | |||
# print(device) | |||
logger.info(device) | |||
# 4.定义train方法 | |||
trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, | |||
sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size), | |||
metrics=[metric], use_tqdm=False, | |||
metrics=[metric], use_tqdm=False, save_path='save', | |||
dev_data=datainfo.datasets['test'], device=device, | |||
check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, | |||
n_epochs=ops.train_epoch, num_workers=4) | |||