Browse Source

update ci_pipeline (#1)

* Update ci_pipeline.yaml

* update ci_pipeline

* update ci_pipeline

* update pylint

* A2Attention

* ACmixAttention

* CondConv

* DepthwiseSeparableConvolution
v1
Huyf9 GitHub 2 years ago
parent
commit
e89894d48e
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 280 additions and 16 deletions
  1. +2
    -1
      .github/pylint.conf
  2. +4
    -4
      .github/workflows/ci_pipeline.yaml
  3. +18
    -11
      model/attention/A2Attention.py
  4. +127
    -0
      model/attention/ACmixAttention.py
  5. +90
    -0
      model/conv/CondConv.py
  6. +39
    -0
      model/conv/DepthwiseSeparableConvolution.py

+ 2
- 1
.github/pylint.conf View File

@@ -161,7 +161,8 @@ disable=raw-checker-failed,
too-few-public-methods, too-few-public-methods,
no-member, no-member,
protected-access, protected-access,
abstract-method
abstract-method,
C0103


# Enable the message, report, category or checker with the given id(s). You can # Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option # either give multiple identifier separated by comma (,) or put this option


+ 4
- 4
.github/workflows/ci_pipeline.yaml View File

@@ -7,12 +7,12 @@ on:
pull_request: pull_request:
branches: [ "main" ] branches: [ "main" ]
paths: paths:
- 'External-Attention-MindSpore/**'
- 'model/**'
- '.github/workflows/**' - '.github/workflows/**'
push: push:
branches: [ "main" ] branches: [ "main" ]
paths: paths:
- 'External-Attention-MindSpore/**'
- 'model/**'


permissions: permissions:
contents: read contents: read
@@ -38,9 +38,9 @@ jobs:
# run: | # run: |
# python .github/install_mindspore.py # python .github/install_mindspore.py
# pip install -r download.txt # pip install -r download.txt
- name: Analysing the External-Attention-MindSpore code with pylint
- name: Analysing the model code with pylint
run: | run: |
pylint External-Attention-MindSpore --rcfile=.github/pylint.conf
pylint model --rcfile=.github/pylint.conf
# - name: Analysing the tests code with pylint # - name: Analysing the tests code with pylint
# run: | # run: |
# pylint tests --rcfile=.github/pylint.conf # pylint tests --rcfile=.github/pylint.conf

+ 18
- 11
model/attention/A2Attention.py View File

@@ -1,9 +1,16 @@
"""
DoubleAttention
"""

import mindspore as ms import mindspore as ms
from mindspore import nn from mindspore import nn
from mindspore.common.initializer import initializer, HeNormal, Normal from mindspore.common.initializer import initializer, HeNormal, Normal


class DoubleAttention(nn.Cell): class DoubleAttention(nn.Cell):
"""
Double Attention
"""
def __init__(self, in_channels, c_m, c_n, reconstruct=True): def __init__(self, in_channels, c_m, c_n, reconstruct=True):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
@@ -21,6 +28,7 @@ class DoubleAttention(nn.Cell):
self.apply(self.init_weights) self.apply(self.init_weights)


def init_weights(self, cell): def init_weights(self, cell):
""" init weight """
if isinstance(cell, nn.Conv2d): if isinstance(cell, nn.Conv2d):
cell.weight.set_data(initializer(HeNormal(mode='fan_out'), cell.weight.shape, cell.weight.dtype)) cell.weight.set_data(initializer(HeNormal(mode='fan_out'), cell.weight.shape, cell.weight.dtype))
if cell.bias is not None: if cell.bias is not None:
@@ -34,28 +42,27 @@ class DoubleAttention(nn.Cell):
cell.bias.set_data(initializer('zeros', cell.bias.shape, cell.bias.dtype)) cell.bias.set_data(initializer('zeros', cell.bias.shape, cell.bias.dtype))


def construct(self, x): def construct(self, x):
B, C, H, W = x.shape
assert C == self.in_channels
b, c, h, w = x.shape
assert c == self.in_channels
a = self.convA(x) # b, c_m, h, w a = self.convA(x) # b, c_m, h, w
b = self.convB(x) # b, c_n, h, w b = self.convB(x) # b, c_n, h, w
v = self.convV(x) # b, c_n, h, w v = self.convV(x) # b, c_n, h, w
tmpA = a.view(B, self.c_m, -1)
attention_maps = ms.ops.softmax(b.view(B, self.c_n, -1))
attention_vectors = ms.ops.softmax(v.view(B, self.c_n, -1))
tmpA = a.view(b, self.c_m, -1)
attention_maps = ms.ops.softmax(b.view(b, self.c_n, -1))
attention_vectors = ms.ops.softmax(v.view(b, self.c_n, -1))


global_descriptors = ms.ops.bmm(tmpA, attention_maps.permute(0, 2, 1)) global_descriptors = ms.ops.bmm(tmpA, attention_maps.permute(0, 2, 1))


tmpZ = ms.ops.matmul(global_descriptors, attention_vectors) tmpZ = ms.ops.matmul(global_descriptors, attention_vectors)
tmpZ = tmpZ.view(B, self.c_m, H, W)
tmpZ = tmpZ.view(b, self.c_m, h, w)
if self.reconstruct: if self.reconstruct:
tmpZ = self.conv_reconstruct(tmpZ) tmpZ = self.conv_reconstruct(tmpZ)


return tmpZ return tmpZ


if __name__ == "__main__": if __name__ == "__main__":
input = ms.ops.randn([12, 512, 7, 7])
in_tensor = ms.ops.randn([12, 512, 7, 7])
a2 = DoubleAttention(512, 128, 128) a2 = DoubleAttention(512, 128, 128)
output = a2(input)
output = a2(in_tensor)
print(output.shape) print(output.shape)


+ 127
- 0
model/attention/ACmixAttention.py View File

@@ -0,0 +1,127 @@
""" ACmix Attention """
import mindspore as ms
from mindspore import nn


def position(H, W):
""" get position encode """
loc_w = ms.ops.linspace(-1., 1., W).unsqueeze(0).repeat(H, axis=0)
loc_h = ms.ops.linspace(-1., 1., H).unsqueeze(1).repeat(W, axis=1)
loc = ms.ops.cat([loc_w.unsqueeze(0), loc_h.unsqueeze(0)], 0)
loc = loc.reshape(-1, *loc.shape)
# print(loc)
return loc


def stride(x, strides):
""" split x with strides in last two dimension """
# B, C, H, W = x.shape
return x[:, :, ::strides, ::strides]


def init_rate_half(tensor):
""" fill data with 0.5 """
if isinstance(tensor, ms.Parameter):
fill_data = ms.ops.ones(tensor.shape) * 0.5
tensor.set_dtype(fill_data.dtype)
tensor.data.set_data(fill_data)
return tensor


def init_rate_0(tensor):
""" fill data with 0 """
if isinstance(tensor, ms.Parameter):
fill_data = ms.ops.zeros(tensor.shape)
tensor.set_dtype(fill_data.dtype)
tensor.data.set_data(fill_data)
return tensor


class ACmix(nn.Cell):
""" ACmix """
def __init__(self, in_planes, out_planes, kernel_att=7, head=4, kernel_conv=3, strides=1, dilation=1):
super().__init__()
self.in_planes = in_planes
self.out_planes = out_planes
self.head = head
self.kernel_att = kernel_att
self.kernel_conv = kernel_conv
self.stride = strides
self.dilation = dilation
self.rate1 = ms.Parameter(ms.Tensor(1))
self.rate2 = ms.Parameter(ms.Tensor(1))
self.head_dim = self.out_planes // self.head

self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1)
self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1)
self.conv3 = nn.Conv2d(in_planes, out_planes, kernel_size=1)
self.conv_p = nn.Conv2d(2, self.head_dim, kernel_size=1)

