Browse Source

Merge pull request #135 from choosewhatulike/pr

Add Star-Transformer
tags/v0.4.0
Xipeng Qiu GitHub 6 years ago
parent
commit
88d4de7c90
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 171 additions and 9 deletions
  1. +146
    -0
      fastNLP/modules/encoder/star_transformer.py
  2. +15
    -9
      fastNLP/modules/encoder/transformer.py
  3. +10
    -0
      test/modules/test_other_modules.py

+ 146
- 0
fastNLP/modules/encoder/star_transformer.py View File

@@ -0,0 +1,146 @@
import torch
from torch import nn
from torch.nn import functional as F
import numpy as NP


class StarTransformer(nn.Module):
"""Star-Transformer Encoder part。
paper: https://arxiv.org/abs/1902.09113

:param hidden_size: int, 输入维度的大小。同时也是输出维度的大小。
:param num_layers: int, star-transformer的层数
:param num_head: int,head的数量。
:param head_dim: int, 每个head的维度大小。
:param dropout: float dropout 概率
:param max_len: int or None, 如果为int,输入序列的最大长度,
模型会为属于序列加上position embedding。
若为None,忽略加上position embedding的步骤
"""
def __init__(self, hidden_size, num_layers, num_head, head_dim, dropout=0.1, max_len=None):
super(StarTransformer, self).__init__()
self.iters = num_layers

self.norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(self.iters)])
self.ring_att = nn.ModuleList(
[MSA1(hidden_size, nhead=num_head, head_dim=head_dim, dropout=dropout)
for _ in range(self.iters)])
self.star_att = nn.ModuleList(
[MSA2(hidden_size, nhead=num_head, head_dim=head_dim, dropout=dropout)
for _ in range(self.iters)])

if max_len is not None:
self.pos_emb = self.pos_emb = nn.Embedding(max_len, hidden_size)
else:
self.pos_emb = None

def forward(self, data, mask):
"""
:param FloatTensor data: [batch, length, hidden] the input sequence
:param ByteTensor mask: [batch, length] the padding mask for input, in which padding pos is 0
:return: [batch, length, hidden] the output sequence
[batch, hidden] the global relay node
"""
def norm_func(f, x):
# B, H, L, 1
return f(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)

B, L, H = data.size()
smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1)

embs = data.permute(0, 2, 1)[:,:,:,None] # B H L 1
if self.pos_emb:
P = self.pos_emb(torch.arange(L, dtype=torch.long, device=embs.device)\
.view(1, L)).permute(0, 2, 1).contiguous()[:, :, :, None] # 1 H L 1
embs = embs + P

nodes = embs
relay = embs.mean(2, keepdim=True)
ex_mask = mask[:, None, :, None].expand(B, H, L, 1)
r_embs = embs.view(B, H, 1, L)
for i in range(self.iters):
ax = torch.cat([r_embs, relay.expand(B, H, 1, L)], 2)
nodes = nodes + F.leaky_relu(self.ring_att[i](norm_func(self.norm[i], nodes), ax=ax))
relay = F.leaky_relu(self.star_att[i](relay, torch.cat([relay, nodes], 2), smask))

nodes = nodes.masked_fill_(ex_mask, 0)

nodes = nodes.view(B, H, L).permute(0, 2, 1)

return nodes, relay.view(B, H)


class MSA1(nn.Module):
def __init__(self, nhid, nhead=10, head_dim=10, dropout=0.1):
super(MSA1, self).__init__()
# Multi-head Self Attention Case 1, doing self-attention for small regions
# Due to the architecture of GPU, using hadamard production and summation are faster than dot production when unfold_size is very small
self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1)
self.WK = nn.Conv2d(nhid, nhead * head_dim, 1)
self.WV = nn.Conv2d(nhid, nhead * head_dim, 1)
self.WO = nn.Conv2d(nhead * head_dim, nhid, 1)

self.drop = nn.Dropout(dropout)

# print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim)
self.nhid, self.nhead, self.head_dim, self.unfold_size = nhid, nhead, head_dim, 3

def forward(self, x, ax=None):
# x: B, H, L, 1, ax : B, H, X, L append features
nhid, nhead, head_dim, unfold_size = self.nhid, self.nhead, self.head_dim, self.unfold_size
B, H, L, _ = x.shape

q, k, v = self.WQ(x), self.WK(x), self.WV(x) # x: (B,H,L,1)

