|
- import mindspore.nn as nn
- import mindspore.ops.operations as ops
- import mindspore.common.initializer as init
- import mindspore.tensor as Tensor
- import mindspore.common.dtype as mstype
- import mindspore.common.dtype as mstype
-
-
- class ConvLayer(nn.Cell):
- def __init__(self, c_in):
- super(ConvLayer, self).__init__()
- padding = 1 if torch.__version__ >= '1.5.0' else 2
- self.downConv = nn.Conv1d(in_channels=c_in,
- out_channels=c_in,
- kernel_size=3,
- padding=padding,
- padding_mode='circular')
- self.norm = nn.BatchNorm1d(c_in)
- self.activation = nn.ELU()
- self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
-
- def construct(self, x):
- x = self.downConv(x.permute(0, 2, 1))
- x = self.norm(x)
- x = self.activation(x)
- x = self.maxPool(x)
- x = x.transpose(1, 2)
- return x
-
- class EncoderLayer(nn.Cell):
- def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
- super(EncoderLayer, self).__init__()
- d_ff = d_ff or 4 * d_model
- self.attention = attention
- self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
- self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
- self.norm1 = nn.LayerNorm(d_model)
- self.norm2 = nn.LayerNorm(d_model)
- self.dropout = nn.Dropout(dropout)
- self.activation = ops.ReLU() if activation == "relu" else ops.GELU()
-
- def construct(self, x, attn_mask=None):
- new_x, attn = self.attention(
- x, x, x,
- attn_mask=attn_mask
- )
- x = x + self.dropout(new_x)
-
- y = x = self.norm1(x)
- y = self.dropout(self.activation(self.conv1(y.transpose(2, 1))))
- y = self.dropout(self.conv2(y).transpose(2, 1))
-
- return self.norm2(x + y), attn
-
- class Encoder(nn.Cell):
- def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
- super(Encoder, self).__init__()
- self.attn_layers = nn.CellList(attn_layers)
- self.conv_layers = nn.CellList(conv_layers) if conv_layers is not None else None
- self.norm = norm_layer
-
- def construct(self, x, attn_mask=None):
- attns = []
- if self.conv_layers is not None:
- for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers):
- x, attn = attn_layer(x, attn_mask=attn_mask)
- x = conv_layer(x)
- attns.append(attn)
- x, attn = self.attn_layers[-1](x, attn_mask=attn_mask)
- attns.append(attn)
- else:
- for attn_layer in self.attn_layers:
- x, attn = attn_layer(x, attn_mask=attn_mask)
- attns.append(attn)
-
- if self.norm is not None:
- x = self.norm(x)
-
- return x, attns
-
- class EncoderStack(nn.Cell):
- def __init__(self, encoders, inp_lens):
- super(EncoderStack, self).__init__()
- self.encoders = nn.CellList(encoders)
- self.inp_lens = inp_lens
-
- def construct(self, x, attn_mask=None):
- x_stack = []
- attns = []
- for i_len, encoder in zip(self.inp_lens, self.encoders):
- inp_len = x.shape[1] // (2 ** i_len)
- x_s, attn = encoder(x[:, -inp_len:, :])
- x_stack.append(x_s)
- attns.append(attn)
- x_stack = ops.Concat(2)(x_stack)
-
- return x_stack, attns
|