Browse Source

update ci_pipeline (#1)

* Update ci_pipeline.yaml

* update ci_pipeline

* update ci_pipeline

* update pylint

* A2Attention

* ACmixAttention

* CondConv

* DepthwiseSeparableConvolution
v1
Huyf9 GitHub 2 years ago
parent
commit
e89894d48e
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 280 additions and 16 deletions
  1. +2
    -1
      .github/pylint.conf
  2. +4
    -4
      .github/workflows/ci_pipeline.yaml
  3. +18
    -11
      model/attention/A2Attention.py
  4. +127
    -0
      model/attention/ACmixAttention.py
  5. +90
    -0
      model/conv/CondConv.py
  6. +39
    -0
      model/conv/DepthwiseSeparableConvolution.py

+ 2
- 1
.github/pylint.conf View File

@@ -161,7 +161,8 @@ disable=raw-checker-failed,
too-few-public-methods,
no-member,
protected-access,
abstract-method
abstract-method,
C0103

# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option


+ 4
- 4
.github/workflows/ci_pipeline.yaml View File

@@ -7,12 +7,12 @@ on:
pull_request:
branches: [ "main" ]
paths:
- 'External-Attention-MindSpore/**'
- 'model/**'
- '.github/workflows/**'
push:
branches: [ "main" ]
paths:
- 'External-Attention-MindSpore/**'
- 'model/**'

permissions:
contents: read
@@ -38,9 +38,9 @@ jobs:
# run: |
# python .github/install_mindspore.py
# pip install -r download.txt
- name: Analysing the External-Attention-MindSpore code with pylint
- name: Analysing the model code with pylint
run: |
pylint External-Attention-MindSpore --rcfile=.github/pylint.conf
pylint model --rcfile=.github/pylint.conf
# - name: Analysing the tests code with pylint
# run: |
# pylint tests --rcfile=.github/pylint.conf

+ 18
- 11
model/attention/A2Attention.py View File

@@ -1,9 +1,16 @@
"""
DoubleAttention
"""

import mindspore as ms
from mindspore import nn
from mindspore.common.initializer import initializer, HeNormal, Normal

class DoubleAttention(nn.Cell):
"""
Double Attention
"""
def __init__(self, in_channels, c_m, c_n, reconstruct=True):
super().__init__()
self.in_channels = in_channels
@@ -21,6 +28,7 @@ class DoubleAttention(nn.Cell):
self.apply(self.init_weights)

def init_weights(self, cell):
""" init weight """
if isinstance(cell, nn.Conv2d):
cell.weight.set_data(initializer(HeNormal(mode='fan_out'), cell.weight.shape, cell.weight.dtype))
if cell.bias is not None:
@@ -34,28 +42,27 @@ class DoubleAttention(nn.Cell):
cell.bias.set_data(initializer('zeros', cell.bias.shape, cell.bias.dtype))

def construct(self, x):
B, C, H, W = x.shape
assert C == self.in_channels
b, c, h, w = x.shape
assert c == self.in_channels
a = self.convA(x) # b, c_m, h, w
b = self.convB(x) # b, c_n, h, w
v = self.convV(x) # b, c_n, h, w
tmpA = a.view(B, self.c_m, -1)
attention_maps = ms.ops.softmax(b.view(B, self.c_n, -1))
attention_vectors = ms.ops.softmax(v.view(B, self.c_n, -1))
tmpA = a.view(b, self.c_m, -1)
attention_maps = ms.ops.softmax(b.view(b, self.c_n, -1))
attention_vectors = ms.ops.softmax(v.view(b, self.c_n, -1))

global_descriptors = ms.ops.bmm(tmpA, attention_maps.permute(0, 2, 1))

tmpZ = ms.ops.matmul(global_descriptors, attention_vectors)
tmpZ = tmpZ.view(B, self.c_m, H, W)
tmpZ = tmpZ.view(b, self.c_m, h, w)
if self.reconstruct:
tmpZ = self.conv_reconstruct(tmpZ)

return tmpZ

if __name__ == "__main__":
input = ms.ops.randn([12, 512, 7, 7])
in_tensor = ms.ops.randn([12, 512, 7, 7])
a2 = DoubleAttention(512, 128, 128)
output = a2(input)
output = a2(in_tensor)
print(output.shape)


+ 127
- 0
model/attention/ACmixAttention.py View File

@@ -0,0 +1,127 @@
""" ACmix Attention """
import mindspore as ms
from mindspore import nn


def position(H, W):
""" get position encode """
loc_w = ms.ops.linspace(-1., 1., W).unsqueeze(0).repeat(H, axis=0)
loc_h = ms.ops.linspace(-1., 1., H).unsqueeze(1).repeat(W, axis=1)
loc = ms.ops.cat([loc_w.unsqueeze(0), loc_h.unsqueeze(0)], 0)
loc = loc.reshape(-1, *loc.shape)
# print(loc)
return loc


def stride(x, strides):
""" split x with strides in last two dimension """
# B, C, H, W = x.shape
return x[:, :, ::strides, ::strides]


def init_rate_half(tensor):
""" fill data with 0.5 """
if isinstance(tensor, ms.Parameter):
fill_data = ms.ops.ones(tensor.shape) * 0.5
tensor.set_dtype(fill_data.dtype)
tensor.data.set_data(fill_data)
return tensor


def init_rate_0(tensor):
""" fill data with 0 """
if isinstance(tensor, ms.Parameter):
fill_data = ms.ops.zeros(tensor.shape)
tensor.set_dtype(fill_data.dtype)
tensor.data.set_data(fill_data)
return tensor


class ACmix(nn.Cell):
""" ACmix """
def __init__(self, in_planes, out_planes, kernel_att=7, head=4, kernel_conv=3, strides=1, dilation=1):
super().__init__()
self.in_planes = in_planes
self.out_planes = out_planes
self.head = head
self.kernel_att = kernel_att
self.kernel_conv = kernel_conv
self.stride = strides
self.dilation = dilation
self.rate1 = ms.Parameter(ms.Tensor(1))
self.rate2 = ms.Parameter(ms.Tensor(1))
self.head_dim = self.out_planes // self.head

self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1)
self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1)
self.conv3 = nn.Conv2d(in_planes, out_planes, kernel_size=1)
self.conv_p = nn.Conv2d(2, self.head_dim, kernel_size=1)

