Browse Source

- update transformer docs

tags/v0.4.0
yunfan 5 years ago
parent
commit
5241e30bdd
1 changed files with 15 additions and 9 deletions
  1. +15
    -9
      fastNLP/modules/encoder/transformer.py

+ 15
- 9
fastNLP/modules/encoder/transformer.py View File

@@ -5,17 +5,18 @@ from ..dropout import TimestepDropout




class TransformerEncoder(nn.Module): class TransformerEncoder(nn.Module):
"""transformer的encoder模块,不包含embedding层

:param num_layers: int, transformer的层数
:param model_size: int, 输入维度的大小。同时也是输出维度的大小。
:param inner_size: int, FFN层的hidden大小
:param key_size: int, 每个head的维度大小。
:param value_size: int,每个head中value的维度。
:param num_head: int,head的数量。
:param dropout: float。
"""
class SubLayer(nn.Module): class SubLayer(nn.Module):
def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1): def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1):
"""

:param model_size: int, 输入维度的大小。同时也是输出维度的大小。
:param inner_size: int, FFN层的hidden大小
:param key_size: int, 每个head的维度大小。
:param value_size: int,每个head中value的维度。
:param num_head: int,head的数量。
:param dropout: float。
"""
super(TransformerEncoder.SubLayer, self).__init__() super(TransformerEncoder.SubLayer, self).__init__()
self.atte = MultiHeadAtte(model_size, key_size, value_size, num_head, dropout) self.atte = MultiHeadAtte(model_size, key_size, value_size, num_head, dropout)
self.norm1 = nn.LayerNorm(model_size) self.norm1 = nn.LayerNorm(model_size)
@@ -45,6 +46,11 @@ class TransformerEncoder(nn.Module):
self.layers = nn.ModuleList([self.SubLayer(**kargs) for _ in range(num_layers)]) self.layers = nn.ModuleList([self.SubLayer(**kargs) for _ in range(num_layers)])


def forward(self, x, seq_mask=None): def forward(self, x, seq_mask=None):
"""
:param x: [batch, seq_len, model_size] 输入序列
:param seq_mask: [batch, seq_len] 输入序列的padding mask
:return: [batch, seq_len, model_size] 输出序列
"""
output = x output = x
if seq_mask is None: if seq_mask is None:
atte_mask_out = None atte_mask_out = None


Loading…
Cancel
Save