Browse Source

Merge pull request #141 from wlhgtc/master

Another bug in Star Transformer
tags/v0.4.0
Yunfan Shao GitHub 5 years ago
parent
commit
90d112c07c
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 3 deletions
  1. +1
    -3
      fastNLP/modules/encoder/star_transformer.py

+ 1
- 3
fastNLP/modules/encoder/star_transformer.py View File

@@ -7,7 +7,6 @@ import numpy as NP
class StarTransformer(nn.Module): class StarTransformer(nn.Module):
"""Star-Transformer Encoder part。 """Star-Transformer Encoder part。
paper: https://arxiv.org/abs/1902.09113 paper: https://arxiv.org/abs/1902.09113

:param hidden_size: int, 输入维度的大小。同时也是输出维度的大小。 :param hidden_size: int, 输入维度的大小。同时也是输出维度的大小。
:param num_layers: int, star-transformer的层数 :param num_layers: int, star-transformer的层数
:param num_head: int,head的数量。 :param num_head: int,head的数量。
@@ -137,11 +136,10 @@ class MSA2(nn.Module):


q = q.view(B, nhead, 1, head_dim) # B, H, 1, 1 -> B, N, 1, h q = q.view(B, nhead, 1, head_dim) # B, H, 1, 1 -> B, N, 1, h
k = k.view(B, nhead, head_dim, L) # B, H, L, 1 -> B, N, h, L k = k.view(B, nhead, head_dim, L) # B, H, L, 1 -> B, N, h, L
v = k.view(B, nhead, head_dim, L).permute(0, 1, 3, 2) # B, H, L, 1 -> B, N, L, h
v = v.view(B, nhead, head_dim, L).permute(0, 1, 3, 2) # B, H, L, 1 -> B, N, L, h
pre_a = torch.matmul(q, k) / NP.sqrt(head_dim) pre_a = torch.matmul(q, k) / NP.sqrt(head_dim)
if mask is not None: if mask is not None:
pre_a = pre_a.masked_fill(mask[:, None, None, :], -float('inf')) pre_a = pre_a.masked_fill(mask[:, None, None, :], -float('inf'))
alphas = self.drop(F.softmax(pre_a, 3)) # B, N, 1, L alphas = self.drop(F.softmax(pre_a, 3)) # B, N, 1, L
att = torch.matmul(alphas, v).view(B, -1, 1, 1) # B, N, 1, h -> B, N*h, 1, 1 att = torch.matmul(alphas, v).view(B, -1, 1, 1) # B, N, 1, h -> B, N*h, 1, 1
return self.WO(att) return self.WO(att)


Loading…
Cancel
Save