self.padding_att = (self.dilation * (self.kernel_att - 1) + 1) // 2
self.pad_att = nn.ReflectionPad2d(self.padding_att)
self.unfold = nn.Unfold(ksizes=[1, self.kernel_att, self.kernel_att, 1],
strides=[1, self.stride, self.stride, 1],
rates=[1, 1, 1, 1])
self.softmax = nn.Softmax(axis=1)

self.fc = nn.Conv2d(3 * self.head, self.kernel_conv ** 2, kernel_size=1)
self.dep_conv = nn.Conv2d(self.kernel_conv ** 2 * self.head_dim, out_planes, kernel_size=self.kernel_conv,
stride=strides, pad_mode='pad', padding=1, has_bias=True, group=self.head_dim)

self.reset_parameters()

def reset_parameters(self):
""" reset parameters """
init_rate_half(self.rate1)
init_rate_half(self.rate2)
kernel = ms.ops.zeros((self.kernel_conv ** 2, self.kernel_conv, self.kernel_conv))
for i in range(self.kernel_conv ** 2):
kernel[i, i // self.kernel_conv, i % self.kernel_conv] = 1.
kernel = kernel.reshape(1, *kernel.shape).repeat(self.out_planes, axis=0)
self.dep_conv.weight = ms.Parameter(default_input=kernel, requires_grad=True)
self.dep_conv.bias = init_rate_0(self.dep_conv.bias)

def construct(self, x):
q, k, v = self.conv1(x), self.conv2(x), self.conv3(x)
scaling = float(self.head_dim) ** 0.5
B, _, H, W = q.shape
h_out, w_out = H // self.stride, W // self.stride

pe = self.conv_p(position(H, W))

q_att = q.view(B * self.head, self.head_dim, H, W) * scaling
k_att = k.view(B * self.head, self.head_dim, H, W)
v_att = v.view(B * self.head, self.head_dim, H, W)

if self.stride > 1:
q_att = stride(q_att, self.stride)
q_pe = stride(pe, self.stride)
else:
q_pe = pe

unfold_k = self.unfold(self.pad_att(k_att)).view(B * self.head, self.head_dim,
self.kernel_att ** 2, h_out, w_out)
unfold_rpe = self.unfold(self.pad_att(pe)).view(1, self.head_dim, self.kernel_att ** 2,
h_out, w_out)

att = (q_att.unsqueeze(2) * (unfold_k + q_pe.unsqueeze(2) - unfold_rpe)).sum(1)
att = self.softmax(att)

out_att = self.unfold(self.pad_att(v_att)).view(B * self.head, self.head_dim, self.kernel_att ** 2,
h_out, w_out)
out_att = (att.unsqueeze(1) * out_att).sum(2).view(B, self.out_planes, h_out, w_out)

f_all = self.fc(
ms.ops.cat([q.view(B, self.head, self.head_dim, H * W), k.view(B, self.head, self.head_dim, H * W),
v.view(B, self.head, self.head_dim, H * W)], 1))
f_conv = f_all.permute(0, 2, 1, 3).reshape(x.shape[0], -1, x.shape[-2], x.shape[-1])

out_conv = self.dep_conv(f_conv)
return self.rate1 * out_att + self.rate2 * out_conv


if __name__ == "__main__":
in_tensor = ms.ops.randn((50, 256, 7, 7))
acmix = ACmix(in_planes=256, out_planes=256)
out = acmix(in_tensor)
print(out.shape)

+ 90
- 0
model/conv/CondConv.py View File

@@ -0,0 +1,90 @@
""" CondConv """
import mindspore as ms
from mindspore import nn
from mindspore.common.initializer import initializer, HeNormal, HeUniform


class Attention(nn.Cell):
""" Attnetion """
def __init__(self, in_channels, K, init_weight=True):
super().__init__()
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.net = nn.Conv2d(in_channels=in_channels, out_channels=K, kernel_size=1)
self.sigmoid = nn.Sigmoid()

if init_weight:
self.apply(self.init_weights)

def init_weights(self, cell):
""" initialize weights """
if isinstance(cell, nn.Conv2d):
cell.weight.set_data(initializer(HeNormal(mode='fan_out', nonlinearity='relu'),
cell.weight.shape, cell.weight.dtype))
if cell.bias is not None:
cell.bias.set_data(initializer('zeros', cell.bias.shape, cell.bias.dtype))
elif isinstance(cell, nn.BatchNorm2d):
cell.gamma.set_data(initializer('ones', cell.gamma.shape, cell.dtype))
cell.beta.set_data(initializer('zeros', cell.beta.shape, cell.beta.dtype))

def construct(self, x):
att = self.avgpool(x)
att = self.net(att).view(x.shape[0], -1)
return self.sigmoid(att)


class CondConv(nn.Cell):
""" CondConv """
def __init__(self, in_channels, out_channels, kernel_size, stride, padding=0,
dilation=1, groups=1, bias=True, K=4, init_weight=True):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.bias = bias
self.K = K
self.attention = Attention(in_channels, K)

self.weight = ms.Parameter(ms.ops.randn(K, out_channels, in_channels // groups, kernel_size, kernel_size),
requires_grad=True)
if bias:
self.bias = ms.Parameter(ms.ops.randn(K, out_channels), requires_grad=True)
else:
self.bias = None

if init_weight:
self.init_weights()

def init_weights(self):
""" initialize weights """
for i in range(self.K):
self.weight[i] = initializer(HeUniform(), self.weight[i].shape, self.weight[i].dtype)

def construct(self, x):
B, _, H, W = x.shape
softmax_att = self.attention(x)
x = x.view(1, -1, H, W)
weight = self.weight.view(self.K, -1)
aggregate_weight = ms.ops.mm(softmax_att, weight).view(B * self.out_channels, self.in_channels // self.groups,
self.kernel_size, self.kernel_size)
if self.bias:
bias = self.bias.view(self.K, -1)
aggregate_bias = ms.ops.mm(softmax_att, bias).view(-1)
output = ms.ops.conv2d(x, weight=aggregate_weight, bias=aggregate_bias, stride=self.stride, pad_mode="pad",
padding=self.padding, dilation=self.dilation, groups=self.groups * B)
else:
output = ms.ops.conv2d(x, weight=aggregate_weight, bias=None, stride=self.stride, pad_mode="pad",
padding=self.padding, dilation=self.dilation, groups=self.groups * B)

output = output.view(B, self.out_channels, H, W)
return output


if __name__ == "__main__":
in_tensor = ms.ops.randn((2, 32, 64, 64))
cconv = CondConv(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
out = cconv(in_tensor)
print(out.shape)

+ 39
- 0
model/conv/DepthwiseSeparableConvolution.py View File

@@ -0,0 +1,39 @@
""" Depthwise and Separable Convolution """
import mindspore as ms
from mindspore import nn


class DepthwiseSeparableConvolution(nn.Cell):
""" DepthwiseSeparableConvolution """
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1):
super().__init__()

self.depthwise_conv = nn.Conv2d(in_channels=in_channels,
out_channels=in_channels,
kernel_size=kernel_size,
stride=stride,
pad_mode='pad',
padding=padding)

self.pointwise_conv = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
group=1)

def construct(self, x):
x = self.depthwise_conv(x)
out = self.pointwise_conv(x)
return out


if __name__ == '__main__':
in_tensor = ms.ops.randn((1, 3, 224, 224), dtype=ms.float32)
conv = DepthwiseSeparableConvolution(3, 64)
output = conv(in_tensor)
print(output.shape)

Loading…
Cancel
Save