@@ -1,3 +1,7 @@ | |||||
""" | |||||
.. todo:: | |||||
doc | |||||
""" | |||||
__all__ = [ | __all__ = [ | ||||
"MLP", | "MLP", | ||||
"ConditionalRandomField", | "ConditionalRandomField", | ||||
@@ -6,6 +10,6 @@ __all__ = [ | |||||
] | ] | ||||
from .crf import ConditionalRandomField | from .crf import ConditionalRandomField | ||||
from .crf import allowed_transitions | |||||
from .mlp import MLP | from .mlp import MLP | ||||
from .utils import viterbi_decode | from .utils import viterbi_decode | ||||
from .crf import allowed_transitions |
@@ -1,3 +1,5 @@ | |||||
"""undocumented""" | |||||
__all__ = [ | __all__ = [ | ||||
"ConditionalRandomField", | "ConditionalRandomField", | ||||
"allowed_transitions" | "allowed_transitions" | ||||
@@ -9,13 +11,14 @@ from torch import nn | |||||
from ..utils import initial_parameter | from ..utils import initial_parameter | ||||
from ...core import Vocabulary | from ...core import Vocabulary | ||||
def allowed_transitions(id2target, encoding_type='bio', include_start_end=False): | def allowed_transitions(id2target, encoding_type='bio', include_start_end=False): | ||||
""" | """ | ||||
别名::class:`fastNLP.modules.allowed_transitions` :class:`fastNLP.modules.decoder.allowed_transitions` | 别名::class:`fastNLP.modules.allowed_transitions` :class:`fastNLP.modules.decoder.allowed_transitions` | ||||
给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。 | 给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。 | ||||
:param dict,Vocabulary id2target: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 | |||||
:param dict, ~fastNLP.Vocabulary id2target: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 | |||||
"B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.idx2word得到id2label。 | "B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.idx2word得到id2label。 | ||||
:param str encoding_type: 支持"bio", "bmes", "bmeso", "bioes"。 | :param str encoding_type: 支持"bio", "bmes", "bmeso", "bioes"。 | ||||
:param bool include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; | :param bool include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; | ||||
@@ -1,3 +1,5 @@ | |||||
"""undocumented""" | |||||
__all__ = [ | __all__ = [ | ||||
"MLP" | "MLP" | ||||
] | ] | ||||
@@ -1,3 +1,5 @@ | |||||
"""undocumented""" | |||||
__all__ = [ | __all__ = [ | ||||
"viterbi_decode" | "viterbi_decode" | ||||
] | ] | ||||
@@ -1,4 +1,8 @@ | |||||
__all__ = [] | |||||
"""undocumented""" | |||||
__all__ = [ | |||||
"TimestepDropout" | |||||
] | |||||
import torch | import torch | ||||
@@ -1,3 +1,8 @@ | |||||
""" | |||||
.. todo:: | |||||
doc | |||||
""" | |||||
__all__ = [ | __all__ = [ | ||||
# "BertModel", | # "BertModel", | ||||
@@ -24,13 +29,12 @@ __all__ = [ | |||||
"MultiHeadAttention", | "MultiHeadAttention", | ||||
] | ] | ||||
from .attention import MultiHeadAttention | |||||
from .bert import BertModel | from .bert import BertModel | ||||
from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder | from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder | ||||
from .conv_maxpool import ConvMaxpool | from .conv_maxpool import ConvMaxpool | ||||
from .lstm import LSTM | from .lstm import LSTM | ||||
from .pooling import MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask | |||||
from .star_transformer import StarTransformer | from .star_transformer import StarTransformer | ||||
from .transformer import TransformerEncoder | from .transformer import TransformerEncoder | ||||
from .variational_rnn import VarRNN, VarLSTM, VarGRU | from .variational_rnn import VarRNN, VarLSTM, VarGRU | ||||
from .pooling import MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask | |||||
from .attention import MultiHeadAttention |
@@ -1,7 +1,9 @@ | |||||
""" | |||||
"""undocumented | |||||
这个页面的代码大量参考了 allenNLP | 这个页面的代码大量参考了 allenNLP | ||||
""" | """ | ||||
__all__ = [] | |||||
from typing import Optional, Tuple, List, Callable | from typing import Optional, Tuple, List, Callable | ||||
import torch | import torch | ||||
@@ -1,3 +1,5 @@ | |||||
"""undocumented""" | |||||
__all__ = [ | __all__ = [ | ||||
"MultiHeadAttention" | "MultiHeadAttention" | ||||
] | ] | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
"""undocumented | |||||
这个页面的代码很大程度上参考(复制粘贴)了https://github.com/huggingface/pytorch-pretrained-BERT的代码, 如果你发现该代码对你 | 这个页面的代码很大程度上参考(复制粘贴)了https://github.com/huggingface/pytorch-pretrained-BERT的代码, 如果你发现该代码对你 | ||||
有用,也请引用一下他们。 | 有用,也请引用一下他们。 | ||||
""" | """ | ||||
@@ -8,17 +8,17 @@ __all__ = [ | |||||
] | ] | ||||
import collections | import collections | ||||
import unicodedata | |||||
import copy | import copy | ||||
import json | import json | ||||
import math | import math | ||||
import os | import os | ||||
import unicodedata | |||||
import torch | import torch | ||||
from torch import nn | from torch import nn | ||||
from ...core import logger | |||||
from ..utils import _get_file_name_base_on_postfix | from ..utils import _get_file_name_base_on_postfix | ||||
from ...core import logger | |||||
CONFIG_FILE = 'bert_config.json' | CONFIG_FILE = 'bert_config.json' | ||||
VOCAB_NAME = 'vocab.txt' | VOCAB_NAME = 'vocab.txt' | ||||
@@ -1,3 +1,5 @@ | |||||
"""undocumented""" | |||||
__all__ = [ | __all__ = [ | ||||
"ConvolutionCharEncoder", | "ConvolutionCharEncoder", | ||||
"LSTMCharEncoder" | "LSTMCharEncoder" | ||||
@@ -1,3 +1,5 @@ | |||||
"""undocumented""" | |||||
__all__ = [ | __all__ = [ | ||||
"ConvMaxpool" | "ConvMaxpool" | ||||
] | ] | ||||
@@ -1,7 +1,8 @@ | |||||
""" | |||||
"""undocumented | |||||
轻量封装的 Pytorch LSTM 模块. | 轻量封装的 Pytorch LSTM 模块. | ||||
可在 forward 时传入序列的长度, 自动对padding做合适的处理. | 可在 forward 时传入序列的长度, 自动对padding做合适的处理. | ||||
""" | """ | ||||
__all__ = [ | __all__ = [ | ||||
"LSTM" | "LSTM" | ||||
] | ] | ||||
@@ -1,3 +1,5 @@ | |||||
"""undocumented""" | |||||
__all__ = [ | __all__ = [ | ||||
"MaxPool", | "MaxPool", | ||||
"MaxPoolWithMask", | "MaxPoolWithMask", | ||||
@@ -1,6 +1,7 @@ | |||||
""" | |||||
"""undocumented | |||||
Star-Transformer 的encoder部分的 Pytorch 实现 | Star-Transformer 的encoder部分的 Pytorch 实现 | ||||
""" | """ | ||||
__all__ = [ | __all__ = [ | ||||
"StarTransformer" | "StarTransformer" | ||||
] | ] | ||||
@@ -1,3 +1,5 @@ | |||||
"""undocumented""" | |||||
__all__ = [ | __all__ = [ | ||||
"TransformerEncoder" | "TransformerEncoder" | ||||
] | ] | ||||
@@ -1,6 +1,7 @@ | |||||
""" | |||||
"""undocumented | |||||
Variational RNN 的 Pytorch 实现 | Variational RNN 的 Pytorch 实现 | ||||
""" | """ | ||||
__all__ = [ | __all__ = [ | ||||
"VarRNN", | "VarRNN", | ||||
"VarLSTM", | "VarLSTM", | ||||
@@ -1,10 +1,20 @@ | |||||
""" | |||||
.. todo:: | |||||
doc | |||||
""" | |||||
__all__ = [ | |||||
"initial_parameter", | |||||
"summary" | |||||
] | |||||
import os | |||||
from functools import reduce | from functools import reduce | ||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
import torch.nn.init as init | import torch.nn.init as init | ||||
import glob | |||||
import os | |||||
def initial_parameter(net, initial_method=None): | def initial_parameter(net, initial_method=None): | ||||
"""A method used to initialize the weights of PyTorch models. | """A method used to initialize the weights of PyTorch models. | ||||
@@ -40,7 +50,7 @@ def initial_parameter(net, initial_method=None): | |||||
init_method = init.uniform_ | init_method = init.uniform_ | ||||
else: | else: | ||||
init_method = init.xavier_normal_ | init_method = init.xavier_normal_ | ||||
def weights_init(m): | def weights_init(m): | ||||
# classname = m.__class__.__name__ | # classname = m.__class__.__name__ | ||||
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d) or isinstance(m, nn.Conv3d): # for all the cnn | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d) or isinstance(m, nn.Conv3d): # for all the cnn | ||||
@@ -66,7 +76,7 @@ def initial_parameter(net, initial_method=None): | |||||
else: | else: | ||||
init.normal_(w.data) # bias | init.normal_(w.data) # bias | ||||
# print("init else") | # print("init else") | ||||
net.apply(weights_init) | net.apply(weights_init) | ||||
@@ -79,11 +89,11 @@ def summary(model: nn.Module): | |||||
""" | """ | ||||
train = [] | train = [] | ||||
nontrain = [] | nontrain = [] | ||||
def layer_summary(module: nn.Module): | def layer_summary(module: nn.Module): | ||||
def count_size(sizes): | def count_size(sizes): | ||||
return reduce(lambda x, y: x*y, sizes) | |||||
return reduce(lambda x, y: x * y, sizes) | |||||
for p in module.parameters(recurse=False): | for p in module.parameters(recurse=False): | ||||
if p.requires_grad: | if p.requires_grad: | ||||
train.append(count_size(p.shape)) | train.append(count_size(p.shape)) | ||||
@@ -91,7 +101,7 @@ def summary(model: nn.Module): | |||||
nontrain.append(count_size(p.shape)) | nontrain.append(count_size(p.shape)) | ||||
for subm in module.children(): | for subm in module.children(): | ||||
layer_summary(subm) | layer_summary(subm) | ||||
layer_summary(model) | layer_summary(model) | ||||
total_train = sum(train) | total_train = sum(train) | ||||
total_nontrain = sum(nontrain) | total_nontrain = sum(nontrain) | ||||
@@ -101,7 +111,7 @@ def summary(model: nn.Module): | |||||
strings.append('Trainable params: {:,}'.format(total_train)) | strings.append('Trainable params: {:,}'.format(total_train)) | ||||
strings.append('Non-trainable params: {:,}'.format(total_nontrain)) | strings.append('Non-trainable params: {:,}'.format(total_nontrain)) | ||||
max_len = len(max(strings, key=len)) | max_len = len(max(strings, key=len)) | ||||
bar = '-'*(max_len + 3) | |||||
bar = '-' * (max_len + 3) | |||||
strings = [bar] + strings + [bar] | strings = [bar] + strings + [bar] | ||||
print('\n'.join(strings)) | print('\n'.join(strings)) | ||||
return total, total_train, total_nontrain | return total, total_train, total_nontrain | ||||
@@ -128,9 +138,9 @@ def _get_file_name_base_on_postfix(dir_path, postfix): | |||||
:param postfix: 形如".bin", ".json"等 | :param postfix: 形如".bin", ".json"等 | ||||
:return: str,文件的路径 | :return: str,文件的路径 | ||||
""" | """ | ||||
files = list(filter(lambda filename:filename.endswith(postfix), os.listdir(os.path.join(dir_path)))) | |||||
files = list(filter(lambda filename: filename.endswith(postfix), os.listdir(os.path.join(dir_path)))) | |||||
if len(files) == 0: | if len(files) == 0: | ||||
raise FileNotFoundError(f"There is no file endswith *{postfix} file in {dir_path}") | raise FileNotFoundError(f"There is no file endswith *{postfix} file in {dir_path}") | ||||
elif len(files) > 1: | elif len(files) > 1: | ||||
raise FileExistsError(f"There are multiple *{postfix} files in {dir_path}") | raise FileExistsError(f"There are multiple *{postfix} files in {dir_path}") | ||||
return os.path.join(dir_path, files[0]) | |||||
return os.path.join(dir_path, files[0]) |