Browse Source

[update] logger, support straightly import logger to use

tags/v0.4.10
yunfan 6 years ago
parent
commit
3b8bc469ba
7 changed files with 93 additions and 74 deletions
  1. +2
    -2
      fastNLP/core/callback.py
  2. +4
    -4
      fastNLP/core/dist_trainer.py
  3. +2
    -3
      fastNLP/core/tester.py
  4. +3
    -7
      fastNLP/core/trainer.py
  5. +3
    -0
      fastNLP/io/__init__.py
  6. +69
    -51
      fastNLP/io/_logger.py
  7. +10
    -7
      reproduction/text_classification/train_dpcnn.py

+ 2
- 2
fastNLP/core/callback.py View File

@@ -86,7 +86,7 @@ except:
from ..io.model_io import ModelSaver, ModelLoader from ..io.model_io import ModelSaver, ModelLoader
from .dataset import DataSet from .dataset import DataSet
from .tester import Tester from .tester import Tester
import logging
from ..io import logger


try: try:
import fitlog import fitlog
@@ -178,7 +178,7 @@ class Callback(object):


@property @property
def logger(self): def logger(self):
return getattr(self._trainer, 'logger', logging.getLogger(__name__))
return getattr(self._trainer, 'logger', logger)


def on_train_begin(self): def on_train_begin(self):
""" """


+ 4
- 4
fastNLP/core/dist_trainer.py View File

@@ -9,7 +9,6 @@ from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
import os import os
from tqdm import tqdm from tqdm import tqdm
import logging
import time import time
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import partial from functools import partial
@@ -22,7 +21,8 @@ from .optimizer import Optimizer
from .utils import _build_args from .utils import _build_args
from .utils import _move_dict_value_to_device from .utils import _move_dict_value_to_device
from .utils import _get_func_signature from .utils import _get_func_signature
from ..io.logger import init_logger
from ..io import logger
import logging
from pkg_resources import parse_version from pkg_resources import parse_version


__all__ = [ __all__ = [
@@ -140,8 +140,8 @@ class DistTrainer():
self.cp_save_path = None self.cp_save_path = None


# use INFO in the master, WARN for others # use INFO in the master, WARN for others
init_logger(log_path, level=logging.INFO if self.is_master else logging.WARNING)
self.logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO if self.is_master else logging.WARNING)
self.logger = logger
self.logger.info("Setup Distributed Trainer") self.logger.info("Setup Distributed Trainer")
self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format( self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format(
os.getpid(), self.rank, self.local_rank, self.device, self.fp16 if self.fp16 else False)) os.getpid(), self.rank, self.local_rank, self.device, self.fp16 if self.fp16 else False))


+ 2
- 3
fastNLP/core/tester.py View File

@@ -56,7 +56,7 @@ from .utils import _move_model_to_device
from ._parallel_utils import _data_parallel_wrapper from ._parallel_utils import _data_parallel_wrapper
from ._parallel_utils import _model_contains_inner_module from ._parallel_utils import _model_contains_inner_module
from functools import partial from functools import partial
from ..io.logger import init_logger, get_logger
from ..io import logger


__all__ = [ __all__ = [
"Tester" "Tester"
@@ -104,8 +104,7 @@ class Tester(object):
self.batch_size = batch_size self.batch_size = batch_size
self.verbose = verbose self.verbose = verbose
self.use_tqdm = use_tqdm self.use_tqdm = use_tqdm
init_logger(stdout='tqdm' if use_tqdm else 'plain')
self.logger = get_logger(__name__)
self.logger = logger


if isinstance(data, DataSet): if isinstance(data, DataSet):
self.data_iterator = DataSetIter( self.data_iterator = DataSetIter(


+ 3
- 7
fastNLP/core/trainer.py View File

@@ -353,7 +353,7 @@ from .utils import _get_func_signature
from .utils import _get_model_device from .utils import _get_model_device
from .utils import _move_model_to_device from .utils import _move_model_to_device
from ._parallel_utils import _model_contains_inner_module from ._parallel_utils import _model_contains_inner_module
from ..io.logger import init_logger, get_logger
from ..io import logger




class Trainer(object): class Trainer(object):
@@ -548,11 +548,7 @@ class Trainer(object):
else: else:
raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer))) raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer)))


log_path = None
if save_path is not None:
log_path = os.path.join(os.path.dirname(save_path), 'log')
init_logger(path=log_path, stdout='tqdm' if use_tqdm else 'plain')
self.logger = get_logger(__name__)
self.logger = logger


