diff --git a/fastNLP/modules/decoder/__init__.py b/fastNLP/modules/decoder/__init__.py index 664618b2..57acb172 100644 --- a/fastNLP/modules/decoder/__init__.py +++ b/fastNLP/modules/decoder/__init__.py @@ -1,3 +1,7 @@ +""" +.. todo:: + doc +""" __all__ = [ "MLP", "ConditionalRandomField", @@ -6,6 +10,6 @@ __all__ = [ ] from .crf import ConditionalRandomField +from .crf import allowed_transitions from .mlp import MLP from .utils import viterbi_decode -from .crf import allowed_transitions diff --git a/fastNLP/modules/decoder/crf.py b/fastNLP/modules/decoder/crf.py index 9f19afef..b47d0162 100644 --- a/fastNLP/modules/decoder/crf.py +++ b/fastNLP/modules/decoder/crf.py @@ -1,3 +1,5 @@ +"""undocumented""" + __all__ = [ "ConditionalRandomField", "allowed_transitions" @@ -9,13 +11,14 @@ from torch import nn from ..utils import initial_parameter from ...core import Vocabulary + def allowed_transitions(id2target, encoding_type='bio', include_start_end=False): """ 别名::class:`fastNLP.modules.allowed_transitions` :class:`fastNLP.modules.decoder.allowed_transitions` 给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。 - :param dict,Vocabulary id2target: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 + :param dict, ~fastNLP.Vocabulary id2target: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 "B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.idx2word得到id2label。 :param str encoding_type: 支持"bio", "bmes", "bmeso", "bioes"。 :param bool include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; diff --git a/fastNLP/modules/decoder/mlp.py b/fastNLP/modules/decoder/mlp.py index 9d9d80f2..f6e687a7 100644 --- a/fastNLP/modules/decoder/mlp.py +++ b/fastNLP/modules/decoder/mlp.py @@ -1,3 +1,5 @@ +"""undocumented""" + __all__ = [ "MLP" ] diff --git a/fastNLP/modules/decoder/utils.py b/fastNLP/modules/decoder/utils.py index 3d5ac3f8..118b1414 100644 --- a/fastNLP/modules/decoder/utils.py +++ b/fastNLP/modules/decoder/utils.py @@ -1,3 +1,5 @@ +"""undocumented""" + __all__ = [ "viterbi_decode" ] diff --git a/fastNLP/modules/dropout.py b/fastNLP/modules/dropout.py index 0ea2a2d9..24c20cc6 100644 --- a/fastNLP/modules/dropout.py +++ b/fastNLP/modules/dropout.py @@ -1,4 +1,8 @@ -__all__ = [] +"""undocumented""" + +__all__ = [ + "TimestepDropout" +] import torch diff --git a/fastNLP/modules/encoder/__init__.py b/fastNLP/modules/encoder/__init__.py index 1e99a0fd..0dfc18de 100644 --- a/fastNLP/modules/encoder/__init__.py +++ b/fastNLP/modules/encoder/__init__.py @@ -1,3 +1,8 @@ +""" +.. todo:: + doc +""" + __all__ = [ # "BertModel", @@ -24,13 +29,12 @@ __all__ = [ "MultiHeadAttention", ] +from .attention import MultiHeadAttention from .bert import BertModel from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder from .conv_maxpool import ConvMaxpool from .lstm import LSTM +from .pooling import MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask from .star_transformer import StarTransformer from .transformer import TransformerEncoder from .variational_rnn import VarRNN, VarLSTM, VarGRU - -from .pooling import MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask -from .attention import MultiHeadAttention diff --git a/fastNLP/modules/encoder/_elmo.py b/fastNLP/modules/encoder/_elmo.py index befae8bc..554cf8a9 100644 --- a/fastNLP/modules/encoder/_elmo.py +++ b/fastNLP/modules/encoder/_elmo.py @@ -1,7 +1,9 @@ -""" +"""undocumented 这个页面的代码大量参考了 allenNLP """ +__all__ = [] + from typing import Optional, Tuple, List, Callable import torch diff --git a/fastNLP/modules/encoder/attention.py b/fastNLP/modules/encoder/attention.py index fe3f7fd8..02bd078a 100644 --- a/fastNLP/modules/encoder/attention.py +++ b/fastNLP/modules/encoder/attention.py @@ -1,3 +1,5 @@ +"""undocumented""" + __all__ = [ "MultiHeadAttention" ] diff --git a/fastNLP/modules/encoder/bert.py b/fastNLP/modules/encoder/bert.py index b74c4da0..5026f48a 100644 --- a/fastNLP/modules/encoder/bert.py +++ b/fastNLP/modules/encoder/bert.py @@ -1,4 +1,4 @@ -""" +"""undocumented 这个页面的代码很大程度上参考(复制粘贴)了https://github.com/huggingface/pytorch-pretrained-BERT的代码, 如果你发现该代码对你 有用,也请引用一下他们。 """ @@ -8,17 +8,17 @@ __all__ = [ ] import collections - -import unicodedata import copy import json import math import os +import unicodedata import torch from torch import nn -from ...core import logger + from ..utils import _get_file_name_base_on_postfix +from ...core import logger CONFIG_FILE = 'bert_config.json' VOCAB_NAME = 'vocab.txt' diff --git a/fastNLP/modules/encoder/char_encoder.py b/fastNLP/modules/encoder/char_encoder.py index 6a6e1470..e40bd0dd 100644 --- a/fastNLP/modules/encoder/char_encoder.py +++ b/fastNLP/modules/encoder/char_encoder.py @@ -1,3 +1,5 @@ +"""undocumented""" + __all__ = [ "ConvolutionCharEncoder", "LSTMCharEncoder" diff --git a/fastNLP/modules/encoder/conv_maxpool.py b/fastNLP/modules/encoder/conv_maxpool.py index 8ce6b163..68415189 100644 --- a/fastNLP/modules/encoder/conv_maxpool.py +++ b/fastNLP/modules/encoder/conv_maxpool.py @@ -1,3 +1,5 @@ +"""undocumented""" + __all__ = [ "ConvMaxpool" ] diff --git a/fastNLP/modules/encoder/lstm.py b/fastNLP/modules/encoder/lstm.py index e2358132..1f3eae6d 100644 --- a/fastNLP/modules/encoder/lstm.py +++ b/fastNLP/modules/encoder/lstm.py @@ -1,7 +1,8 @@ -""" +"""undocumented 轻量封装的 Pytorch LSTM 模块. 可在 forward 时传入序列的长度, 自动对padding做合适的处理. """ + __all__ = [ "LSTM" ] diff --git a/fastNLP/modules/encoder/pooling.py b/fastNLP/modules/encoder/pooling.py index d8aa54ad..b1272284 100644 --- a/fastNLP/modules/encoder/pooling.py +++ b/fastNLP/modules/encoder/pooling.py @@ -1,3 +1,5 @@ +"""undocumented""" + __all__ = [ "MaxPool", "MaxPoolWithMask", diff --git a/fastNLP/modules/encoder/star_transformer.py b/fastNLP/modules/encoder/star_transformer.py index 3927a494..02d7a6a0 100644 --- a/fastNLP/modules/encoder/star_transformer.py +++ b/fastNLP/modules/encoder/star_transformer.py @@ -1,6 +1,7 @@ -""" +"""undocumented Star-Transformer 的encoder部分的 Pytorch 实现 """ + __all__ = [ "StarTransformer" ] diff --git a/fastNLP/modules/encoder/transformer.py b/fastNLP/modules/encoder/transformer.py index bc488e54..ce9172d5 100644 --- a/fastNLP/modules/encoder/transformer.py +++ b/fastNLP/modules/encoder/transformer.py @@ -1,3 +1,5 @@ +"""undocumented""" + __all__ = [ "TransformerEncoder" ] diff --git a/fastNLP/modules/encoder/variational_rnn.py b/fastNLP/modules/encoder/variational_rnn.py index 8e5e804b..933555c8 100644 --- a/fastNLP/modules/encoder/variational_rnn.py +++ b/fastNLP/modules/encoder/variational_rnn.py @@ -1,6 +1,7 @@ -""" +"""undocumented Variational RNN 的 Pytorch 实现 """ + __all__ = [ "VarRNN", "VarLSTM", diff --git a/fastNLP/modules/utils.py b/fastNLP/modules/utils.py index ead75711..09574782 100644 --- a/fastNLP/modules/utils.py +++ b/fastNLP/modules/utils.py @@ -1,10 +1,20 @@ +""" +.. todo:: + doc +""" + +__all__ = [ + "initial_parameter", + "summary" +] + +import os from functools import reduce import torch import torch.nn as nn import torch.nn.init as init -import glob -import os + def initial_parameter(net, initial_method=None): """A method used to initialize the weights of PyTorch models. @@ -40,7 +50,7 @@ def initial_parameter(net, initial_method=None): init_method = init.uniform_ else: init_method = init.xavier_normal_ - + def weights_init(m): # classname = m.__class__.__name__ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d) or isinstance(m, nn.Conv3d): # for all the cnn @@ -66,7 +76,7 @@ def initial_parameter(net, initial_method=None): else: init.normal_(w.data) # bias # print("init else") - + net.apply(weights_init) @@ -79,11 +89,11 @@ def summary(model: nn.Module): """ train = [] nontrain = [] - + def layer_summary(module: nn.Module): def count_size(sizes): - return reduce(lambda x, y: x*y, sizes) - + return reduce(lambda x, y: x * y, sizes) + for p in module.parameters(recurse=False): if p.requires_grad: train.append(count_size(p.shape)) @@ -91,7 +101,7 @@ def summary(model: nn.Module): nontrain.append(count_size(p.shape)) for subm in module.children(): layer_summary(subm) - + layer_summary(model) total_train = sum(train) total_nontrain = sum(nontrain) @@ -101,7 +111,7 @@ def summary(model: nn.Module): strings.append('Trainable params: {:,}'.format(total_train)) strings.append('Non-trainable params: {:,}'.format(total_nontrain)) max_len = len(max(strings, key=len)) - bar = '-'*(max_len + 3) + bar = '-' * (max_len + 3) strings = [bar] + strings + [bar] print('\n'.join(strings)) return total, total_train, total_nontrain @@ -128,9 +138,9 @@ def _get_file_name_base_on_postfix(dir_path, postfix): :param postfix: 形如".bin", ".json"等 :return: str,文件的路径 """ - files = list(filter(lambda filename:filename.endswith(postfix), os.listdir(os.path.join(dir_path)))) + files = list(filter(lambda filename: filename.endswith(postfix), os.listdir(os.path.join(dir_path)))) if len(files) == 0: raise FileNotFoundError(f"There is no file endswith *{postfix} file in {dir_path}") elif len(files) > 1: raise FileExistsError(f"There are multiple *{postfix} files in {dir_path}") - return os.path.join(dir_path, files[0]) \ No newline at end of file + return os.path.join(dir_path, files[0])