Browse Source

[update] logger, support straightly import logger to use

tags/v0.4.10
yunfan 5 years ago
parent
commit
3b8bc469ba
7 changed files with 93 additions and 74 deletions
  1. +2
    -2
      fastNLP/core/callback.py
  2. +4
    -4
      fastNLP/core/dist_trainer.py
  3. +2
    -3
      fastNLP/core/tester.py
  4. +3
    -7
      fastNLP/core/trainer.py
  5. +3
    -0
      fastNLP/io/__init__.py
  6. +69
    -51
      fastNLP/io/_logger.py
  7. +10
    -7
      reproduction/text_classification/train_dpcnn.py

+ 2
- 2
fastNLP/core/callback.py View File

@@ -86,7 +86,7 @@ except:
from ..io.model_io import ModelSaver, ModelLoader
from .dataset import DataSet
from .tester import Tester
import logging
from ..io import logger

try:
import fitlog
@@ -178,7 +178,7 @@ class Callback(object):

@property
def logger(self):
return getattr(self._trainer, 'logger', logging.getLogger(__name__))
return getattr(self._trainer, 'logger', logger)

def on_train_begin(self):
"""


+ 4
- 4
fastNLP/core/dist_trainer.py View File

@@ -9,7 +9,6 @@ from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
import os
from tqdm import tqdm
import logging
import time
from datetime import datetime, timedelta
from functools import partial
@@ -22,7 +21,8 @@ from .optimizer import Optimizer
from .utils import _build_args
from .utils import _move_dict_value_to_device
from .utils import _get_func_signature
from ..io.logger import init_logger
from ..io import logger
import logging
from pkg_resources import parse_version

__all__ = [
@@ -140,8 +140,8 @@ class DistTrainer():
self.cp_save_path = None

# use INFO in the master, WARN for others
init_logger(log_path, level=logging.INFO if self.is_master else logging.WARNING)
self.logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO if self.is_master else logging.WARNING)
self.logger = logger
self.logger.info("Setup Distributed Trainer")
self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format(
os.getpid(), self.rank, self.local_rank, self.device, self.fp16 if self.fp16 else False))


+ 2
- 3
fastNLP/core/tester.py View File

@@ -56,7 +56,7 @@ from .utils import _move_model_to_device
from ._parallel_utils import _data_parallel_wrapper
from ._parallel_utils import _model_contains_inner_module
from functools import partial
from ..io.logger import init_logger, get_logger
from ..io import logger

__all__ = [
"Tester"
@@ -104,8 +104,7 @@ class Tester(object):
self.batch_size = batch_size
self.verbose = verbose
self.use_tqdm = use_tqdm
init_logger(stdout='tqdm' if use_tqdm else 'plain')
self.logger = get_logger(__name__)
self.logger = logger

if isinstance(data, DataSet):
self.data_iterator = DataSetIter(


+ 3
- 7
fastNLP/core/trainer.py View File

@@ -353,7 +353,7 @@ from .utils import _get_func_signature
from .utils import _get_model_device
from .utils import _move_model_to_device
from ._parallel_utils import _model_contains_inner_module
from ..io.logger import init_logger, get_logger
from ..io import logger


class Trainer(object):
@@ -548,11 +548,7 @@ class Trainer(object):
else:
raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer)))

log_path = None
if save_path is not None:
log_path = os.path.join(os.path.dirname(save_path), 'log')
init_logger(path=log_path, stdout='tqdm' if use_tqdm else 'plain')
self.logger = get_logger(__name__)
self.logger = logger

self.use_tqdm = use_tqdm
self.pbar = None
@@ -701,7 +697,7 @@ class Trainer(object):
self.n_steps) + \
self.tester._format_eval_results(eval_res)
# pbar.write(eval_str + '\n')
self.logger.info(eval_str)
self.logger.info(eval_str + '\n')
# ================= mini-batch end ==================== #

# lr decay; early stopping


+ 3
- 0
fastNLP/io/__init__.py View File

@@ -72,6 +72,8 @@ __all__ = [

'ModelLoader',
'ModelSaver',

'logger',
]

from .embed_loader import EmbedLoader
@@ -81,3 +83,4 @@ from .model_io import ModelLoader, ModelSaver

from .loader import *
from .pipe import *
from ._logger import *

fastNLP/io/logger.py → fastNLP/io/_logger.py View File

@@ -6,8 +6,11 @@ import os
import sys
import warnings

__all__ = [
'logger',
]

__all__ = ['logger']
ROOT_NAME = 'fastNLP'

try:
import fitlog
@@ -39,7 +42,7 @@ else:
self.setLevel(level)


def get_level(level):
def _get_level(level):
if isinstance(level, int):
pass
else:
@@ -50,22 +53,45 @@ def get_level(level):
return level


def init_logger(path=None, stdout='tqdm', level='INFO'):
"""initialize logger"""
if stdout not in ['none', 'plain', 'tqdm']:
raise ValueError('stdout must in one of {}'.format(['none', 'plain', 'tqdm']))
def _add_file_handler(logger, path, level='INFO'):
for h in logger.handlers:
if isinstance(h, logging.FileHandler):
if os.path.abspath(path) == h.baseFilename:
# file path already added
return

level = get_level(level)
# File Handler
if os.path.exists(path):
assert os.path.isfile(path)
warnings.warn('log already exists in {}'.format(path))
dirname = os.path.abspath(os.path.dirname(path))
os.makedirs(dirname, exist_ok=True)

logger = logging.getLogger('fastNLP')
logger.setLevel(level)
handlers_type = set([type(h) for h in logger.handlers])
file_handler = logging.FileHandler(path, mode='a')
file_handler.setLevel(_get_level(level))
file_formatter = logging.Formatter(fmt='%(asctime)s - [%(levelname)s] - %(message)s',
datefmt='%Y/%m/%d %H:%M:%S')
file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler)


def _set_stdout_handler(logger, stdout='tqdm', level='INFO'):
level = _get_level(level)
if stdout not in ['none', 'plain', 'tqdm']:
raise ValueError('stdout must in one of {}'.format(['none', 'plain', 'tqdm']))
# make sure to initialize logger only once
stream_handler = None
for i, h in enumerate(logger.handlers):
if isinstance(h, (logging.StreamHandler, TqdmLoggingHandler)):
stream_handler = h
break
if stream_handler is not None:
logger.removeHandler(stream_handler)

# Stream Handler
if stdout == 'plain' and (logging.StreamHandler not in handlers_type):
if stdout == 'plain':
stream_handler = logging.StreamHandler(sys.stdout)
elif stdout == 'tqdm' and (TqdmLoggingHandler not in handlers_type):
elif stdout == 'tqdm':
stream_handler = TqdmLoggingHandler(level)
else:
stream_handler = None
@@ -76,52 +102,44 @@ def init_logger(path=None, stdout='tqdm', level='INFO'):
stream_handler.setFormatter(stream_formatter)
logger.addHandler(stream_handler)

# File Handler
if path is not None and (logging.FileHandler not in handlers_type):
if os.path.exists(path):
assert os.path.isfile(path)
warnings.warn('log already exists in {}'.format(path))
dirname = os.path.abspath(os.path.dirname(path))
os.makedirs(dirname, exist_ok=True)

file_handler = logging.FileHandler(path, mode='a')
file_handler.setLevel(level)
file_formatter = logging.Formatter(fmt='%(asctime)s - [%(levelname)s] - %(name)s - %(message)s',
datefmt='%Y/%m/%d %H:%M:%S')
file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler)

return logger
def _init_logger(path=None, stdout='tqdm', level='INFO'):
"""initialize logger"""
level = _get_level(level)

# logger = logging.getLogger(ROOT_NAME)
logger = logging.getLogger()
logger.setLevel(level)

# init logger when import
logger = init_logger()
_set_stdout_handler(logger, stdout, level)

# File Handler
if path is not None:
_add_file_handler(logger, path, level)

def get_logger(name=None):
if name is None:
return logging.getLogger('fastNLP')
return logging.getLogger(name)
return logger


def set_file(path, level='INFO'):
for h in logger.handlers:
if isinstance(h, logging.FileHandler):
if os.path.abspath(path) == h.baseFilename:
# file path already added
return
def _get_logger(name=None, level='INFO'):
level = _get_level(level)
if name is None:
name = ROOT_NAME
assert isinstance(name, str)
if not name.startswith(ROOT_NAME):
name = '{}.{}'.format(ROOT_NAME, name)
logger = logging.getLogger(name)
logger.setLevel(level)
return logger

# File Handler
if os.path.exists(path):
assert os.path.isfile(path)
warnings.warn('log already exists in {}'.format(path))
dirname = os.path.abspath(os.path.dirname(path))
os.makedirs(dirname, exist_ok=True)

file_handler = logging.FileHandler(path, mode='a')
file_handler.setLevel(get_level(level))
file_formatter = logging.Formatter(fmt='%(asctime)s - [%(levelname)s] - %(name)s - %(message)s',
datefmt='%Y/%m/%d %H:%M:%S')
file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler)
class FastNLPLogger(logging.Logger):
def add_file(self, path, level):
_add_file_handler(self, path, level)

def set_stdout(self, stdout, level):
_set_stdout_handler(self, stdout, level)

_logger = _init_logger(path=None)
logger = FastNLPLogger(ROOT_NAME)
logger.__dict__.update(_logger.__dict__)
del _logger

+ 10
- 7
reproduction/text_classification/train_dpcnn.py View File

@@ -15,13 +15,14 @@ from fastNLP.core.const import Const as C
from fastNLP.core.vocabulary import VocabularyOption
from fastNLP.core.dist_trainer import DistTrainer
from utils.util_init import set_rng_seeds
from fastNLP.io import logger
import os
# os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/'
# os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches'
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"


# hyper
logger.add_file('log', 'INFO')

class Config():
seed = 12345
@@ -46,11 +47,11 @@ class Config():
self.datapath = {k: os.path.join(self.datadir, v)
for k, v in self.datafile.items()}


ops = Config()

set_rng_seeds(ops.seed)
print('RNG SEED: {}'.format(ops.seed))
# print('RNG SEED: {}'.format(ops.seed))
logger.info('RNG SEED %d'%ops.seed)

# 1.task相关信息:利用dataloader载入dataInfo

@@ -81,8 +82,9 @@ print(embedding.embedding.weight.data.mean(), embedding.embedding.weight.data.st
# embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)])
datainfo.datasets['train'] = datainfo.datasets['train'][:1000]
datainfo.datasets['test'] = datainfo.datasets['test'][:1000]
print(datainfo)
print(datainfo.datasets['train'][0])
# print(datainfo)
# print(datainfo.datasets['train'][0])
logger.info(datainfo)

model = DPCNN(init_embed=embedding, num_cls=len(datainfo.vocabs[C.TARGET]),
embed_dropout=ops.embed_dropout, cls_dropout=ops.cls_dropout)
@@ -108,12 +110,13 @@ callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5)))

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

print(device)
# print(device)
logger.info(device)

# 4.定义train方法
trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss,
sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size),
metrics=[metric], use_tqdm=False,
metrics=[metric], use_tqdm=False, save_path='save',
dev_data=datainfo.datasets['test'], device=device,
check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks,
n_epochs=ops.train_epoch, num_workers=4)


Loading…
Cancel
Save