| @@ -5,17 +5,18 @@ from ..dropout import TimestepDropout | |||||
| class TransformerEncoder(nn.Module): | class TransformerEncoder(nn.Module): | ||||
| """transformer的encoder模块,不包含embedding层 | |||||
| :param num_layers: int, transformer的层数 | |||||
| :param model_size: int, 输入维度的大小。同时也是输出维度的大小。 | |||||
| :param inner_size: int, FFN层的hidden大小 | |||||
| :param key_size: int, 每个head的维度大小。 | |||||
| :param value_size: int,每个head中value的维度。 | |||||
| :param num_head: int,head的数量。 | |||||
| :param dropout: float。 | |||||
| """ | |||||
| class SubLayer(nn.Module): | class SubLayer(nn.Module): | ||||
| def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1): | def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1): | ||||
| """ | |||||
| :param model_size: int, 输入维度的大小。同时也是输出维度的大小。 | |||||
| :param inner_size: int, FFN层的hidden大小 | |||||
| :param key_size: int, 每个head的维度大小。 | |||||
| :param value_size: int,每个head中value的维度。 | |||||
| :param num_head: int,head的数量。 | |||||
| :param dropout: float。 | |||||
| """ | |||||
| super(TransformerEncoder.SubLayer, self).__init__() | super(TransformerEncoder.SubLayer, self).__init__() | ||||
| self.atte = MultiHeadAtte(model_size, key_size, value_size, num_head, dropout) | self.atte = MultiHeadAtte(model_size, key_size, value_size, num_head, dropout) | ||||
| self.norm1 = nn.LayerNorm(model_size) | self.norm1 = nn.LayerNorm(model_size) | ||||
| @@ -45,6 +46,11 @@ class TransformerEncoder(nn.Module): | |||||
| self.layers = nn.ModuleList([self.SubLayer(**kargs) for _ in range(num_layers)]) | self.layers = nn.ModuleList([self.SubLayer(**kargs) for _ in range(num_layers)]) | ||||
| def forward(self, x, seq_mask=None): | def forward(self, x, seq_mask=None): | ||||
| """ | |||||
| :param x: [batch, seq_len, model_size] 输入序列 | |||||
| :param seq_mask: [batch, seq_len] 输入序列的padding mask | |||||
| :return: [batch, seq_len, model_size] 输出序列 | |||||
| """ | |||||
| output = x | output = x | ||||
| if seq_mask is None: | if seq_mask is None: | ||||
| atte_mask_out = None | atte_mask_out = None | ||||