self.use_tqdm = use_tqdm self.use_tqdm = use_tqdm
self.pbar = None self.pbar = None
@@ -701,7 +697,7 @@ class Trainer(object):
self.n_steps) + \ self.n_steps) + \
self.tester._format_eval_results(eval_res) self.tester._format_eval_results(eval_res)
# pbar.write(eval_str + '\n') # pbar.write(eval_str + '\n')
self.logger.info(eval_str)
self.logger.info(eval_str + '\n')
# ================= mini-batch end ==================== # # ================= mini-batch end ==================== #


# lr decay; early stopping # lr decay; early stopping


+ 3
- 0
fastNLP/io/__init__.py View File

@@ -72,6 +72,8 @@ __all__ = [


'ModelLoader', 'ModelLoader',
'ModelSaver', 'ModelSaver',

'logger',
] ]


from .embed_loader import EmbedLoader from .embed_loader import EmbedLoader
@@ -81,3 +83,4 @@ from .model_io import ModelLoader, ModelSaver


from .loader import * from .loader import *
from .pipe import * from .pipe import *
from ._logger import *

fastNLP/io/logger.py → fastNLP/io/_logger.py View File

@@ -6,8 +6,11 @@ import os
import sys import sys
import warnings import warnings


__all__ = [
'logger',
]


__all__ = ['logger']
ROOT_NAME = 'fastNLP'


try: try:
import fitlog import fitlog
@@ -39,7 +42,7 @@ else:
self.setLevel(level) self.setLevel(level)




def get_level(level):
def _get_level(level):
if isinstance(level, int): if isinstance(level, int):
pass pass
else: else:
@@ -50,22 +53,45 @@ def get_level(level):
return level return level




def init_logger(path=None, stdout='tqdm', level='INFO'):
"""initialize logger"""
if stdout not in ['none', 'plain', 'tqdm']:
raise ValueError('stdout must in one of {}'.format(['none', 'plain', 'tqdm']))
def _add_file_handler(logger, path, level='INFO'):
for h in logger.handlers:
if isinstance(h, logging.FileHandler):
if os.path.abspath(path) == h.baseFilename:
# file path already added
return


level = get_level(level)
# File Handler
if os.path.exists(path):
assert os.path.isfile(path)
warnings.warn('log already exists in {}'.format(path))
dirname = os.path.abspath(os.path.dirname(path))
os.makedirs(dirname, exist_ok=True)


logger = logging.getLogger('fastNLP')
logger.setLevel(level)
handlers_type = set([type(h) for h in logger.handlers])
file_handler = logging.FileHandler(path, mode='a')
file_handler.setLevel(_get_level(level))
file_formatter = logging.Formatter(fmt='%(asctime)s - [%(levelname)s] - %(message)s',
datefmt='%Y/%m/%d %H:%M:%S')
file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler)



def _set_stdout_handler(logger, stdout='tqdm', level='INFO'):
level = _get_level(level)
if stdout not in ['none', 'plain', 'tqdm']:
raise ValueError('stdout must in one of {}'.format(['none', 'plain', 'tqdm']))
# make sure to initialize logger only once # make sure to initialize logger only once
stream_handler = None
for i, h in enumerate(logger.handlers):
if isinstance(h, (logging.StreamHandler, TqdmLoggingHandler)):
stream_handler = h
break
if stream_handler is not None:
logger.removeHandler(stream_handler)

# Stream Handler # Stream Handler
if stdout == 'plain' and (logging.StreamHandler not in handlers_type):
if stdout == 'plain':
stream_handler = logging.StreamHandler(sys.stdout) stream_handler = logging.StreamHandler(sys.stdout)
elif stdout == 'tqdm' and (TqdmLoggingHandler not in handlers_type):
elif stdout == 'tqdm':
stream_handler = TqdmLoggingHandler(level) stream_handler = TqdmLoggingHandler(level)
else: else:
stream_handler = None stream_handler = None
@@ -76,52 +102,44 @@ def init_logger(path=None, stdout='tqdm', level='INFO'):
stream_handler.setFormatter(stream_formatter) stream_handler.setFormatter(stream_formatter)
logger.addHandler(stream_handler) logger.addHandler(stream_handler)


# File Handler
if path is not None and (logging.FileHandler not in handlers_type):
if os.path.exists(path):
assert os.path.isfile(path)
warnings.warn('log already exists in {}'.format(path))
dirname = os.path.abspath(os.path.dirname(path))
os.makedirs(dirname, exist_ok=True)

file_handler = logging.FileHandler(path, mode='a')
file_handler.setLevel(level)
file_formatter = logging.Formatter(fmt='%(asctime)s - [%(levelname)s] - %(name)s - %(message)s',
datefmt='%Y/%m/%d %H:%M:%S')
file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler)


return logger
def _init_logger(path=None, stdout='tqdm', level='INFO'):
"""initialize logger"""
level = _get_level(level)


# logger = logging.getLogger(ROOT_NAME)
logger = logging.getLogger()
logger.setLevel(level)


