diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index 1feaf3fb..efee08b5 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -10,8 +10,72 @@ core 模块里实现了 fastNLP 的核心框架,常用的功能都可以从 fa 对于常用的功能,你只需要在 :doc:`fastNLP` 中查看即可。如果想了解各个子模块的具体作用,您可以在下面找到每个子模块的具体文档。 - """ +__all__ = [ + "DataSet", + + "Instance", + + "FieldArray", + "Padder", + "AutoPadder", + "EngChar2DPadder", + + "Vocabulary", + + "DataSetIter", + "BatchIter", + "TorchLoaderIter", + + "Const", + + "Tester", + "Trainer", + + "cache_results", + "seq_len_to_mask", + "get_seq_len", + "logger", + + "Callback", + "GradientClipCallback", + "EarlyStopCallback", + "FitlogCallback", + "EvaluateCallback", + "LRScheduler", + "ControlC", + "LRFinder", + "TensorboardCallback", + "WarmupCallback", + 'SaveModelCallback', + "EchoCallback", + "TesterCallback", + "CallbackException", + "EarlyStopError", + + "LossFunc", + "CrossEntropyLoss", + "L1Loss", + "BCELoss", + "NLLLoss", + "LossInForward", + + "AccuracyMetric", + "SpanFPreRecMetric", + "ExtractiveQAMetric", + + "Optimizer", + "SGD", + "Adam", + "AdamW", + + "SequentialSampler", + "BucketSampler", + "RandomSampler", + "Sampler", +] + +from ._logger import logger from .batch import DataSetIter, BatchIter, TorchLoaderIter from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \ LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, \ @@ -28,4 +92,3 @@ from .tester import Tester from .trainer import Trainer from .utils import cache_results, seq_len_to_mask, get_seq_len from .vocabulary import Vocabulary -from ._logger import logger diff --git a/fastNLP/core/_logger.py b/fastNLP/core/_logger.py index 50266d7a..7198cfbd 100644 --- a/fastNLP/core/_logger.py +++ b/fastNLP/core/_logger.py @@ -1,15 +1,15 @@ +"""undocumented""" + +__all__ = [ + 'logger', +] + import logging import logging.config -import torch -import _pickle as pickle import os import sys import warnings -__all__ = [ - 'logger', -] - ROOT_NAME = 'fastNLP' try: @@ -25,7 +25,7 @@ if tqdm is not None: class TqdmLoggingHandler(logging.Handler): def __init__(self, level=logging.INFO): super().__init__(level) - + def emit(self, record): try: msg = self.format(record) @@ -59,14 +59,14 @@ def _add_file_handler(logger, path, level='INFO'): if os.path.abspath(path) == h.baseFilename: # file path already added return - + # File Handler if os.path.exists(path): assert os.path.isfile(path) warnings.warn('log already exists in {}'.format(path)) dirname = os.path.abspath(os.path.dirname(path)) os.makedirs(dirname, exist_ok=True) - + file_handler = logging.FileHandler(path, mode='a') file_handler.setLevel(_get_level(level)) file_formatter = logging.Formatter(fmt='%(asctime)s - %(module)s - [%(levelname)s] - %(message)s', @@ -87,7 +87,7 @@ def _set_stdout_handler(logger, stdout='tqdm', level='INFO'): break if stream_handler is not None: logger.removeHandler(stream_handler) - + # Stream Handler if stdout == 'plain': stream_handler = logging.StreamHandler(sys.stdout) @@ -95,7 +95,7 @@ def _set_stdout_handler(logger, stdout='tqdm', level='INFO'): stream_handler = TqdmLoggingHandler(level) else: stream_handler = None - + if stream_handler is not None: stream_formatter = logging.Formatter('%(message)s') stream_handler.setLevel(level) @@ -103,38 +103,40 @@ def _set_stdout_handler(logger, stdout='tqdm', level='INFO'): logger.addHandler(stream_handler) - class FastNLPLogger(logging.getLoggerClass()): def __init__(self, name): super().__init__(name) - + def add_file(self, path='./log.txt', level='INFO'): """add log output file and level""" _add_file_handler(self, path, level) - + def set_stdout(self, stdout='tqdm', level='INFO'): """set stdout format and level""" _set_stdout_handler(self, stdout, level) + logging.setLoggerClass(FastNLPLogger) + + # print(logging.getLoggerClass()) # print(logging.getLogger()) def _init_logger(path=None, stdout='tqdm', level='INFO'): """initialize logger""" level = _get_level(level) - + # logger = logging.getLogger() logger = logging.getLogger(ROOT_NAME) logger.propagate = False logger.setLevel(level) - + _set_stdout_handler(logger, stdout, level) - + # File Handler if path is not None: _add_file_handler(logger, path, level) - + return logger diff --git a/fastNLP/core/_parallel_utils.py b/fastNLP/core/_parallel_utils.py index 6b24d9f9..ce745820 100644 --- a/fastNLP/core/_parallel_utils.py +++ b/fastNLP/core/_parallel_utils.py @@ -1,11 +1,14 @@ +"""undocumented""" + +__all__ = [] import threading + import torch from torch import nn from torch.nn.parallel.parallel_apply import get_a_var - -from torch.nn.parallel.scatter_gather import scatter_kwargs, gather from torch.nn.parallel.replicate import replicate +from torch.nn.parallel.scatter_gather import scatter_kwargs, gather def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None): @@ -27,11 +30,11 @@ def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None): assert len(modules) == len(devices) else: devices = [None] * len(modules) - + lock = threading.Lock() results = {} grad_enabled = torch.is_grad_enabled() - + def _worker(i, module, input, kwargs, device=None): torch.set_grad_enabled(grad_enabled) if device is None: @@ -47,20 +50,20 @@ def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None): except Exception as e: with lock: results[i] = e - + if len(modules) > 1: threads = [threading.Thread(target=_worker, args=(i, module, input, kwargs, device)) for i, (module, input, kwargs, device) in enumerate(zip(modules, inputs, kwargs_tup, devices))] - + for thread in threads: thread.start() for thread in threads: thread.join() else: _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) - + outputs = [] for i in range(len(inputs)): output = results[i] @@ -79,6 +82,7 @@ def _data_parallel_wrapper(func_name, device_ids, output_device): :param output_device: nn.DataParallel中的output_device :return: """ + def wrapper(network, *inputs, **kwargs): inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0) if len(device_ids) == 1: @@ -86,6 +90,7 @@ def _data_parallel_wrapper(func_name, device_ids, output_device): replicas = replicate(network, device_ids[:len(inputs)]) outputs = parallel_apply(replicas, func_name, inputs, kwargs, device_ids[:len(replicas)]) return gather(outputs, output_device) + return wrapper @@ -99,4 +104,4 @@ def _model_contains_inner_module(model): if isinstance(model, nn.Module): if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)): return True - return False \ No newline at end of file + return False diff --git a/fastNLP/core/const.py b/fastNLP/core/const.py index 27e8d1cb..ad5d1f1e 100644 --- a/fastNLP/core/const.py +++ b/fastNLP/core/const.py @@ -1,3 +1,13 @@ +""" +.. todo:: + doc +""" + +__all__ = [ + "Const" +] + + class Const: """ fastNLP中field命名常量。 @@ -25,47 +35,47 @@ class Const: LOSS = 'loss' RAW_WORD = 'raw_words' RAW_CHAR = 'raw_chars' - + @staticmethod def INPUTS(i): """得到第 i 个 ``INPUT`` 的命名""" i = int(i) + 1 return Const.INPUT + str(i) - + @staticmethod def CHAR_INPUTS(i): """得到第 i 个 ``CHAR_INPUT`` 的命名""" i = int(i) + 1 return Const.CHAR_INPUT + str(i) - + @staticmethod def RAW_WORDS(i): i = int(i) + 1 return Const.RAW_WORD + str(i) - + @staticmethod def RAW_CHARS(i): i = int(i) + 1 return Const.RAW_CHAR + str(i) - + @staticmethod def INPUT_LENS(i): """得到第 i 个 ``INPUT_LEN`` 的命名""" i = int(i) + 1 return Const.INPUT_LEN + str(i) - + @staticmethod def OUTPUTS(i): """得到第 i 个 ``OUTPUT`` 的命名""" i = int(i) + 1 return Const.OUTPUT + str(i) - + @staticmethod def TARGETS(i): """得到第 i 个 ``TARGET`` 的命名""" i = int(i) + 1 return Const.TARGET + str(i) - + @staticmethod def LOSSES(i): """得到第 i 个 ``LOSS`` 的命名""" diff --git a/fastNLP/core/dist_trainer.py b/fastNLP/core/dist_trainer.py index 7c64fee4..3a293447 100644 --- a/fastNLP/core/dist_trainer.py +++ b/fastNLP/core/dist_trainer.py @@ -1,29 +1,29 @@ -""" +"""undocumented 正在开发中的分布式训练代码 """ +import logging +import os +import time +from datetime import datetime + import torch import torch.cuda -import torch.optim import torch.distributed as dist -from torch.utils.data.distributed import DistributedSampler +import torch.optim +from pkg_resources import parse_version from torch.nn.parallel import DistributedDataParallel as DDP -import os +from torch.utils.data.distributed import DistributedSampler from tqdm import tqdm -import time -from datetime import datetime, timedelta -from functools import partial +from ._logger import logger from .batch import DataSetIter, BatchIter from .callback import DistCallbackManager, CallbackException, TesterCallback from .dataset import DataSet from .losses import _prepare_losser from .optimizer import Optimizer from .utils import _build_args -from .utils import _move_dict_value_to_device from .utils import _get_func_signature -from ._logger import logger -import logging -from pkg_resources import parse_version +from .utils import _move_dict_value_to_device __all__ = [ 'get_local_rank', diff --git a/fastNLP/core/field.py b/fastNLP/core/field.py index b3f024f8..05f987c2 100644 --- a/fastNLP/core/field.py +++ b/fastNLP/core/field.py @@ -1,18 +1,25 @@ +""" +.. todo:: + doc +""" + __all__ = [ "Padder", "AutoPadder", "EngChar2DPadder", ] -from numbers import Number -import torch -import numpy as np -from typing import Any from abc import abstractmethod -from copy import deepcopy from collections import Counter -from .utils import _is_iterable +from copy import deepcopy +from numbers import Number +from typing import Any + +import numpy as np +import torch + from ._logger import logger +from .utils import _is_iterable class SetInputOrTargetException(Exception): diff --git a/fastNLP/core/predictor.py b/fastNLP/core/predictor.py index 2d6a7380..c6b8fc90 100644 --- a/fastNLP/core/predictor.py +++ b/fastNLP/core/predictor.py @@ -1,13 +1,15 @@ -""" - ..todo:: - 检查这个类是否需要 -""" +"""undocumented""" + +__all__ = [ + "Predictor" +] + from collections import defaultdict import torch -from . import DataSetIter from . import DataSet +from . import DataSetIter from . import SequentialSampler from .utils import _build_args, _move_dict_value_to_device, _get_model_device @@ -21,7 +23,7 @@ class Predictor(object): :param torch.nn.Module network: 用来完成预测任务的模型 """ - + def __init__(self, network): if not isinstance(network, torch.nn.Module): raise ValueError( @@ -29,7 +31,7 @@ class Predictor(object): self.network = network self.batch_size = 1 self.batch_output = [] - + def predict(self, data: DataSet, seq_len_field_name=None): """用已经训练好的模型进行inference. @@ -41,27 +43,27 @@ class Predictor(object): raise ValueError("Only Dataset class is allowed, not {}.".format(type(data))) if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays: raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data)) - + prev_training = self.network.training self.network.eval() network_device = _get_model_device(self.network) batch_output = defaultdict(list) data_iterator = DataSetIter(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False) - + if hasattr(self.network, "predict"): predict_func = self.network.predict else: predict_func = self.network.forward - + with torch.no_grad(): for batch_x, _ in data_iterator: _move_dict_value_to_device(batch_x, _, device=network_device) refined_batch_x = _build_args(predict_func, **batch_x) prediction = predict_func(**refined_batch_x) - + if seq_len_field_name is not None: seq_lens = batch_x[seq_len_field_name].tolist() - + for key, value in prediction.items(): value = value.cpu().numpy() if len(value.shape) == 1 or (len(value.shape) == 2 and value.shape[1] == 1): @@ -74,6 +76,6 @@ class Predictor(object): batch_output[key].extend(tmp_batch) else: batch_output[key].append(value) - + self.network.train(prev_training) return batch_output diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index 92f54f9a..52d33a5a 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -1,16 +1,22 @@ +""" +.. todo:: + doc +""" + __all__ = [ "Vocabulary", "VocabularyOption", ] -from functools import wraps from collections import Counter +from functools import partial +from functools import wraps + +from ._logger import logger from .dataset import DataSet from .utils import Option -from functools import partial -import numpy as np from .utils import _is_iterable -from ._logger import logger + class VocabularyOption(Option): def __init__(self, @@ -51,7 +57,7 @@ def _check_build_status(func): self.rebuild = True if self.max_size is not None and len(self.word_count) >= self.max_size: logger.info("[Warning] Vocabulary has reached the max size {} when calling {} method. " - "Adding more words may cause unexpected behaviour of Vocabulary. ".format( + "Adding more words may cause unexpected behaviour of Vocabulary. ".format( self.max_size, func.__name__)) return func(self, *args, **kwargs) @@ -199,7 +205,7 @@ class Vocabulary(object): self.build_reverse_vocab() self.rebuild = False return self - + def build_reverse_vocab(self): """ 基于 `word to index` dict, 构建 `index to word` dict. @@ -279,19 +285,19 @@ class Vocabulary(object): if not isinstance(field[0][0], str) and _is_iterable(field[0][0]): raise RuntimeError("Only support field with 2 dimensions.") return [[self.to_index(c) for c in w] for w in field] - + new_field_name = new_field_name or field_name - + if type(new_field_name) == type(field_name): if isinstance(new_field_name, list): assert len(new_field_name) == len(field_name), "new_field_name should have same number elements with " \ - "field_name." + "field_name." elif isinstance(new_field_name, str): field_name = [field_name] new_field_name = [new_field_name] else: raise TypeError("field_name and new_field_name can only be str or List[str].") - + for idx, dataset in enumerate(datasets): if isinstance(dataset, DataSet): try: @@ -377,7 +383,7 @@ class Vocabulary(object): :return: bool """ return word in self._no_create_word - + def to_index(self, w): """ 将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出``ValueError``:: diff --git a/fastNLP/embeddings/contextual_embedding.py b/fastNLP/embeddings/contextual_embedding.py index 2c304da7..9910a44b 100644 --- a/fastNLP/embeddings/contextual_embedding.py +++ b/fastNLP/embeddings/contextual_embedding.py @@ -8,15 +8,17 @@ __all__ = [ ] from abc import abstractmethod + import torch -from ..core.vocabulary import Vocabulary -from ..core.dataset import DataSet +from .embedding import TokenEmbedding +from ..core import logger from ..core.batch import DataSetIter +from ..core.dataset import DataSet from ..core.sampler import SequentialSampler from ..core.utils import _move_model_to_device, _get_model_device -from .embedding import TokenEmbedding -from ..core import logger +from ..core.vocabulary import Vocabulary + class ContextualEmbedding(TokenEmbedding): def __init__(self, vocab: Vocabulary, word_dropout: float = 0.0, dropout: float = 0.0):