@@ -10,8 +10,72 @@ core 模块里实现了 fastNLP 的核心框架,常用的功能都可以从 fa | |||
对于常用的功能,你只需要在 :doc:`fastNLP` 中查看即可。如果想了解各个子模块的具体作用,您可以在下面找到每个子模块的具体文档。 | |||
""" | |||
__all__ = [ | |||
"DataSet", | |||
"Instance", | |||
"FieldArray", | |||
"Padder", | |||
"AutoPadder", | |||
"EngChar2DPadder", | |||
"Vocabulary", | |||
"DataSetIter", | |||
"BatchIter", | |||
"TorchLoaderIter", | |||
"Const", | |||
"Tester", | |||
"Trainer", | |||
"cache_results", | |||
"seq_len_to_mask", | |||
"get_seq_len", | |||
"logger", | |||
"Callback", | |||
"GradientClipCallback", | |||
"EarlyStopCallback", | |||
"FitlogCallback", | |||
"EvaluateCallback", | |||
"LRScheduler", | |||
"ControlC", | |||
"LRFinder", | |||
"TensorboardCallback", | |||
"WarmupCallback", | |||
'SaveModelCallback', | |||
"EchoCallback", | |||
"TesterCallback", | |||
"CallbackException", | |||
"EarlyStopError", | |||
"LossFunc", | |||
"CrossEntropyLoss", | |||
"L1Loss", | |||
"BCELoss", | |||
"NLLLoss", | |||
"LossInForward", | |||
"AccuracyMetric", | |||
"SpanFPreRecMetric", | |||
"ExtractiveQAMetric", | |||
"Optimizer", | |||
"SGD", | |||
"Adam", | |||
"AdamW", | |||
"SequentialSampler", | |||
"BucketSampler", | |||
"RandomSampler", | |||
"Sampler", | |||
] | |||
from ._logger import logger | |||
from .batch import DataSetIter, BatchIter, TorchLoaderIter | |||
from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \ | |||
LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, \ | |||
@@ -28,4 +92,3 @@ from .tester import Tester | |||
from .trainer import Trainer | |||
from .utils import cache_results, seq_len_to_mask, get_seq_len | |||
from .vocabulary import Vocabulary | |||
from ._logger import logger |
@@ -1,15 +1,15 @@ | |||
"""undocumented""" | |||
__all__ = [ | |||
'logger', | |||
] | |||
import logging | |||
import logging.config | |||
import torch | |||
import _pickle as pickle | |||
import os | |||
import sys | |||
import warnings | |||
__all__ = [ | |||
'logger', | |||
] | |||
ROOT_NAME = 'fastNLP' | |||
try: | |||
@@ -25,7 +25,7 @@ if tqdm is not None: | |||
class TqdmLoggingHandler(logging.Handler): | |||
def __init__(self, level=logging.INFO): | |||
super().__init__(level) | |||
def emit(self, record): | |||
try: | |||
msg = self.format(record) | |||
@@ -59,14 +59,14 @@ def _add_file_handler(logger, path, level='INFO'): | |||
if os.path.abspath(path) == h.baseFilename: | |||
# file path already added | |||
return | |||
# File Handler | |||
if os.path.exists(path): | |||
assert os.path.isfile(path) | |||
warnings.warn('log already exists in {}'.format(path)) | |||
dirname = os.path.abspath(os.path.dirname(path)) | |||
os.makedirs(dirname, exist_ok=True) | |||
file_handler = logging.FileHandler(path, mode='a') | |||
file_handler.setLevel(_get_level(level)) | |||
file_formatter = logging.Formatter(fmt='%(asctime)s - %(module)s - [%(levelname)s] - %(message)s', | |||
@@ -87,7 +87,7 @@ def _set_stdout_handler(logger, stdout='tqdm', level='INFO'): | |||
break | |||
if stream_handler is not None: | |||
logger.removeHandler(stream_handler) | |||
# Stream Handler | |||
if stdout == 'plain': | |||
stream_handler = logging.StreamHandler(sys.stdout) | |||
@@ -95,7 +95,7 @@ def _set_stdout_handler(logger, stdout='tqdm', level='INFO'): | |||
stream_handler = TqdmLoggingHandler(level) | |||
else: | |||
stream_handler = None | |||
if stream_handler is not None: | |||
stream_formatter = logging.Formatter('%(message)s') | |||
stream_handler.setLevel(level) | |||
@@ -103,38 +103,40 @@ def _set_stdout_handler(logger, stdout='tqdm', level='INFO'): | |||
logger.addHandler(stream_handler) | |||
class FastNLPLogger(logging.getLoggerClass()): | |||
def __init__(self, name): | |||
super().__init__(name) | |||
def add_file(self, path='./log.txt', level='INFO'): | |||
"""add log output file and level""" | |||
_add_file_handler(self, path, level) | |||
def set_stdout(self, stdout='tqdm', level='INFO'): | |||
"""set stdout format and level""" | |||
_set_stdout_handler(self, stdout, level) | |||
logging.setLoggerClass(FastNLPLogger) | |||
# print(logging.getLoggerClass()) | |||
# print(logging.getLogger()) | |||
def _init_logger(path=None, stdout='tqdm', level='INFO'): | |||
"""initialize logger""" | |||
level = _get_level(level) | |||
# logger = logging.getLogger() | |||
logger = logging.getLogger(ROOT_NAME) | |||
logger.propagate = False | |||
logger.setLevel(level) | |||
_set_stdout_handler(logger, stdout, level) | |||
# File Handler | |||
if path is not None: | |||
_add_file_handler(logger, path, level) | |||
return logger | |||
@@ -1,11 +1,14 @@ | |||
"""undocumented""" | |||
__all__ = [] | |||
import threading | |||
import torch | |||
from torch import nn | |||
from torch.nn.parallel.parallel_apply import get_a_var | |||
from torch.nn.parallel.scatter_gather import scatter_kwargs, gather | |||
from torch.nn.parallel.replicate import replicate | |||
from torch.nn.parallel.scatter_gather import scatter_kwargs, gather | |||
def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None): | |||
@@ -27,11 +30,11 @@ def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None): | |||
assert len(modules) == len(devices) | |||
else: | |||
devices = [None] * len(modules) | |||
lock = threading.Lock() | |||
results = {} | |||
grad_enabled = torch.is_grad_enabled() | |||
def _worker(i, module, input, kwargs, device=None): | |||
torch.set_grad_enabled(grad_enabled) | |||
if device is None: | |||
@@ -47,20 +50,20 @@ def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None): | |||
except Exception as e: | |||
with lock: | |||
results[i] = e | |||
if len(modules) > 1: | |||
threads = [threading.Thread(target=_worker, | |||
args=(i, module, input, kwargs, device)) | |||
for i, (module, input, kwargs, device) in | |||
enumerate(zip(modules, inputs, kwargs_tup, devices))] | |||
for thread in threads: | |||
thread.start() | |||
for thread in threads: | |||
thread.join() | |||
else: | |||
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) | |||
outputs = [] | |||
for i in range(len(inputs)): | |||
output = results[i] | |||
@@ -79,6 +82,7 @@ def _data_parallel_wrapper(func_name, device_ids, output_device): | |||
:param output_device: nn.DataParallel中的output_device | |||
:return: | |||
""" | |||
def wrapper(network, *inputs, **kwargs): | |||
inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0) | |||
if len(device_ids) == 1: | |||
@@ -86,6 +90,7 @@ def _data_parallel_wrapper(func_name, device_ids, output_device): | |||
replicas = replicate(network, device_ids[:len(inputs)]) | |||
outputs = parallel_apply(replicas, func_name, inputs, kwargs, device_ids[:len(replicas)]) | |||
return gather(outputs, output_device) | |||
return wrapper | |||
@@ -99,4 +104,4 @@ def _model_contains_inner_module(model): | |||
if isinstance(model, nn.Module): | |||
if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)): | |||
return True | |||
return False | |||
return False |
@@ -1,3 +1,13 @@ | |||
""" | |||
.. todo:: | |||
doc | |||
""" | |||
__all__ = [ | |||
"Const" | |||
] | |||
class Const: | |||
""" | |||
fastNLP中field命名常量。 | |||
@@ -25,47 +35,47 @@ class Const: | |||
LOSS = 'loss' | |||
RAW_WORD = 'raw_words' | |||
RAW_CHAR = 'raw_chars' | |||
@staticmethod | |||
def INPUTS(i): | |||
"""得到第 i 个 ``INPUT`` 的命名""" | |||
i = int(i) + 1 | |||
return Const.INPUT + str(i) | |||
@staticmethod | |||
def CHAR_INPUTS(i): | |||
"""得到第 i 个 ``CHAR_INPUT`` 的命名""" | |||
i = int(i) + 1 | |||
return Const.CHAR_INPUT + str(i) | |||
@staticmethod | |||
def RAW_WORDS(i): | |||
i = int(i) + 1 | |||
return Const.RAW_WORD + str(i) | |||
@staticmethod | |||
def RAW_CHARS(i): | |||
i = int(i) + 1 | |||
return Const.RAW_CHAR + str(i) | |||
@staticmethod | |||
def INPUT_LENS(i): | |||
"""得到第 i 个 ``INPUT_LEN`` 的命名""" | |||
i = int(i) + 1 | |||
return Const.INPUT_LEN + str(i) | |||
@staticmethod | |||
def OUTPUTS(i): | |||
"""得到第 i 个 ``OUTPUT`` 的命名""" | |||
i = int(i) + 1 | |||
return Const.OUTPUT + str(i) | |||
@staticmethod | |||
def TARGETS(i): | |||
"""得到第 i 个 ``TARGET`` 的命名""" | |||
i = int(i) + 1 | |||
return Const.TARGET + str(i) | |||
@staticmethod | |||
def LOSSES(i): | |||
"""得到第 i 个 ``LOSS`` 的命名""" | |||
@@ -1,29 +1,29 @@ | |||
""" | |||
"""undocumented | |||
正在开发中的分布式训练代码 | |||
""" | |||
import logging | |||
import os | |||
import time | |||
from datetime import datetime | |||
import torch | |||
import torch.cuda | |||
import torch.optim | |||
import torch.distributed as dist | |||
from torch.utils.data.distributed import DistributedSampler | |||
import torch.optim | |||
from pkg_resources import parse_version | |||
from torch.nn.parallel import DistributedDataParallel as DDP | |||
import os | |||
from torch.utils.data.distributed import DistributedSampler | |||
from tqdm import tqdm | |||
import time | |||
from datetime import datetime, timedelta | |||
from functools import partial | |||
from ._logger import logger | |||
from .batch import DataSetIter, BatchIter | |||
from .callback import DistCallbackManager, CallbackException, TesterCallback | |||
from .dataset import DataSet | |||
from .losses import _prepare_losser | |||
from .optimizer import Optimizer | |||
from .utils import _build_args | |||
from .utils import _move_dict_value_to_device | |||
from .utils import _get_func_signature | |||
from ._logger import logger | |||
import logging | |||
from pkg_resources import parse_version | |||
from .utils import _move_dict_value_to_device | |||
__all__ = [ | |||
'get_local_rank', | |||
@@ -1,18 +1,25 @@ | |||
""" | |||
.. todo:: | |||
doc | |||
""" | |||
__all__ = [ | |||
"Padder", | |||
"AutoPadder", | |||
"EngChar2DPadder", | |||
] | |||
from numbers import Number | |||
import torch | |||
import numpy as np | |||
from typing import Any | |||
from abc import abstractmethod | |||
from copy import deepcopy | |||
from collections import Counter | |||
from .utils import _is_iterable | |||
from copy import deepcopy | |||
from numbers import Number | |||
from typing import Any | |||
import numpy as np | |||
import torch | |||
from ._logger import logger | |||
from .utils import _is_iterable | |||
class SetInputOrTargetException(Exception): | |||
@@ -1,13 +1,15 @@ | |||
""" | |||
..todo:: | |||
检查这个类是否需要 | |||
""" | |||
"""undocumented""" | |||
__all__ = [ | |||
"Predictor" | |||
] | |||
from collections import defaultdict | |||
import torch | |||
from . import DataSetIter | |||
from . import DataSet | |||
from . import DataSetIter | |||
from . import SequentialSampler | |||
from .utils import _build_args, _move_dict_value_to_device, _get_model_device | |||
@@ -21,7 +23,7 @@ class Predictor(object): | |||
:param torch.nn.Module network: 用来完成预测任务的模型 | |||
""" | |||
def __init__(self, network): | |||
if not isinstance(network, torch.nn.Module): | |||
raise ValueError( | |||
@@ -29,7 +31,7 @@ class Predictor(object): | |||
self.network = network | |||
self.batch_size = 1 | |||
self.batch_output = [] | |||
def predict(self, data: DataSet, seq_len_field_name=None): | |||
"""用已经训练好的模型进行inference. | |||
@@ -41,27 +43,27 @@ class Predictor(object): | |||
raise ValueError("Only Dataset class is allowed, not {}.".format(type(data))) | |||
if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays: | |||
raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data)) | |||
prev_training = self.network.training | |||
self.network.eval() | |||
network_device = _get_model_device(self.network) | |||
batch_output = defaultdict(list) | |||
data_iterator = DataSetIter(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False) | |||
if hasattr(self.network, "predict"): | |||
predict_func = self.network.predict | |||
else: | |||
predict_func = self.network.forward | |||
with torch.no_grad(): | |||
for batch_x, _ in data_iterator: | |||
_move_dict_value_to_device(batch_x, _, device=network_device) | |||
refined_batch_x = _build_args(predict_func, **batch_x) | |||
prediction = predict_func(**refined_batch_x) | |||
if seq_len_field_name is not None: | |||
seq_lens = batch_x[seq_len_field_name].tolist() | |||
for key, value in prediction.items(): | |||
value = value.cpu().numpy() | |||
if len(value.shape) == 1 or (len(value.shape) == 2 and value.shape[1] == 1): | |||
@@ -74,6 +76,6 @@ class Predictor(object): | |||
batch_output[key].extend(tmp_batch) | |||
else: | |||
batch_output[key].append(value) | |||
self.network.train(prev_training) | |||
return batch_output |
@@ -1,16 +1,22 @@ | |||
""" | |||
.. todo:: | |||
doc | |||
""" | |||
__all__ = [ | |||
"Vocabulary", | |||
"VocabularyOption", | |||
] | |||
from functools import wraps | |||
from collections import Counter | |||
from functools import partial | |||
from functools import wraps | |||
from ._logger import logger | |||
from .dataset import DataSet | |||
from .utils import Option | |||
from functools import partial | |||
import numpy as np | |||
from .utils import _is_iterable | |||
from ._logger import logger | |||
class VocabularyOption(Option): | |||
def __init__(self, | |||
@@ -51,7 +57,7 @@ def _check_build_status(func): | |||
self.rebuild = True | |||
if self.max_size is not None and len(self.word_count) >= self.max_size: | |||
logger.info("[Warning] Vocabulary has reached the max size {} when calling {} method. " | |||
"Adding more words may cause unexpected behaviour of Vocabulary. ".format( | |||
"Adding more words may cause unexpected behaviour of Vocabulary. ".format( | |||
self.max_size, func.__name__)) | |||
return func(self, *args, **kwargs) | |||
@@ -199,7 +205,7 @@ class Vocabulary(object): | |||
self.build_reverse_vocab() | |||
self.rebuild = False | |||
return self | |||
def build_reverse_vocab(self): | |||
""" | |||
基于 `word to index` dict, 构建 `index to word` dict. | |||
@@ -279,19 +285,19 @@ class Vocabulary(object): | |||
if not isinstance(field[0][0], str) and _is_iterable(field[0][0]): | |||
raise RuntimeError("Only support field with 2 dimensions.") | |||
return [[self.to_index(c) for c in w] for w in field] | |||
new_field_name = new_field_name or field_name | |||
if type(new_field_name) == type(field_name): | |||
if isinstance(new_field_name, list): | |||
assert len(new_field_name) == len(field_name), "new_field_name should have same number elements with " \ | |||
"field_name." | |||
"field_name." | |||
elif isinstance(new_field_name, str): | |||
field_name = [field_name] | |||
new_field_name = [new_field_name] | |||
else: | |||
raise TypeError("field_name and new_field_name can only be str or List[str].") | |||
for idx, dataset in enumerate(datasets): | |||
if isinstance(dataset, DataSet): | |||
try: | |||
@@ -377,7 +383,7 @@ class Vocabulary(object): | |||
:return: bool | |||
""" | |||
return word in self._no_create_word | |||
def to_index(self, w): | |||
""" | |||
将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出``ValueError``:: | |||
@@ -8,15 +8,17 @@ __all__ = [ | |||
] | |||
from abc import abstractmethod | |||
import torch | |||
from ..core.vocabulary import Vocabulary | |||
from ..core.dataset import DataSet | |||
from .embedding import TokenEmbedding | |||
from ..core import logger | |||
from ..core.batch import DataSetIter | |||
from ..core.dataset import DataSet | |||
from ..core.sampler import SequentialSampler | |||
from ..core.utils import _move_model_to_device, _get_model_device | |||
from .embedding import TokenEmbedding | |||
from ..core import logger | |||
from ..core.vocabulary import Vocabulary | |||
class ContextualEmbedding(TokenEmbedding): | |||
def __init__(self, vocab: Vocabulary, word_dropout: float = 0.0, dropout: float = 0.0): | |||