From e89894d48eb551ff59e2994c8e771a16c9c2bb35 Mon Sep 17 00:00:00 2001 From: Huyf9 <98331980+Huyf9@users.noreply.github.com> Date: Wed, 26 Jul 2023 10:02:12 +0800 Subject: [PATCH] update ci_pipeline (#1) * Update ci_pipeline.yaml * update ci_pipeline * update ci_pipeline * update pylint * A2Attention * ACmixAttention * CondConv * DepthwiseSeparableConvolution --- .github/pylint.conf | 3 +- .github/workflows/ci_pipeline.yaml | 8 +- model/attention/A2Attention.py | 29 +++-- model/attention/ACmixAttention.py | 127 ++++++++++++++++++++ model/conv/CondConv.py | 90 ++++++++++++++ model/conv/DepthwiseSeparableConvolution.py | 39 ++++++ 6 files changed, 280 insertions(+), 16 deletions(-) create mode 100644 model/attention/ACmixAttention.py create mode 100644 model/conv/CondConv.py create mode 100644 model/conv/DepthwiseSeparableConvolution.py diff --git a/.github/pylint.conf b/.github/pylint.conf index b550edb..76189fa 100644 --- a/.github/pylint.conf +++ b/.github/pylint.conf @@ -161,7 +161,8 @@ disable=raw-checker-failed, too-few-public-methods, no-member, protected-access, - abstract-method + abstract-method, + C0103 # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option diff --git a/.github/workflows/ci_pipeline.yaml b/.github/workflows/ci_pipeline.yaml index e852e81..1292604 100644 --- a/.github/workflows/ci_pipeline.yaml +++ b/.github/workflows/ci_pipeline.yaml @@ -7,12 +7,12 @@ on: pull_request: branches: [ "main" ] paths: - - 'External-Attention-MindSpore/**' + - 'model/**' - '.github/workflows/**' push: branches: [ "main" ] paths: - - 'External-Attention-MindSpore/**' + - 'model/**' permissions: contents: read @@ -38,9 +38,9 @@ jobs: # run: | # python .github/install_mindspore.py # pip install -r download.txt - - name: Analysing the External-Attention-MindSpore code with pylint + - name: Analysing the model code with pylint run: | - pylint External-Attention-MindSpore --rcfile=.github/pylint.conf + pylint model --rcfile=.github/pylint.conf # - name: Analysing the tests code with pylint # run: | # pylint tests --rcfile=.github/pylint.conf diff --git a/model/attention/A2Attention.py b/model/attention/A2Attention.py index e9b2f1c..8724fbf 100644 --- a/model/attention/A2Attention.py +++ b/model/attention/A2Attention.py @@ -1,9 +1,16 @@ +""" +DoubleAttention +""" + import mindspore as ms from mindspore import nn from mindspore.common.initializer import initializer, HeNormal, Normal - + class DoubleAttention(nn.Cell): + """ + Double Attention + """ def __init__(self, in_channels, c_m, c_n, reconstruct=True): super().__init__() self.in_channels = in_channels @@ -21,6 +28,7 @@ class DoubleAttention(nn.Cell): self.apply(self.init_weights) def init_weights(self, cell): + """ init weight """ if isinstance(cell, nn.Conv2d): cell.weight.set_data(initializer(HeNormal(mode='fan_out'), cell.weight.shape, cell.weight.dtype)) if cell.bias is not None: @@ -34,28 +42,27 @@ class DoubleAttention(nn.Cell): cell.bias.set_data(initializer('zeros', cell.bias.shape, cell.bias.dtype)) def construct(self, x): - B, C, H, W = x.shape - assert C == self.in_channels + b, c, h, w = x.shape + assert c == self.in_channels a = self.convA(x) # b, c_m, h, w b = self.convB(x) # b, c_n, h, w v = self.convV(x) # b, c_n, h, w - tmpA = a.view(B, self.c_m, -1) - attention_maps = ms.ops.softmax(b.view(B, self.c_n, -1)) - attention_vectors = ms.ops.softmax(v.view(B, self.c_n, -1)) + tmpA = a.view(b, self.c_m, -1) + attention_maps = ms.ops.softmax(b.view(b, self.c_n, -1)) + attention_vectors = ms.ops.softmax(v.view(b, self.c_n, -1)) global_descriptors = ms.ops.bmm(tmpA, attention_maps.permute(0, 2, 1)) tmpZ = ms.ops.matmul(global_descriptors, attention_vectors) - tmpZ = tmpZ.view(B, self.c_m, H, W) + tmpZ = tmpZ.view(b, self.c_m, h, w) if self.reconstruct: tmpZ = self.conv_reconstruct(tmpZ) return tmpZ - + if __name__ == "__main__": - input = ms.ops.randn([12, 512, 7, 7]) + in_tensor = ms.ops.randn([12, 512, 7, 7]) a2 = DoubleAttention(512, 128, 128) - output = a2(input) + output = a2(in_tensor) print(output.shape) - diff --git a/model/attention/ACmixAttention.py b/model/attention/ACmixAttention.py new file mode 100644 index 0000000..fa76d58 --- /dev/null +++ b/model/attention/ACmixAttention.py @@ -0,0 +1,127 @@ +""" ACmix Attention """ +import mindspore as ms +from mindspore import nn + + +def position(H, W): + """ get position encode """ + loc_w = ms.ops.linspace(-1., 1., W).unsqueeze(0).repeat(H, axis=0) + loc_h = ms.ops.linspace(-1., 1., H).unsqueeze(1).repeat(W, axis=1) + loc = ms.ops.cat([loc_w.unsqueeze(0), loc_h.unsqueeze(0)], 0) + loc = loc.reshape(-1, *loc.shape) + # print(loc) + return loc + + +def stride(x, strides): + """ split x with strides in last two dimension """ + # B, C, H, W = x.shape + return x[:, :, ::strides, ::strides] + + +def init_rate_half(tensor): + """ fill data with 0.5 """ + if isinstance(tensor, ms.Parameter): + fill_data = ms.ops.ones(tensor.shape) * 0.5 + tensor.set_dtype(fill_data.dtype) + tensor.data.set_data(fill_data) + return tensor + + +def init_rate_0(tensor): + """ fill data with 0 """ + if isinstance(tensor, ms.Parameter): + fill_data = ms.ops.zeros(tensor.shape) + tensor.set_dtype(fill_data.dtype) + tensor.data.set_data(fill_data) + return tensor + + +class ACmix(nn.Cell): + """ ACmix """ + def __init__(self, in_planes, out_planes, kernel_att=7, head=4, kernel_conv=3, strides=1, dilation=1): + super().__init__() + self.in_planes = in_planes + self.out_planes = out_planes + self.head = head + self.kernel_att = kernel_att + self.kernel_conv = kernel_conv + self.stride = strides + self.dilation = dilation + self.rate1 = ms.Parameter(ms.Tensor(1)) + self.rate2 = ms.Parameter(ms.Tensor(1)) + self.head_dim = self.out_planes // self.head + + self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1) + self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1) + self.conv3 = nn.Conv2d(in_planes, out_planes, kernel_size=1) + self.conv_p = nn.Conv2d(2, self.head_dim, kernel_size=1) + + self.padding_att = (self.dilation * (self.kernel_att - 1) + 1) // 2 + self.pad_att = nn.ReflectionPad2d(self.padding_att) + self.unfold = nn.Unfold(ksizes=[1, self.kernel_att, self.kernel_att, 1], + strides=[1, self.stride, self.stride, 1], + rates=[1, 1, 1, 1]) + self.softmax = nn.Softmax(axis=1) + + self.fc = nn.Conv2d(3 * self.head, self.kernel_conv ** 2, kernel_size=1) + self.dep_conv = nn.Conv2d(self.kernel_conv ** 2 * self.head_dim, out_planes, kernel_size=self.kernel_conv, + stride=strides, pad_mode='pad', padding=1, has_bias=True, group=self.head_dim) + + self.reset_parameters() + + def reset_parameters(self): + """ reset parameters """ + init_rate_half(self.rate1) + init_rate_half(self.rate2) + kernel = ms.ops.zeros((self.kernel_conv ** 2, self.kernel_conv, self.kernel_conv)) + for i in range(self.kernel_conv ** 2): + kernel[i, i // self.kernel_conv, i % self.kernel_conv] = 1. + kernel = kernel.reshape(1, *kernel.shape).repeat(self.out_planes, axis=0) + self.dep_conv.weight = ms.Parameter(default_input=kernel, requires_grad=True) + self.dep_conv.bias = init_rate_0(self.dep_conv.bias) + + def construct(self, x): + q, k, v = self.conv1(x), self.conv2(x), self.conv3(x) + scaling = float(self.head_dim) ** 0.5 + B, _, H, W = q.shape + h_out, w_out = H // self.stride, W // self.stride + + pe = self.conv_p(position(H, W)) + + q_att = q.view(B * self.head, self.head_dim, H, W) * scaling + k_att = k.view(B * self.head, self.head_dim, H, W) + v_att = v.view(B * self.head, self.head_dim, H, W) + + if self.stride > 1: + q_att = stride(q_att, self.stride) + q_pe = stride(pe, self.stride) + else: + q_pe = pe + + unfold_k = self.unfold(self.pad_att(k_att)).view(B * self.head, self.head_dim, + self.kernel_att ** 2, h_out, w_out) + unfold_rpe = self.unfold(self.pad_att(pe)).view(1, self.head_dim, self.kernel_att ** 2, + h_out, w_out) + + att = (q_att.unsqueeze(2) * (unfold_k + q_pe.unsqueeze(2) - unfold_rpe)).sum(1) + att = self.softmax(att) + + out_att = self.unfold(self.pad_att(v_att)).view(B * self.head, self.head_dim, self.kernel_att ** 2, + h_out, w_out) + out_att = (att.unsqueeze(1) * out_att).sum(2).view(B, self.out_planes, h_out, w_out) + + f_all = self.fc( + ms.ops.cat([q.view(B, self.head, self.head_dim, H * W), k.view(B, self.head, self.head_dim, H * W), + v.view(B, self.head, self.head_dim, H * W)], 1)) + f_conv = f_all.permute(0, 2, 1, 3).reshape(x.shape[0], -1, x.shape[-2], x.shape[-1]) + + out_conv = self.dep_conv(f_conv) + return self.rate1 * out_att + self.rate2 * out_conv + + +if __name__ == "__main__": + in_tensor = ms.ops.randn((50, 256, 7, 7)) + acmix = ACmix(in_planes=256, out_planes=256) + out = acmix(in_tensor) + print(out.shape) diff --git a/model/conv/CondConv.py b/model/conv/CondConv.py new file mode 100644 index 0000000..72412c8 --- /dev/null +++ b/model/conv/CondConv.py @@ -0,0 +1,90 @@ +""" CondConv """ +import mindspore as ms +from mindspore import nn +from mindspore.common.initializer import initializer, HeNormal, HeUniform + + +class Attention(nn.Cell): + """ Attnetion """ + def __init__(self, in_channels, K, init_weight=True): + super().__init__() + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.net = nn.Conv2d(in_channels=in_channels, out_channels=K, kernel_size=1) + self.sigmoid = nn.Sigmoid() + + if init_weight: + self.apply(self.init_weights) + + def init_weights(self, cell): + """ initialize weights """ + if isinstance(cell, nn.Conv2d): + cell.weight.set_data(initializer(HeNormal(mode='fan_out', nonlinearity='relu'), + cell.weight.shape, cell.weight.dtype)) + if cell.bias is not None: + cell.bias.set_data(initializer('zeros', cell.bias.shape, cell.bias.dtype)) + elif isinstance(cell, nn.BatchNorm2d): + cell.gamma.set_data(initializer('ones', cell.gamma.shape, cell.dtype)) + cell.beta.set_data(initializer('zeros', cell.beta.shape, cell.beta.dtype)) + + def construct(self, x): + att = self.avgpool(x) + att = self.net(att).view(x.shape[0], -1) + return self.sigmoid(att) + + +class CondConv(nn.Cell): + """ CondConv """ + def __init__(self, in_channels, out_channels, kernel_size, stride, padding=0, + dilation=1, groups=1, bias=True, K=4, init_weight=True): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.bias = bias + self.K = K + self.attention = Attention(in_channels, K) + + self.weight = ms.Parameter(ms.ops.randn(K, out_channels, in_channels // groups, kernel_size, kernel_size), + requires_grad=True) + if bias: + self.bias = ms.Parameter(ms.ops.randn(K, out_channels), requires_grad=True) + else: + self.bias = None + + if init_weight: + self.init_weights() + + def init_weights(self): + """ initialize weights """ + for i in range(self.K): + self.weight[i] = initializer(HeUniform(), self.weight[i].shape, self.weight[i].dtype) + + def construct(self, x): + B, _, H, W = x.shape + softmax_att = self.attention(x) + x = x.view(1, -1, H, W) + weight = self.weight.view(self.K, -1) + aggregate_weight = ms.ops.mm(softmax_att, weight).view(B * self.out_channels, self.in_channels // self.groups, + self.kernel_size, self.kernel_size) + if self.bias: + bias = self.bias.view(self.K, -1) + aggregate_bias = ms.ops.mm(softmax_att, bias).view(-1) + output = ms.ops.conv2d(x, weight=aggregate_weight, bias=aggregate_bias, stride=self.stride, pad_mode="pad", + padding=self.padding, dilation=self.dilation, groups=self.groups * B) + else: + output = ms.ops.conv2d(x, weight=aggregate_weight, bias=None, stride=self.stride, pad_mode="pad", + padding=self.padding, dilation=self.dilation, groups=self.groups * B) + + output = output.view(B, self.out_channels, H, W) + return output + + +if __name__ == "__main__": + in_tensor = ms.ops.randn((2, 32, 64, 64)) + cconv = CondConv(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) + out = cconv(in_tensor) + print(out.shape) diff --git a/model/conv/DepthwiseSeparableConvolution.py b/model/conv/DepthwiseSeparableConvolution.py new file mode 100644 index 0000000..41ab2ff --- /dev/null +++ b/model/conv/DepthwiseSeparableConvolution.py @@ -0,0 +1,39 @@ +""" Depthwise and Separable Convolution """ +import mindspore as ms +from mindspore import nn + + +class DepthwiseSeparableConvolution(nn.Cell): + """ DepthwiseSeparableConvolution """ + def __init__(self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1): + super().__init__() + + self.depthwise_conv = nn.Conv2d(in_channels=in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + stride=stride, + pad_mode='pad', + padding=padding) + + self.pointwise_conv = nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + group=1) + + def construct(self, x): + x = self.depthwise_conv(x) + out = self.pointwise_conv(x) + return out + + +if __name__ == '__main__': + in_tensor = ms.ops.randn((1, 3, 224, 224), dtype=ms.float32) + conv = DepthwiseSeparableConvolution(3, 64) + output = conv(in_tensor) + print(output.shape)