self.padding_att = (self.dilation * (self.kernel_att - 1) + 1) // 2
self.pad_att = nn.ReflectionPad2d(self.padding_att)
self.unfold = nn.Unfold(ksizes=[1, self.kernel_att, self.kernel_att, 1],
strides=[1, self.stride, self.stride, 1],
rates=[1, 1, 1, 1])
self.softmax = nn.Softmax(axis=1)

self.fc = nn.Conv2d(3 * self.head, self.kernel_conv ** 2, kernel_size=1)
self.dep_conv = nn.Conv2d(self.kernel_conv ** 2 * self.head_dim, out_planes, kernel_size=self.kernel_conv,
stride=strides, pad_mode='pad', padding=1, has_bias=True, group=self.head_dim)

self.reset_parameters()

def reset_parameters(self):
""" reset parameters """
init_rate_half(self.rate1)
init_rate_half(self.rate2)
kernel = ms.ops.zeros((self.kernel_conv ** 2, self.kernel_conv, self.kernel_conv))
for i in range(self.kernel_conv ** 2):
kernel[i, i // self.kernel_conv, i % self.kernel_conv] = 1.
kernel = kernel.reshape(1, *kernel.shape).repeat(self.out_planes, axis=0)
self.dep_conv.weight = ms.Parameter(default_input=kernel, requires_grad=True)
self.dep_conv.bias = init_rate_0(self.dep_conv.bias)

def construct(self, x):
q, k, v = self.conv1(x), self.conv2(x), self.conv3(x)
scaling = float(self.head_dim) ** 0.5
B, _, H, W = q.shape
h_out, w_out = H // self.stride, W // self.stride

pe = self.conv_p(position(H, W))

q_att = q.view(B * self.head, self.head_dim, H, W) * scaling
k_att = k.view(B * self.head, self.head_dim, H, W)
v_att = v.view(B * self.head, self.head_dim, H, W)

if self.stride > 1:
q_att = stride(q_att, self.stride)
q_pe = stride(pe, self.stride)
else:
q_pe = pe

unfold_k = self.unfold(self.pad_att(k_att)).view(B * self.head, self.head_dim,
self.kernel_att ** 2, h_out, w_out)
unfold_rpe = self.unfold(self.pad_att(pe)).view(1, self.head_dim, self.kernel_att ** 2,
h_out, w_out)

att = (q_att.unsqueeze(2) * (unfold_k + q_pe.unsqueeze(2) - unfold_rpe)).sum(1)
att = self.softmax(att)

out_att = self.unfold(self.pad_att(v_att)).view(B * self.head, self.head_dim, self.kernel_att ** 2,
h_out, w_out)
out_att = (att.unsqueeze(1) * out_att).sum(2).view(B, self.out_planes, h_out, w_out)

f_all = self.fc(
ms.ops.cat([q.view(B, self.head, self.head_dim, H * W), k.view(B, self.head, self.head_dim, H * W),
v.view(B, self.head, self.head_dim, H * W)], 1))
f_conv = f_all.permute(0, 2, 1, 3).reshape(x.shape[0], -1, x.shape[-2], x.shape[-1])

out_conv = self.dep_conv(f_conv)
return self.rate1 * out_att + self.rate2 * out_conv


if __name__ == "__main__":
in_tensor = ms.ops.randn((50, 256, 7, 7))
acmix = ACmix(in_planes=256, out_planes=256)
out = acmix(in_tensor)
print(out.shape)

+ 90
- 0
model/conv/CondConv.py View File

@@ -0,0 +1,90 @@
""" CondConv """
import mindspore as ms
from mindspore import nn
from mindspore.common.initializer import initializer, HeNormal, HeUniform


class Attention(nn.Cell):
""" Attnetion """
def __init__(self, in_channels, K, init_weight=True):
super().__init__()
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.net = nn.Conv2d(in_channels=in_channels, out_channels=K, kernel_size=1)
self.sigmoid = nn.Sigmoid()

if init_weight:
self.apply(self.init_weights)

def init_weights(self, cell):
""" initialize weights """
if isinstance(cell, nn.Conv2d):
cell.weight.set_data(initializer(HeNormal(mode='fan_out', nonlinearity='relu'),
cell.weight.shape, cell.weight.dtype))
if cell.bias is not None:
cell.bias.set_data(initializer('zeros', cell.bias.shape, cell.bias.dtype))
elif isinstance(cell, nn.BatchNorm2d):
cell.gamma.set_data(initializer('ones', cell.gamma.shape, cell.dtype))
cell.beta.set_data(initializer('zeros', cell.beta.shape, cell.beta.dtype))

def construct(self, x):
att = self.avgpool(x)
att = self.net(att).view(x.shape[0], -1)
return self.sigmoid(att)


class CondConv(nn.Cell):
""" CondConv """
def __init__(self, in_channels, out_channels, kernel_size, stride, padding=0,
dilation=1, groups=1, bias=True, K=4, init_weight=True):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.bias = bias
self.K = K
self.attention = Attention(in_channels, K)

self.weight = ms.Parameter(ms.ops.randn(K, out_channels, in_channels // groups, kernel_size, kernel_size),
requires_grad=True)
if bias:
self.bias = ms.Parameter(ms.ops.randn(K, out_channels), requires_grad=True)
else:
self.bias = None

if init_weight:
self.init_weights()

def init_weights(self):
""" initialize weights """
for i in range(self.K):
self.weight[i] = initializer(HeUniform(), self.weight[i].shape, self.weight[i].dtype)

def construct(self, x):
B, _, H, W = x.shape
softmax_att = self.attention(x)
x = x.view(1, -1, H, W)
weight = self.weight.view(self.K, -1)
aggregate_weight = ms.ops.mm(softmax_att, weight).view(B * self.out_channels, self.in_channels // self.groups,
self.kernel_size, self.kernel_size)
if self.bias:
bias = self.bias.view(self.K, -1)
aggregate_bias = ms.ops.mm(softmax_att, bias).view(-1)
output = ms.ops.conv2d(x, weight=aggregate_weight, bias=aggregate_bias, stride=self.stride, pad_mode="pad",
padding=self.padding, dilation=self.dilation, groups=self.groups * B)
else:
output = ms.ops.conv2d(x, weight=aggregate_weight, bias=None, stride=self.stride, pad_mode="pad",
padding=self.padding, dilation=self.dilation, groups=self.groups * B)

output = output.view(B, self.out_channels, H, W)
return output


if __name__ == "__main__":
in_tensor = ms.ops.randn((2, 32, 64, 64))
cconv = CondConv(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
out = cconv(in_tensor)
print(out.shape)

+ 39
- 0
model/conv/DepthwiseSeparableConvolution.py View File

@@ -0,0 +1,39 @@
""" Depthwise and Separable Convolution """
import mindspore as ms
from mindspore import nn


class DepthwiseSeparableConvolution(nn.Cell):
""" DepthwiseSeparableConvolution """
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1):
super().__init__()

self.depthwise_conv = nn.Conv2d(in_channels=in_channels,
out_channels=in_channels,
kernel_size=kernel_size,
stride=stride,
pad_mode='pad',
padding=padding)

self.pointwise_conv = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
group=1)

def construct(self, x):
x = self.depthwise_conv(x)
out = self.pointwise_conv(x)
return out


if __name__ == '__main__':
in_tensor = ms.ops.randn((1, 3, 224, 224), dtype=ms.float32)
conv = DepthwiseSeparableConvolution(3, 64)
output = conv(in_tensor)
print(output.shape)

Loading…
Cancel
Save