@@ -59,7 +59,9 @@ __all__ = [ | |||||
"NLLLoss", | "NLLLoss", | ||||
"LossInForward", | "LossInForward", | ||||
"cache_results" | |||||
"cache_results", | |||||
'logger' | |||||
] | ] | ||||
__version__ = '0.4.5' | __version__ = '0.4.5' | ||||
@@ -28,3 +28,4 @@ from .tester import Tester | |||||
from .trainer import Trainer | from .trainer import Trainer | ||||
from .utils import cache_results, seq_len_to_mask, get_seq_len | from .utils import cache_results, seq_len_to_mask, get_seq_len | ||||
from .vocabulary import Vocabulary | from .vocabulary import Vocabulary | ||||
from ._logger import logger |
@@ -69,7 +69,7 @@ def _add_file_handler(logger, path, level='INFO'): | |||||
file_handler = logging.FileHandler(path, mode='a') | file_handler = logging.FileHandler(path, mode='a') | ||||
file_handler.setLevel(_get_level(level)) | file_handler.setLevel(_get_level(level)) | ||||
file_formatter = logging.Formatter(fmt='%(asctime)s - [%(levelname)s] - %(message)s', | |||||
file_formatter = logging.Formatter(fmt='%(asctime)s - %(module)s - [%(levelname)s] - %(message)s', | |||||
datefmt='%Y/%m/%d %H:%M:%S') | datefmt='%Y/%m/%d %H:%M:%S') | ||||
file_handler.setFormatter(file_formatter) | file_handler.setFormatter(file_formatter) | ||||
logger.addHandler(file_handler) | logger.addHandler(file_handler) | ||||
@@ -97,18 +97,36 @@ def _set_stdout_handler(logger, stdout='tqdm', level='INFO'): | |||||
stream_handler = None | stream_handler = None | ||||
if stream_handler is not None: | if stream_handler is not None: | ||||
stream_formatter = logging.Formatter('[%(levelname)s] %(message)s') | |||||
stream_formatter = logging.Formatter('%(message)s') | |||||
stream_handler.setLevel(level) | stream_handler.setLevel(level) | ||||
stream_handler.setFormatter(stream_formatter) | stream_handler.setFormatter(stream_formatter) | ||||
logger.addHandler(stream_handler) | logger.addHandler(stream_handler) | ||||
class FastNLPLogger(logging.getLoggerClass()): | |||||
def __init__(self, name): | |||||
super().__init__(name) | |||||
def add_file(self, path='./log.txt', level='INFO'): | |||||
"""add log output file and level""" | |||||
_add_file_handler(self, path, level) | |||||
def set_stdout(self, stdout='tqdm', level='INFO'): | |||||
"""set stdout format and level""" | |||||
_set_stdout_handler(self, stdout, level) | |||||
logging.setLoggerClass(FastNLPLogger) | |||||
# print(logging.getLoggerClass()) | |||||
# print(logging.getLogger()) | |||||
def _init_logger(path=None, stdout='tqdm', level='INFO'): | def _init_logger(path=None, stdout='tqdm', level='INFO'): | ||||
"""initialize logger""" | """initialize logger""" | ||||
level = _get_level(level) | level = _get_level(level) | ||||
# logger = logging.getLogger(ROOT_NAME) | |||||
logger = logging.getLogger() | |||||
# logger = logging.getLogger() | |||||
logger = logging.getLogger(ROOT_NAME) | |||||
logger.propagate = False | |||||
logger.setLevel(level) | logger.setLevel(level) | ||||
_set_stdout_handler(logger, stdout, level) | _set_stdout_handler(logger, stdout, level) | ||||
@@ -132,16 +150,4 @@ def _get_logger(name=None, level='INFO'): | |||||
return logger | return logger | ||||
class FastNLPLogger(logging.Logger): | |||||
def add_file(self, path='./log.txt', level='INFO'): | |||||
"""add log output file and level""" | |||||
_add_file_handler(self, path, level) | |||||
def set_stdout(self, stdout='tqdm', level='INFO'): | |||||
"""set stdout format and level""" | |||||
_set_stdout_handler(self, stdout, level) | |||||
_logger = _init_logger(path=None) | |||||
logger = FastNLPLogger(ROOT_NAME) | |||||
logger.__dict__.update(_logger.__dict__) | |||||
del _logger | |||||
logger = _init_logger(path=None) |
@@ -86,7 +86,7 @@ except: | |||||
from ..io.model_io import ModelSaver, ModelLoader | from ..io.model_io import ModelSaver, ModelLoader | ||||
from .dataset import DataSet | from .dataset import DataSet | ||||
from .tester import Tester | from .tester import Tester | ||||
from ..io import logger | |||||
from ._logger import logger | |||||
try: | try: | ||||
import fitlog | import fitlog | ||||
@@ -56,7 +56,7 @@ from .utils import _move_model_to_device | |||||
from ._parallel_utils import _data_parallel_wrapper | from ._parallel_utils import _data_parallel_wrapper | ||||
from ._parallel_utils import _model_contains_inner_module | from ._parallel_utils import _model_contains_inner_module | ||||
from functools import partial | from functools import partial | ||||
from ..io import logger | |||||
from ._logger import logger | |||||
__all__ = [ | __all__ = [ | ||||
"Tester" | "Tester" | ||||
@@ -353,8 +353,7 @@ from .utils import _get_func_signature | |||||
from .utils import _get_model_device | from .utils import _get_model_device | ||||
from .utils import _move_model_to_device | from .utils import _move_model_to_device | ||||
from ._parallel_utils import _model_contains_inner_module | from ._parallel_utils import _model_contains_inner_module | ||||
from ..io import logger | |||||
from ._logger import logger | |||||
class Trainer(object): | class Trainer(object): | ||||
""" | """ | ||||
@@ -17,7 +17,7 @@ import numpy as np | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
from typing import List | from typing import List | ||||
import logging | |||||
from ._logger import logger | |||||
_CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | ||||
'varargs']) | 'varargs']) | ||||
@@ -661,7 +661,7 @@ class _pseudo_tqdm: | |||||
当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据 | 当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据 | ||||
""" | """ | ||||
def __init__(self, **kwargs): | def __init__(self, **kwargs): | ||||
self.logger = logging.getLogger(__name__) | |||||
self.logger = logger | |||||
def write(self, info): | def write(self, info): | ||||
self.logger.info(info) | self.logger.info(info) | ||||
@@ -74,7 +74,6 @@ __all__ = [ | |||||
'ModelLoader', | 'ModelLoader', | ||||
'ModelSaver', | 'ModelSaver', | ||||
'logger', | |||||
] | ] | ||||
from .embed_loader import EmbedLoader | from .embed_loader import EmbedLoader | ||||
@@ -84,4 +83,3 @@ from .model_io import ModelLoader, ModelSaver | |||||
from .loader import * | from .loader import * | ||||
from .pipe import * | from .pipe import * | ||||
from ._logger import * |
@@ -15,14 +15,14 @@ from fastNLP.core.const import Const as C | |||||
from fastNLP.core.vocabulary import VocabularyOption | from fastNLP.core.vocabulary import VocabularyOption | ||||
from fastNLP.core.dist_trainer import DistTrainer | from fastNLP.core.dist_trainer import DistTrainer | ||||
from utils.util_init import set_rng_seeds | from utils.util_init import set_rng_seeds | ||||
from fastNLP.io import logger | |||||
from fastNLP import logger | |||||
import os | import os | ||||
# os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' | # os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' | ||||
# os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' | # os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' | ||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | ||||
# hyper | # hyper | ||||
logger.add_file('log', 'INFO') | logger.add_file('log', 'INFO') | ||||
print(logger.handlers) | |||||
class Config(): | class Config(): | ||||
seed = 12345 | seed = 12345 | ||||