# init logger when import
logger = init_logger()
_set_stdout_handler(logger, stdout, level)


# File Handler
if path is not None:
_add_file_handler(logger, path, level)


def get_logger(name=None):
if name is None:
return logging.getLogger('fastNLP')
return logging.getLogger(name)
return logger




def set_file(path, level='INFO'):
for h in logger.handlers:
if isinstance(h, logging.FileHandler):
if os.path.abspath(path) == h.baseFilename:
# file path already added
return
def _get_logger(name=None, level='INFO'):
level = _get_level(level)
if name is None:
name = ROOT_NAME
assert isinstance(name, str)
if not name.startswith(ROOT_NAME):
name = '{}.{}'.format(ROOT_NAME, name)
logger = logging.getLogger(name)
logger.setLevel(level)
return logger


# File Handler
if os.path.exists(path):
assert os.path.isfile(path)
warnings.warn('log already exists in {}'.format(path))
dirname = os.path.abspath(os.path.dirname(path))
os.makedirs(dirname, exist_ok=True)


file_handler = logging.FileHandler(path, mode='a')
file_handler.setLevel(get_level(level))
file_formatter = logging.Formatter(fmt='%(asctime)s - [%(levelname)s] - %(name)s - %(message)s',
datefmt='%Y/%m/%d %H:%M:%S')
file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler)
class FastNLPLogger(logging.Logger):
def add_file(self, path, level):
_add_file_handler(self, path, level)

def set_stdout(self, stdout, level):
_set_stdout_handler(self, stdout, level)


_logger = _init_logger(path=None)
logger = FastNLPLogger(ROOT_NAME)
logger.__dict__.update(_logger.__dict__)
del _logger

+ 10
- 7
reproduction/text_classification/train_dpcnn.py View File

@@ -15,13 +15,14 @@ from fastNLP.core.const import Const as C
from fastNLP.core.vocabulary import VocabularyOption from fastNLP.core.vocabulary import VocabularyOption
from fastNLP.core.dist_trainer import DistTrainer from fastNLP.core.dist_trainer import DistTrainer
from utils.util_init import set_rng_seeds from utils.util_init import set_rng_seeds
from fastNLP.io import logger
import os import os
# os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' # os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/'
# os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' # os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches'
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"



# hyper # hyper
logger.add_file('log', 'INFO')


class Config(): class Config():
seed = 12345 seed = 12345
@@ -46,11 +47,11 @@ class Config():
self.datapath = {k: os.path.join(self.datadir, v) self.datapath = {k: os.path.join(self.datadir, v)
for k, v in self.datafile.items()} for k, v in self.datafile.items()}



ops = Config() ops = Config()


set_rng_seeds(ops.seed) set_rng_seeds(ops.seed)
print('RNG SEED: {}'.format(ops.seed))
# print('RNG SEED: {}'.format(ops.seed))
logger.info('RNG SEED %d'%ops.seed)


# 1.task相关信息:利用dataloader载入dataInfo # 1.task相关信息:利用dataloader载入dataInfo


@@ -81,8 +82,9 @@ print(embedding.embedding.weight.data.mean(), embedding.embedding.weight.data.st
# embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)]) # embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)])
datainfo.datasets['train'] = datainfo.datasets['train'][:1000] datainfo.datasets['train'] = datainfo.datasets['train'][:1000]
datainfo.datasets['test'] = datainfo.datasets['test'][:1000] datainfo.datasets['test'] = datainfo.datasets['test'][:1000]
print(datainfo)
print(datainfo.datasets['train'][0])
# print(datainfo)
# print(datainfo.datasets['train'][0])
logger.info(datainfo)


model = DPCNN(init_embed=embedding, num_cls=len(datainfo.vocabs[C.TARGET]), model = DPCNN(init_embed=embedding, num_cls=len(datainfo.vocabs[C.TARGET]),
embed_dropout=ops.embed_dropout, cls_dropout=ops.cls_dropout) embed_dropout=ops.embed_dropout, cls_dropout=ops.cls_dropout)
@@ -108,12 +110,13 @@ callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5)))


device = 'cuda:0' if torch.cuda.is_available() else 'cpu' device = 'cuda:0' if torch.cuda.is_available() else 'cpu'


print(device)
# print(device)
logger.info(device)


# 4.定义train方法 # 4.定义train方法
trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss,
sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size), sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size),
metrics=[metric], use_tqdm=False,
metrics=[metric], use_tqdm=False, save_path='save',
dev_data=datainfo.datasets['test'], device=device, dev_data=datainfo.datasets['test'], device=device,
check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks,
n_epochs=ops.train_epoch, num_workers=4) n_epochs=ops.train_epoch, num_workers=4)


Loading…
Cancel
Save