if ax is not None:
aL = ax.shape[2]
ak = self.WK(ax).view(B, nhead, head_dim, aL, L)
av = self.WV(ax).view(B, nhead, head_dim, aL, L)
q = q.view(B, nhead, head_dim, 1, L)
k = F.unfold(k.view(B, nhead * head_dim, L, 1), (unfold_size, 1), padding=(unfold_size // 2, 0))\
.view(B, nhead, head_dim, unfold_size, L)
v = F.unfold(v.view(B, nhead * head_dim, L, 1), (unfold_size, 1), padding=(unfold_size // 2, 0))\
.view(B, nhead, head_dim, unfold_size, L)
if ax is not None:
k = torch.cat([k, ak], 3)
v = torch.cat([v, av], 3)

alphas = self.drop(F.softmax((q * k).sum(2, keepdim=True) / NP.sqrt(head_dim), 3)) # B N L 1 U
att = (alphas * v).sum(3).view(B, nhead * head_dim, L, 1)

ret = self.WO(att)

return ret


class MSA2(nn.Module):
def __init__(self, nhid, nhead=10, head_dim=10, dropout=0.1):
# Multi-head Self Attention Case 2, a broadcastable query for a sequence key and value
super(MSA2, self).__init__()
self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1)
self.WK = nn.Conv2d(nhid, nhead * head_dim, 1)
self.WV = nn.Conv2d(nhid, nhead * head_dim, 1)
self.WO = nn.Conv2d(nhead * head_dim, nhid, 1)

self.drop = nn.Dropout(dropout)

# print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim)
self.nhid, self.nhead, self.head_dim, self.unfold_size = nhid, nhead, head_dim, 3

def forward(self, x, y, mask=None):
# x: B, H, 1, 1, 1 y: B H L 1
nhid, nhead, head_dim, unfold_size = self.nhid, self.nhead, self.head_dim, self.unfold_size
B, H, L, _ = y.shape

q, k, v = self.WQ(x), self.WK(y), self.WV(y)

q = q.view(B, nhead, 1, head_dim) # B, H, 1, 1 -> B, N, 1, h
k = k.view(B, nhead, head_dim, L) # B, H, L, 1 -> B, N, h, L
v = k.view(B, nhead, head_dim, L).permute(0, 1, 3, 2) # B, H, L, 1 -> B, N, L, h
pre_a = torch.matmul(q, k) / NP.sqrt(head_dim)
if mask is not None:
pre_a = pre_a.masked_fill(mask[:, None, None, :], -float('inf'))
alphas = self.drop(F.softmax(pre_a, 3)) # B, N, 1, L
att = torch.matmul(alphas, v).view(B, -1, 1, 1) # B, N, 1, h -> B, N*h, 1, 1
return self.WO(att)


+ 15
- 9
fastNLP/modules/encoder/transformer.py View File

@@ -5,17 +5,18 @@ from ..dropout import TimestepDropout




class TransformerEncoder(nn.Module): class TransformerEncoder(nn.Module):
"""transformer的encoder模块,不包含embedding层

:param num_layers: int, transformer的层数
:param model_size: int, 输入维度的大小。同时也是输出维度的大小。
:param inner_size: int, FFN层的hidden大小
:param key_size: int, 每个head的维度大小。
:param value_size: int,每个head中value的维度。
:param num_head: int,head的数量。
:param dropout: float。
"""
class SubLayer(nn.Module): class SubLayer(nn.Module):
def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1): def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1):
"""

:param model_size: int, 输入维度的大小。同时也是输出维度的大小。
:param inner_size: int, FFN层的hidden大小
:param key_size: int, 每个head的维度大小。
:param value_size: int,每个head中value的维度。
:param num_head: int,head的数量。
:param dropout: float。
"""
super(TransformerEncoder.SubLayer, self).__init__() super(TransformerEncoder.SubLayer, self).__init__()
self.atte = MultiHeadAtte(model_size, key_size, value_size, num_head, dropout) self.atte = MultiHeadAtte(model_size, key_size, value_size, num_head, dropout)
self.norm1 = nn.LayerNorm(model_size) self.norm1 = nn.LayerNorm(model_size)
@@ -45,6 +46,11 @@ class TransformerEncoder(nn.Module):
self.layers = nn.ModuleList([self.SubLayer(**kargs) for _ in range(num_layers)]) self.layers = nn.ModuleList([self.SubLayer(**kargs) for _ in range(num_layers)])


def forward(self, x, seq_mask=None): def forward(self, x, seq_mask=None):
"""
:param x: [batch, seq_len, model_size] 输入序列
:param seq_mask: [batch, seq_len] 输入序列的padding mask
:return: [batch, seq_len, model_size] 输出序列
"""
output = x output = x
if seq_mask is None: if seq_mask is None:
atte_mask_out = None atte_mask_out = None


+ 10
- 0
test/modules/test_other_modules.py View File

@@ -3,6 +3,7 @@ import unittest
import torch import torch


from fastNLP.modules.other_modules import GroupNorm, LayerNormalization, BiLinear, BiAffine from fastNLP.modules.other_modules import GroupNorm, LayerNormalization, BiLinear, BiAffine
from fastNLP.modules.encoder.star_transformer import StarTransformer




class TestGroupNorm(unittest.TestCase): class TestGroupNorm(unittest.TestCase):
@@ -49,3 +50,12 @@ class TestBiAffine(unittest.TestCase):
encoder_input = torch.randn((batch_size, decoder_length, 10)) encoder_input = torch.randn((batch_size, decoder_length, 10))
y = layer(decoder_input, encoder_input) y = layer(decoder_input, encoder_input)
self.assertEqual(tuple(y.shape), (batch_size, 25, encoder_length, 1)) self.assertEqual(tuple(y.shape), (batch_size, 25, encoder_length, 1))

class TestStarTransformer(unittest.TestCase):
def test_1(self):
model = StarTransformer(num_layers=6, hidden_size=100, num_head=8, head_dim=20, max_len=100)
x = torch.rand(16, 45, 100)
mask = torch.ones(16, 45).byte()
y, yn = model(x, mask)
self.assertEqual(tuple(y.size()), (16, 45, 100))
self.assertEqual(tuple(yn.size()), (16, 100))

Loading…
Cancel
Save