@@ -1,11 +1,11 @@ | |||
""" | |||
大部分用于的 NLP 任务神经网络都可以看做由编码 :mod:`~fastNLP.modules.encoder` 、 | |||
聚合 :mod:`~fastNLP.modules.aggregator` 、解码 :mod:`~fastNLP.modules.decoder` 三种模块组成。 | |||
解码 :mod:`~fastNLP.modules.decoder` 两种模块组成。 | |||
.. image:: figures/text_classification.png | |||
:mod:`~fastNLP.modules` 中实现了 fastNLP 提供的诸多模块组件,可以帮助用户快速搭建自己所需的网络。 | |||
三种模块的功能和常见组件如下: | |||
两种模块的功能和常见组件如下: | |||
+-----------------------+-----------------------+-----------------------+ | |||
| module type | functionality | example | | |||
@@ -13,9 +13,6 @@ | |||
| encoder | 将输入编码为具有具 | embedding, RNN, CNN, | | |||
| | 有表示能力的向量 | transformer | | |||
+-----------------------+-----------------------+-----------------------+ | |||
| aggregator | 从多个向量中聚合信息 | self-attention, | | |||
| | | max-pooling | | |||
+-----------------------+-----------------------+-----------------------+ | |||
| decoder | 将具有某种表示意义的 | MLP, CRF | | |||
| | 向量解码为需要的输出 | | | |||
| | 形式 | | | |||
@@ -46,10 +43,8 @@ __all__ = [ | |||
"allowed_transitions", | |||
] | |||
from . import aggregator | |||
from . import decoder | |||
from . import encoder | |||
from .aggregator import * | |||
from .decoder import * | |||
from .dropout import TimestepDropout | |||
from .encoder import * | |||
@@ -22,7 +22,14 @@ __all__ = [ | |||
"VarRNN", | |||
"VarLSTM", | |||
"VarGRU" | |||
"VarGRU", | |||
"MaxPool", | |||
"MaxPoolWithMask", | |||
"AvgPool", | |||
"AvgPoolWithMask", | |||
"MultiHeadAttention", | |||
] | |||
from ._bert import BertModel | |||
from .bert import BertWordPieceEncoder | |||
@@ -34,3 +41,6 @@ from .lstm import LSTM | |||
from .star_transformer import StarTransformer | |||
from .transformer import TransformerEncoder | |||
from .variational_rnn import VarRNN, VarLSTM, VarGRU | |||
from .pooling import MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask | |||
from .attention import MultiHeadAttention |
@@ -45,8 +45,7 @@ class DotAttention(nn.Module): | |||
class MultiHeadAttention(nn.Module): | |||
""" | |||
别名::class:`fastNLP.modules.MultiHeadAttention` :class:`fastNLP.modules.aggregator.attention.MultiHeadAttention` | |||
别名::class:`fastNLP.modules.MultiHeadAttention` :class:`fastNLP.modules.encoder.attention.MultiHeadAttention` | |||
:param input_size: int, 输入维度的大小。同时也是输出维度的大小。 | |||
:param key_size: int, 每个head的维度大小。 | |||
@@ -2,35 +2,22 @@ | |||
import os | |||
from torch import nn | |||
import torch | |||
from ...io.file_utils import _get_base_url, cached_path | |||
from ...io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | |||
from ._bert import _WordPieceBertModel, BertModel | |||
class BertWordPieceEncoder(nn.Module): | |||
""" | |||
读取bert模型,读取之后调用index_dataset方法在dataset中生成word_pieces这一列。 | |||
:param fastNLP.Vocabulary vocab: 词表 | |||
:param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为``en-base-uncased`` | |||
:param str layers:最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层 | |||
:param bool requires_grad: 是否需要gradient。 | |||
""" | |||
def __init__(self, model_dir_or_name:str='en-base-uncased', layers:str='-1', | |||
requires_grad:bool=False): | |||
def __init__(self, model_dir_or_name: str='en-base-uncased', layers: str='-1', | |||
requires_grad: bool=False): | |||
super().__init__() | |||
PRETRAIN_URL = _get_base_url('bert') | |||
PRETRAINED_BERT_MODEL_DIR = {'en': 'bert-base-cased-f89bfe08.zip', | |||
'en-base-uncased': 'bert-base-uncased-3413b23c.zip', | |||
'en-base-cased': 'bert-base-cased-f89bfe08.zip', | |||
'en-large-uncased': 'bert-large-uncased-20939f45.zip', | |||
'en-large-cased': 'bert-large-cased-e0cf90fc.zip', | |||
'cn': 'bert-base-chinese-29d0a84a.zip', | |||
'cn-base': 'bert-base-chinese-29d0a84a.zip', | |||
'multilingual': 'bert-base-multilingual-cased-1bd364ee.zip', | |||
'multilingual-base-uncased': 'bert-base-multilingual-uncased-f8730fe4.zip', | |||
'multilingual-base-cased': 'bert-base-multilingual-cased-1bd364ee.zip', | |||
} | |||
if model_dir_or_name in PRETRAINED_BERT_MODEL_DIR: | |||
model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] | |||
@@ -89,4 +76,4 @@ class BertWordPieceEncoder(nn.Module): | |||
outputs = self.model(word_pieces, token_type_ids) | |||
outputs = torch.cat([*outputs], dim=-1) | |||
return outputs | |||
return outputs |
@@ -1,7 +1,8 @@ | |||
__all__ = [ | |||
"MaxPool", | |||
"MaxPoolWithMask", | |||
"AvgPool" | |||
"AvgPool", | |||
"AvgPoolWithMask" | |||
] | |||
import torch | |||
import torch.nn as nn | |||
@@ -9,7 +10,7 @@ import torch.nn as nn | |||
class MaxPool(nn.Module): | |||
""" | |||
别名::class:`fastNLP.modules.MaxPool` :class:`fastNLP.modules.aggregator.pooling.MaxPool` | |||
别名::class:`fastNLP.modules.MaxPool` :class:`fastNLP.modules.encoder.pooling.MaxPool` | |||
Max-pooling模块。 | |||
@@ -58,7 +59,7 @@ class MaxPool(nn.Module): | |||
class MaxPoolWithMask(nn.Module): | |||
""" | |||
别名::class:`fastNLP.modules.MaxPoolWithMask` :class:`fastNLP.modules.aggregator.pooling.MaxPoolWithMask` | |||
别名::class:`fastNLP.modules.MaxPoolWithMask` :class:`fastNLP.modules.encoder.pooling.MaxPoolWithMask` | |||
带mask矩阵的max pooling。在做max-pooling的时候不会考虑mask值为0的位置。 | |||
""" | |||
@@ -98,7 +99,7 @@ class KMaxPool(nn.Module): | |||
class AvgPool(nn.Module): | |||
""" | |||
别名::class:`fastNLP.modules.AvgPool` :class:`fastNLP.modules.aggregator.pooling.AvgPool` | |||
别名::class:`fastNLP.modules.AvgPool` :class:`fastNLP.modules.encoder.pooling.AvgPool` | |||
给定形如[batch_size, max_len, hidden_size]的输入,在最后一维进行avg pooling. 输出为[batch_size, hidden_size] | |||
""" | |||
@@ -125,7 +126,7 @@ class AvgPool(nn.Module): | |||
class AvgPoolWithMask(nn.Module): | |||
""" | |||
别名::class:`fastNLP.modules.AvgPoolWithMask` :class:`fastNLP.modules.aggregator.pooling.AvgPoolWithMask` | |||
别名::class:`fastNLP.modules.AvgPoolWithMask` :class:`fastNLP.modules.encoder.pooling.AvgPoolWithMask` | |||
给定形如[batch_size, max_len, hidden_size]的输入,在最后一维进行avg pooling. 输出为[batch_size, hidden_size], pooling | |||
的时候只会考虑mask为1的位置 | |||
@@ -9,7 +9,7 @@ | |||
# 任务复现 | |||
## Text Classification (文本分类) | |||
- still in progress | |||
- [Text Classification 文本分类任务复现](text_classification) | |||
## Matching (自然语言推理/句子匹配) | |||
@@ -21,11 +21,11 @@ | |||
## Coreference resolution (指代消解) | |||
- still in progress | |||
- [Coreference resolution 指代消解任务复现](coreference_resolution) | |||
## Summarization (摘要) | |||
- still in progress | |||
- [BertSum](Summmarization) | |||
## ... |