@@ -86,7 +86,7 @@ except: | |||||
from ..io.model_io import ModelSaver, ModelLoader | from ..io.model_io import ModelSaver, ModelLoader | ||||
from .dataset import DataSet | from .dataset import DataSet | ||||
from .tester import Tester | from .tester import Tester | ||||
import logging | |||||
from ..io import logger | |||||
try: | try: | ||||
import fitlog | import fitlog | ||||
@@ -178,7 +178,7 @@ class Callback(object): | |||||
@property | @property | ||||
def logger(self): | def logger(self): | ||||
return getattr(self._trainer, 'logger', logging.getLogger(__name__)) | |||||
return getattr(self._trainer, 'logger', logger) | |||||
def on_train_begin(self): | def on_train_begin(self): | ||||
""" | """ | ||||
@@ -9,7 +9,6 @@ from torch.utils.data.distributed import DistributedSampler | |||||
from torch.nn.parallel import DistributedDataParallel as DDP | from torch.nn.parallel import DistributedDataParallel as DDP | ||||
import os | import os | ||||
from tqdm import tqdm | from tqdm import tqdm | ||||
import logging | |||||
import time | import time | ||||
from datetime import datetime, timedelta | from datetime import datetime, timedelta | ||||
from functools import partial | from functools import partial | ||||
@@ -22,7 +21,8 @@ from .optimizer import Optimizer | |||||
from .utils import _build_args | from .utils import _build_args | ||||
from .utils import _move_dict_value_to_device | from .utils import _move_dict_value_to_device | ||||
from .utils import _get_func_signature | from .utils import _get_func_signature | ||||
from ..io.logger import init_logger | |||||
from ..io import logger | |||||
import logging | |||||
from pkg_resources import parse_version | from pkg_resources import parse_version | ||||
__all__ = [ | __all__ = [ | ||||
@@ -140,8 +140,8 @@ class DistTrainer(): | |||||
self.cp_save_path = None | self.cp_save_path = None | ||||
# use INFO in the master, WARN for others | # use INFO in the master, WARN for others | ||||
init_logger(log_path, level=logging.INFO if self.is_master else logging.WARNING) | |||||
self.logger = logging.getLogger(__name__) | |||||
logger.setLevel(logging.INFO if self.is_master else logging.WARNING) | |||||
self.logger = logger | |||||
self.logger.info("Setup Distributed Trainer") | self.logger.info("Setup Distributed Trainer") | ||||
self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format( | self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format( | ||||
os.getpid(), self.rank, self.local_rank, self.device, self.fp16 if self.fp16 else False)) | os.getpid(), self.rank, self.local_rank, self.device, self.fp16 if self.fp16 else False)) | ||||
@@ -56,7 +56,7 @@ from .utils import _move_model_to_device | |||||
from ._parallel_utils import _data_parallel_wrapper | from ._parallel_utils import _data_parallel_wrapper | ||||
from ._parallel_utils import _model_contains_inner_module | from ._parallel_utils import _model_contains_inner_module | ||||
from functools import partial | from functools import partial | ||||
from ..io.logger import init_logger, get_logger | |||||
from ..io import logger | |||||
__all__ = [ | __all__ = [ | ||||
"Tester" | "Tester" | ||||
@@ -104,8 +104,7 @@ class Tester(object): | |||||
self.batch_size = batch_size | self.batch_size = batch_size | ||||
self.verbose = verbose | self.verbose = verbose | ||||
self.use_tqdm = use_tqdm | self.use_tqdm = use_tqdm | ||||
init_logger(stdout='tqdm' if use_tqdm else 'plain') | |||||
self.logger = get_logger(__name__) | |||||
self.logger = logger | |||||
if isinstance(data, DataSet): | if isinstance(data, DataSet): | ||||
self.data_iterator = DataSetIter( | self.data_iterator = DataSetIter( | ||||
@@ -353,7 +353,7 @@ from .utils import _get_func_signature | |||||
from .utils import _get_model_device | from .utils import _get_model_device | ||||
from .utils import _move_model_to_device | from .utils import _move_model_to_device | ||||
from ._parallel_utils import _model_contains_inner_module | from ._parallel_utils import _model_contains_inner_module | ||||
from ..io.logger import init_logger, get_logger | |||||
from ..io import logger | |||||
class Trainer(object): | class Trainer(object): | ||||
@@ -548,11 +548,7 @@ class Trainer(object): | |||||
else: | else: | ||||
raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer))) | raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer))) | ||||
log_path = None | |||||
if save_path is not None: | |||||
log_path = os.path.join(os.path.dirname(save_path), 'log') | |||||
init_logger(path=log_path, stdout='tqdm' if use_tqdm else 'plain') | |||||
self.logger = get_logger(__name__) | |||||
self.logger = logger | |||||
self.use_tqdm = use_tqdm | self.use_tqdm = use_tqdm | ||||
self.pbar = None | self.pbar = None | ||||
@@ -701,7 +697,7 @@ class Trainer(object): | |||||
self.n_steps) + \ | self.n_steps) + \ | ||||
self.tester._format_eval_results(eval_res) | self.tester._format_eval_results(eval_res) | ||||
# pbar.write(eval_str + '\n') | # pbar.write(eval_str + '\n') | ||||
self.logger.info(eval_str) | |||||
self.logger.info(eval_str + '\n') | |||||
# ================= mini-batch end ==================== # | # ================= mini-batch end ==================== # | ||||
# lr decay; early stopping | # lr decay; early stopping | ||||
@@ -72,6 +72,8 @@ __all__ = [ | |||||
'ModelLoader', | 'ModelLoader', | ||||
'ModelSaver', | 'ModelSaver', | ||||
'logger', | |||||
] | ] | ||||
from .embed_loader import EmbedLoader | from .embed_loader import EmbedLoader | ||||
@@ -81,3 +83,4 @@ from .model_io import ModelLoader, ModelSaver | |||||
from .loader import * | from .loader import * | ||||
from .pipe import * | from .pipe import * | ||||
from ._logger import * |
@@ -6,8 +6,11 @@ import os | |||||
import sys | import sys | ||||
import warnings | import warnings | ||||
__all__ = [ | |||||
'logger', | |||||
] | |||||
__all__ = ['logger'] | |||||
ROOT_NAME = 'fastNLP' | |||||
try: | try: | ||||
import fitlog | import fitlog | ||||
@@ -39,7 +42,7 @@ else: | |||||
self.setLevel(level) | self.setLevel(level) | ||||
def get_level(level): | |||||
def _get_level(level): | |||||
if isinstance(level, int): | if isinstance(level, int): | ||||
pass | pass | ||||
else: | else: | ||||
@@ -50,22 +53,45 @@ def get_level(level): | |||||
return level | return level | ||||
def init_logger(path=None, stdout='tqdm', level='INFO'): | |||||
"""initialize logger""" | |||||
if stdout not in ['none', 'plain', 'tqdm']: | |||||
raise ValueError('stdout must in one of {}'.format(['none', 'plain', 'tqdm'])) | |||||
def _add_file_handler(logger, path, level='INFO'): | |||||
for h in logger.handlers: | |||||
if isinstance(h, logging.FileHandler): | |||||
if os.path.abspath(path) == h.baseFilename: | |||||
# file path already added | |||||
return | |||||
level = get_level(level) | |||||
# File Handler | |||||
if os.path.exists(path): | |||||
assert os.path.isfile(path) | |||||
warnings.warn('log already exists in {}'.format(path)) | |||||
dirname = os.path.abspath(os.path.dirname(path)) | |||||
os.makedirs(dirname, exist_ok=True) | |||||
logger = logging.getLogger('fastNLP') | |||||
logger.setLevel(level) | |||||
handlers_type = set([type(h) for h in logger.handlers]) | |||||
file_handler = logging.FileHandler(path, mode='a') | |||||
file_handler.setLevel(_get_level(level)) | |||||
file_formatter = logging.Formatter(fmt='%(asctime)s - [%(levelname)s] - %(message)s', | |||||
datefmt='%Y/%m/%d %H:%M:%S') | |||||
file_handler.setFormatter(file_formatter) | |||||
logger.addHandler(file_handler) | |||||
def _set_stdout_handler(logger, stdout='tqdm', level='INFO'): | |||||
level = _get_level(level) | |||||
if stdout not in ['none', 'plain', 'tqdm']: | |||||
raise ValueError('stdout must in one of {}'.format(['none', 'plain', 'tqdm'])) | |||||
# make sure to initialize logger only once | # make sure to initialize logger only once | ||||
stream_handler = None | |||||
for i, h in enumerate(logger.handlers): | |||||
if isinstance(h, (logging.StreamHandler, TqdmLoggingHandler)): | |||||
stream_handler = h | |||||
break | |||||
if stream_handler is not None: | |||||
logger.removeHandler(stream_handler) | |||||
# Stream Handler | # Stream Handler | ||||
if stdout == 'plain' and (logging.StreamHandler not in handlers_type): | |||||
if stdout == 'plain': | |||||
stream_handler = logging.StreamHandler(sys.stdout) | stream_handler = logging.StreamHandler(sys.stdout) | ||||
elif stdout == 'tqdm' and (TqdmLoggingHandler not in handlers_type): | |||||
elif stdout == 'tqdm': | |||||
stream_handler = TqdmLoggingHandler(level) | stream_handler = TqdmLoggingHandler(level) | ||||
else: | else: | ||||
stream_handler = None | stream_handler = None | ||||
@@ -76,52 +102,44 @@ def init_logger(path=None, stdout='tqdm', level='INFO'): | |||||
stream_handler.setFormatter(stream_formatter) | stream_handler.setFormatter(stream_formatter) | ||||
logger.addHandler(stream_handler) | logger.addHandler(stream_handler) | ||||
# File Handler | |||||
if path is not None and (logging.FileHandler not in handlers_type): | |||||
if os.path.exists(path): | |||||
assert os.path.isfile(path) | |||||
warnings.warn('log already exists in {}'.format(path)) | |||||
dirname = os.path.abspath(os.path.dirname(path)) | |||||
os.makedirs(dirname, exist_ok=True) | |||||
file_handler = logging.FileHandler(path, mode='a') | |||||
file_handler.setLevel(level) | |||||
file_formatter = logging.Formatter(fmt='%(asctime)s - [%(levelname)s] - %(name)s - %(message)s', | |||||
datefmt='%Y/%m/%d %H:%M:%S') | |||||
file_handler.setFormatter(file_formatter) | |||||
logger.addHandler(file_handler) | |||||
return logger | |||||
def _init_logger(path=None, stdout='tqdm', level='INFO'): | |||||
"""initialize logger""" | |||||
level = _get_level(level) | |||||
# logger = logging.getLogger(ROOT_NAME) | |||||
logger = logging.getLogger() | |||||
logger.setLevel(level) | |||||
# init logger when import | |||||
logger = init_logger() | |||||
_set_stdout_handler(logger, stdout, level) | |||||
# File Handler | |||||
if path is not None: | |||||
_add_file_handler(logger, path, level) | |||||
def get_logger(name=None): | |||||
if name is None: | |||||
return logging.getLogger('fastNLP') | |||||
return logging.getLogger(name) | |||||
return logger | |||||
def set_file(path, level='INFO'): | |||||
for h in logger.handlers: | |||||
if isinstance(h, logging.FileHandler): | |||||
if os.path.abspath(path) == h.baseFilename: | |||||
# file path already added | |||||
return | |||||
def _get_logger(name=None, level='INFO'): | |||||
level = _get_level(level) | |||||
if name is None: | |||||
name = ROOT_NAME | |||||
assert isinstance(name, str) | |||||
if not name.startswith(ROOT_NAME): | |||||
name = '{}.{}'.format(ROOT_NAME, name) | |||||
logger = logging.getLogger(name) | |||||
logger.setLevel(level) | |||||
return logger | |||||
# File Handler | |||||
if os.path.exists(path): | |||||
assert os.path.isfile(path) | |||||
warnings.warn('log already exists in {}'.format(path)) | |||||
dirname = os.path.abspath(os.path.dirname(path)) | |||||
os.makedirs(dirname, exist_ok=True) | |||||
file_handler = logging.FileHandler(path, mode='a') | |||||
file_handler.setLevel(get_level(level)) | |||||
file_formatter = logging.Formatter(fmt='%(asctime)s - [%(levelname)s] - %(name)s - %(message)s', | |||||
datefmt='%Y/%m/%d %H:%M:%S') | |||||
file_handler.setFormatter(file_formatter) | |||||
logger.addHandler(file_handler) | |||||
class FastNLPLogger(logging.Logger): | |||||
def add_file(self, path, level): | |||||
_add_file_handler(self, path, level) | |||||
def set_stdout(self, stdout, level): | |||||
_set_stdout_handler(self, stdout, level) | |||||
_logger = _init_logger(path=None) | |||||
logger = FastNLPLogger(ROOT_NAME) | |||||
logger.__dict__.update(_logger.__dict__) | |||||
del _logger |
@@ -15,13 +15,14 @@ from fastNLP.core.const import Const as C | |||||
from fastNLP.core.vocabulary import VocabularyOption | from fastNLP.core.vocabulary import VocabularyOption | ||||
from fastNLP.core.dist_trainer import DistTrainer | from fastNLP.core.dist_trainer import DistTrainer | ||||
from utils.util_init import set_rng_seeds | from utils.util_init import set_rng_seeds | ||||
from fastNLP.io import logger | |||||
import os | import os | ||||
# os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' | # os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' | ||||
# os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' | # os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' | ||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | ||||
# hyper | # hyper | ||||
logger.add_file('log', 'INFO') | |||||
class Config(): | class Config(): | ||||
seed = 12345 | seed = 12345 | ||||
@@ -46,11 +47,11 @@ class Config(): | |||||
self.datapath = {k: os.path.join(self.datadir, v) | self.datapath = {k: os.path.join(self.datadir, v) | ||||
for k, v in self.datafile.items()} | for k, v in self.datafile.items()} | ||||
ops = Config() | ops = Config() | ||||
set_rng_seeds(ops.seed) | set_rng_seeds(ops.seed) | ||||
print('RNG SEED: {}'.format(ops.seed)) | |||||
# print('RNG SEED: {}'.format(ops.seed)) | |||||
logger.info('RNG SEED %d'%ops.seed) | |||||
# 1.task相关信息:利用dataloader载入dataInfo | # 1.task相关信息:利用dataloader载入dataInfo | ||||
@@ -81,8 +82,9 @@ print(embedding.embedding.weight.data.mean(), embedding.embedding.weight.data.st | |||||
# embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)]) | # embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)]) | ||||
datainfo.datasets['train'] = datainfo.datasets['train'][:1000] | datainfo.datasets['train'] = datainfo.datasets['train'][:1000] | ||||
datainfo.datasets['test'] = datainfo.datasets['test'][:1000] | datainfo.datasets['test'] = datainfo.datasets['test'][:1000] | ||||
print(datainfo) | |||||
print(datainfo.datasets['train'][0]) | |||||
# print(datainfo) | |||||
# print(datainfo.datasets['train'][0]) | |||||
logger.info(datainfo) | |||||
model = DPCNN(init_embed=embedding, num_cls=len(datainfo.vocabs[C.TARGET]), | model = DPCNN(init_embed=embedding, num_cls=len(datainfo.vocabs[C.TARGET]), | ||||
embed_dropout=ops.embed_dropout, cls_dropout=ops.cls_dropout) | embed_dropout=ops.embed_dropout, cls_dropout=ops.cls_dropout) | ||||
@@ -108,12 +110,13 @@ callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5))) | |||||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu' | device = 'cuda:0' if torch.cuda.is_available() else 'cpu' | ||||
print(device) | |||||
# print(device) | |||||
logger.info(device) | |||||
# 4.定义train方法 | # 4.定义train方法 | ||||
trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, | trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, | ||||
sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size), | sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size), | ||||
metrics=[metric], use_tqdm=False, | |||||
metrics=[metric], use_tqdm=False, save_path='save', | |||||
dev_data=datainfo.datasets['test'], device=device, | dev_data=datainfo.datasets['test'], device=device, | ||||
check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, | check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, | ||||
n_epochs=ops.train_epoch, num_workers=4) | n_epochs=ops.train_epoch, num_workers=4) | ||||