From 4bc0d7bcd2e95513f403376f59eb623659b562fa Mon Sep 17 00:00:00 2001 From: bin_open_source <342748465@qq.com> Date: Tue, 3 Oct 2023 22:44:05 +0800 Subject: [PATCH] ADD file via upload --- models/encoder.py | 98 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 models/encoder.py diff --git a/models/encoder.py b/models/encoder.py new file mode 100644 index 0000000..3dd4b76 --- /dev/null +++ b/models/encoder.py @@ -0,0 +1,98 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class ConvLayer(nn.Module): + def __init__(self, c_in): + super(ConvLayer, self).__init__() + padding = 1 if torch.__version__>='1.5.0' else 2 + self.downConv = nn.Conv1d(in_channels=c_in, + out_channels=c_in, + kernel_size=3, + padding=padding, + padding_mode='circular') + self.norm = nn.BatchNorm1d(c_in) + self.activation = nn.ELU() + self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) + + def forward(self, x): + x = self.downConv(x.permute(0, 2, 1)) + x = self.norm(x) + x = self.activation(x) + x = self.maxPool(x) + x = x.transpose(1,2) + return x + +class EncoderLayer(nn.Module): + def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"): + super(EncoderLayer, self).__init__() + d_ff = d_ff or 4*d_model + self.attention = attention + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + self.activation = F.relu if activation == "relu" else F.gelu + + def forward(self, x, attn_mask=None): + # x [B, L, D] + # x = x + self.dropout(self.attention( + # x, x, x, + # attn_mask = attn_mask + # )) + new_x, attn = self.attention( + x, x, x, + attn_mask = attn_mask + ) + x = x + self.dropout(new_x) + + y = x = self.norm1(x) + y = self.dropout(self.activation(self.conv1(y.transpose(-1,1)))) + y = self.dropout(self.conv2(y).transpose(-1,1)) + + return self.norm2(x+y), attn + +class Encoder(nn.Module): + def __init__(self, attn_layers, conv_layers=None, norm_layer=None): + super(Encoder, self).__init__() + self.attn_layers = nn.ModuleList(attn_layers) + self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None + self.norm = norm_layer + + def forward(self, x, attn_mask=None): + # x [B, L, D] + attns = [] + if self.conv_layers is not None: + for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers): + x, attn = attn_layer(x, attn_mask=attn_mask) + x = conv_layer(x) + attns.append(attn) + x, attn = self.attn_layers[-1](x, attn_mask=attn_mask) + attns.append(attn) + else: + for attn_layer in self.attn_layers: + x, attn = attn_layer(x, attn_mask=attn_mask) + attns.append(attn) + + if self.norm is not None: + x = self.norm(x) + + return x, attns + +class EncoderStack(nn.Module): + def __init__(self, encoders, inp_lens): + super(EncoderStack, self).__init__() + self.encoders = nn.ModuleList(encoders) + self.inp_lens = inp_lens + + def forward(self, x, attn_mask=None): + # x [B, L, D] + x_stack = []; attns = [] + for i_len, encoder in zip(self.inp_lens, self.encoders): + inp_len = x.shape[1]//(2**i_len) + x_s, attn = encoder(x[:, -inp_len:, :]) + x_stack.append(x_s); attns.append(attn) + x_stack = torch.cat(x_stack, -2) + + return x_stack, attns