* Update ci_pipeline.yaml * update ci_pipeline * update ci_pipeline * update pylint * A2Attention * ACmixAttention * CondConv * DepthwiseSeparableConvolutionv1
| @@ -161,7 +161,8 @@ disable=raw-checker-failed, | |||||
| too-few-public-methods, | too-few-public-methods, | ||||
| no-member, | no-member, | ||||
| protected-access, | protected-access, | ||||
| abstract-method | |||||
| abstract-method, | |||||
| C0103 | |||||
| # Enable the message, report, category or checker with the given id(s). You can | # Enable the message, report, category or checker with the given id(s). You can | ||||
| # either give multiple identifier separated by comma (,) or put this option | # either give multiple identifier separated by comma (,) or put this option | ||||
| @@ -7,12 +7,12 @@ on: | |||||
| pull_request: | pull_request: | ||||
| branches: [ "main" ] | branches: [ "main" ] | ||||
| paths: | paths: | ||||
| - 'External-Attention-MindSpore/**' | |||||
| - 'model/**' | |||||
| - '.github/workflows/**' | - '.github/workflows/**' | ||||
| push: | push: | ||||
| branches: [ "main" ] | branches: [ "main" ] | ||||
| paths: | paths: | ||||
| - 'External-Attention-MindSpore/**' | |||||
| - 'model/**' | |||||
| permissions: | permissions: | ||||
| contents: read | contents: read | ||||
| @@ -38,9 +38,9 @@ jobs: | |||||
| # run: | | # run: | | ||||
| # python .github/install_mindspore.py | # python .github/install_mindspore.py | ||||
| # pip install -r download.txt | # pip install -r download.txt | ||||
| - name: Analysing the External-Attention-MindSpore code with pylint | |||||
| - name: Analysing the model code with pylint | |||||
| run: | | run: | | ||||
| pylint External-Attention-MindSpore --rcfile=.github/pylint.conf | |||||
| pylint model --rcfile=.github/pylint.conf | |||||
| # - name: Analysing the tests code with pylint | # - name: Analysing the tests code with pylint | ||||
| # run: | | # run: | | ||||
| # pylint tests --rcfile=.github/pylint.conf | # pylint tests --rcfile=.github/pylint.conf | ||||
| @@ -1,9 +1,16 @@ | |||||
| """ | |||||
| DoubleAttention | |||||
| """ | |||||
| import mindspore as ms | import mindspore as ms | ||||
| from mindspore import nn | from mindspore import nn | ||||
| from mindspore.common.initializer import initializer, HeNormal, Normal | from mindspore.common.initializer import initializer, HeNormal, Normal | ||||
| class DoubleAttention(nn.Cell): | class DoubleAttention(nn.Cell): | ||||
| """ | |||||
| Double Attention | |||||
| """ | |||||
| def __init__(self, in_channels, c_m, c_n, reconstruct=True): | def __init__(self, in_channels, c_m, c_n, reconstruct=True): | ||||
| super().__init__() | super().__init__() | ||||
| self.in_channels = in_channels | self.in_channels = in_channels | ||||
| @@ -21,6 +28,7 @@ class DoubleAttention(nn.Cell): | |||||
| self.apply(self.init_weights) | self.apply(self.init_weights) | ||||
| def init_weights(self, cell): | def init_weights(self, cell): | ||||
| """ init weight """ | |||||
| if isinstance(cell, nn.Conv2d): | if isinstance(cell, nn.Conv2d): | ||||
| cell.weight.set_data(initializer(HeNormal(mode='fan_out'), cell.weight.shape, cell.weight.dtype)) | cell.weight.set_data(initializer(HeNormal(mode='fan_out'), cell.weight.shape, cell.weight.dtype)) | ||||
| if cell.bias is not None: | if cell.bias is not None: | ||||
| @@ -34,28 +42,27 @@ class DoubleAttention(nn.Cell): | |||||
| cell.bias.set_data(initializer('zeros', cell.bias.shape, cell.bias.dtype)) | cell.bias.set_data(initializer('zeros', cell.bias.shape, cell.bias.dtype)) | ||||
| def construct(self, x): | def construct(self, x): | ||||
| B, C, H, W = x.shape | |||||
| assert C == self.in_channels | |||||
| b, c, h, w = x.shape | |||||
| assert c == self.in_channels | |||||
| a = self.convA(x) # b, c_m, h, w | a = self.convA(x) # b, c_m, h, w | ||||
| b = self.convB(x) # b, c_n, h, w | b = self.convB(x) # b, c_n, h, w | ||||
| v = self.convV(x) # b, c_n, h, w | v = self.convV(x) # b, c_n, h, w | ||||
| tmpA = a.view(B, self.c_m, -1) | |||||
| attention_maps = ms.ops.softmax(b.view(B, self.c_n, -1)) | |||||
| attention_vectors = ms.ops.softmax(v.view(B, self.c_n, -1)) | |||||
| tmpA = a.view(b, self.c_m, -1) | |||||
| attention_maps = ms.ops.softmax(b.view(b, self.c_n, -1)) | |||||
| attention_vectors = ms.ops.softmax(v.view(b, self.c_n, -1)) | |||||
| global_descriptors = ms.ops.bmm(tmpA, attention_maps.permute(0, 2, 1)) | global_descriptors = ms.ops.bmm(tmpA, attention_maps.permute(0, 2, 1)) | ||||
| tmpZ = ms.ops.matmul(global_descriptors, attention_vectors) | tmpZ = ms.ops.matmul(global_descriptors, attention_vectors) | ||||
| tmpZ = tmpZ.view(B, self.c_m, H, W) | |||||
| tmpZ = tmpZ.view(b, self.c_m, h, w) | |||||
| if self.reconstruct: | if self.reconstruct: | ||||
| tmpZ = self.conv_reconstruct(tmpZ) | tmpZ = self.conv_reconstruct(tmpZ) | ||||
| return tmpZ | return tmpZ | ||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| input = ms.ops.randn([12, 512, 7, 7]) | |||||
| in_tensor = ms.ops.randn([12, 512, 7, 7]) | |||||
| a2 = DoubleAttention(512, 128, 128) | a2 = DoubleAttention(512, 128, 128) | ||||
| output = a2(input) | |||||
| output = a2(in_tensor) | |||||
| print(output.shape) | print(output.shape) | ||||
| @@ -0,0 +1,127 @@ | |||||
| """ ACmix Attention """ | |||||
| import mindspore as ms | |||||
| from mindspore import nn | |||||
| def position(H, W): | |||||
| """ get position encode """ | |||||
| loc_w = ms.ops.linspace(-1., 1., W).unsqueeze(0).repeat(H, axis=0) | |||||
| loc_h = ms.ops.linspace(-1., 1., H).unsqueeze(1).repeat(W, axis=1) | |||||
| loc = ms.ops.cat([loc_w.unsqueeze(0), loc_h.unsqueeze(0)], 0) | |||||
| loc = loc.reshape(-1, *loc.shape) | |||||
| # print(loc) | |||||
| return loc | |||||
| def stride(x, strides): | |||||
| """ split x with strides in last two dimension """ | |||||
| # B, C, H, W = x.shape | |||||
| return x[:, :, ::strides, ::strides] | |||||
| def init_rate_half(tensor): | |||||
| """ fill data with 0.5 """ | |||||
| if isinstance(tensor, ms.Parameter): | |||||
| fill_data = ms.ops.ones(tensor.shape) * 0.5 | |||||
| tensor.set_dtype(fill_data.dtype) | |||||
| tensor.data.set_data(fill_data) | |||||
| return tensor | |||||
| def init_rate_0(tensor): | |||||
| """ fill data with 0 """ | |||||
| if isinstance(tensor, ms.Parameter): | |||||
| fill_data = ms.ops.zeros(tensor.shape) | |||||
| tensor.set_dtype(fill_data.dtype) | |||||
| tensor.data.set_data(fill_data) | |||||
| return tensor | |||||
| class ACmix(nn.Cell): | |||||
| """ ACmix """ | |||||
| def __init__(self, in_planes, out_planes, kernel_att=7, head=4, kernel_conv=3, strides=1, dilation=1): | |||||
| super().__init__() | |||||
| self.in_planes = in_planes | |||||
| self.out_planes = out_planes | |||||
| self.head = head | |||||
| self.kernel_att = kernel_att | |||||
| self.kernel_conv = kernel_conv | |||||
| self.stride = strides | |||||
| self.dilation = dilation | |||||
| self.rate1 = ms.Parameter(ms.Tensor(1)) | |||||
| self.rate2 = ms.Parameter(ms.Tensor(1)) | |||||
| self.head_dim = self.out_planes // self.head | |||||
| self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1) | |||||
| self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1) | |||||
| self.conv3 = nn.Conv2d(in_planes, out_planes, kernel_size=1) | |||||
| self.conv_p = nn.Conv2d(2, self.head_dim, kernel_size=1) | |||||
| self.padding_att = (self.dilation * (self.kernel_att - 1) + 1) // 2 | |||||
| self.pad_att = nn.ReflectionPad2d(self.padding_att) | |||||
| self.unfold = nn.Unfold(ksizes=[1, self.kernel_att, self.kernel_att, 1], | |||||
| strides=[1, self.stride, self.stride, 1], | |||||
| rates=[1, 1, 1, 1]) | |||||
| self.softmax = nn.Softmax(axis=1) | |||||
| self.fc = nn.Conv2d(3 * self.head, self.kernel_conv ** 2, kernel_size=1) | |||||
| self.dep_conv = nn.Conv2d(self.kernel_conv ** 2 * self.head_dim, out_planes, kernel_size=self.kernel_conv, | |||||
| stride=strides, pad_mode='pad', padding=1, has_bias=True, group=self.head_dim) | |||||
| self.reset_parameters() | |||||
| def reset_parameters(self): | |||||
| """ reset parameters """ | |||||
| init_rate_half(self.rate1) | |||||
| init_rate_half(self.rate2) | |||||
| kernel = ms.ops.zeros((self.kernel_conv ** 2, self.kernel_conv, self.kernel_conv)) | |||||
| for i in range(self.kernel_conv ** 2): | |||||
| kernel[i, i // self.kernel_conv, i % self.kernel_conv] = 1. | |||||
| kernel = kernel.reshape(1, *kernel.shape).repeat(self.out_planes, axis=0) | |||||
| self.dep_conv.weight = ms.Parameter(default_input=kernel, requires_grad=True) | |||||
| self.dep_conv.bias = init_rate_0(self.dep_conv.bias) | |||||
| def construct(self, x): | |||||
| q, k, v = self.conv1(x), self.conv2(x), self.conv3(x) | |||||
| scaling = float(self.head_dim) ** 0.5 | |||||
| B, _, H, W = q.shape | |||||
| h_out, w_out = H // self.stride, W // self.stride | |||||
| pe = self.conv_p(position(H, W)) | |||||
| q_att = q.view(B * self.head, self.head_dim, H, W) * scaling | |||||
| k_att = k.view(B * self.head, self.head_dim, H, W) | |||||
| v_att = v.view(B * self.head, self.head_dim, H, W) | |||||
| if self.stride > 1: | |||||
| q_att = stride(q_att, self.stride) | |||||
| q_pe = stride(pe, self.stride) | |||||
| else: | |||||
| q_pe = pe | |||||
| unfold_k = self.unfold(self.pad_att(k_att)).view(B * self.head, self.head_dim, | |||||
| self.kernel_att ** 2, h_out, w_out) | |||||
| unfold_rpe = self.unfold(self.pad_att(pe)).view(1, self.head_dim, self.kernel_att ** 2, | |||||
| h_out, w_out) | |||||
| att = (q_att.unsqueeze(2) * (unfold_k + q_pe.unsqueeze(2) - unfold_rpe)).sum(1) | |||||
| att = self.softmax(att) | |||||
| out_att = self.unfold(self.pad_att(v_att)).view(B * self.head, self.head_dim, self.kernel_att ** 2, | |||||
| h_out, w_out) | |||||
| out_att = (att.unsqueeze(1) * out_att).sum(2).view(B, self.out_planes, h_out, w_out) | |||||
| f_all = self.fc( | |||||
| ms.ops.cat([q.view(B, self.head, self.head_dim, H * W), k.view(B, self.head, self.head_dim, H * W), | |||||
| v.view(B, self.head, self.head_dim, H * W)], 1)) | |||||
| f_conv = f_all.permute(0, 2, 1, 3).reshape(x.shape[0], -1, x.shape[-2], x.shape[-1]) | |||||
| out_conv = self.dep_conv(f_conv) | |||||
| return self.rate1 * out_att + self.rate2 * out_conv | |||||
| if __name__ == "__main__": | |||||
| in_tensor = ms.ops.randn((50, 256, 7, 7)) | |||||
| acmix = ACmix(in_planes=256, out_planes=256) | |||||
| out = acmix(in_tensor) | |||||
| print(out.shape) | |||||
| @@ -0,0 +1,90 @@ | |||||
| """ CondConv """ | |||||
| import mindspore as ms | |||||
| from mindspore import nn | |||||
| from mindspore.common.initializer import initializer, HeNormal, HeUniform | |||||
| class Attention(nn.Cell): | |||||
| """ Attnetion """ | |||||
| def __init__(self, in_channels, K, init_weight=True): | |||||
| super().__init__() | |||||
| self.avgpool = nn.AdaptiveAvgPool2d(1) | |||||
| self.net = nn.Conv2d(in_channels=in_channels, out_channels=K, kernel_size=1) | |||||
| self.sigmoid = nn.Sigmoid() | |||||
| if init_weight: | |||||
| self.apply(self.init_weights) | |||||
| def init_weights(self, cell): | |||||
| """ initialize weights """ | |||||
| if isinstance(cell, nn.Conv2d): | |||||
| cell.weight.set_data(initializer(HeNormal(mode='fan_out', nonlinearity='relu'), | |||||
| cell.weight.shape, cell.weight.dtype)) | |||||
| if cell.bias is not None: | |||||
| cell.bias.set_data(initializer('zeros', cell.bias.shape, cell.bias.dtype)) | |||||
| elif isinstance(cell, nn.BatchNorm2d): | |||||
| cell.gamma.set_data(initializer('ones', cell.gamma.shape, cell.dtype)) | |||||
| cell.beta.set_data(initializer('zeros', cell.beta.shape, cell.beta.dtype)) | |||||
| def construct(self, x): | |||||
| att = self.avgpool(x) | |||||
| att = self.net(att).view(x.shape[0], -1) | |||||
| return self.sigmoid(att) | |||||
| class CondConv(nn.Cell): | |||||
| """ CondConv """ | |||||
| def __init__(self, in_channels, out_channels, kernel_size, stride, padding=0, | |||||
| dilation=1, groups=1, bias=True, K=4, init_weight=True): | |||||
| super().__init__() | |||||
| self.in_channels = in_channels | |||||
| self.out_channels = out_channels | |||||
| self.kernel_size = kernel_size | |||||
| self.stride = stride | |||||
| self.padding = padding | |||||
| self.dilation = dilation | |||||
| self.groups = groups | |||||
| self.bias = bias | |||||
| self.K = K | |||||
| self.attention = Attention(in_channels, K) | |||||
| self.weight = ms.Parameter(ms.ops.randn(K, out_channels, in_channels // groups, kernel_size, kernel_size), | |||||
| requires_grad=True) | |||||
| if bias: | |||||
| self.bias = ms.Parameter(ms.ops.randn(K, out_channels), requires_grad=True) | |||||
| else: | |||||
| self.bias = None | |||||
| if init_weight: | |||||
| self.init_weights() | |||||
| def init_weights(self): | |||||
| """ initialize weights """ | |||||
| for i in range(self.K): | |||||
| self.weight[i] = initializer(HeUniform(), self.weight[i].shape, self.weight[i].dtype) | |||||
| def construct(self, x): | |||||
| B, _, H, W = x.shape | |||||
| softmax_att = self.attention(x) | |||||
| x = x.view(1, -1, H, W) | |||||
| weight = self.weight.view(self.K, -1) | |||||
| aggregate_weight = ms.ops.mm(softmax_att, weight).view(B * self.out_channels, self.in_channels // self.groups, | |||||
| self.kernel_size, self.kernel_size) | |||||
| if self.bias: | |||||
| bias = self.bias.view(self.K, -1) | |||||
| aggregate_bias = ms.ops.mm(softmax_att, bias).view(-1) | |||||
| output = ms.ops.conv2d(x, weight=aggregate_weight, bias=aggregate_bias, stride=self.stride, pad_mode="pad", | |||||
| padding=self.padding, dilation=self.dilation, groups=self.groups * B) | |||||
| else: | |||||
| output = ms.ops.conv2d(x, weight=aggregate_weight, bias=None, stride=self.stride, pad_mode="pad", | |||||
| padding=self.padding, dilation=self.dilation, groups=self.groups * B) | |||||
| output = output.view(B, self.out_channels, H, W) | |||||
| return output | |||||
| if __name__ == "__main__": | |||||
| in_tensor = ms.ops.randn((2, 32, 64, 64)) | |||||
| cconv = CondConv(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) | |||||
| out = cconv(in_tensor) | |||||
| print(out.shape) | |||||
| @@ -0,0 +1,39 @@ | |||||
| """ Depthwise and Separable Convolution """ | |||||
| import mindspore as ms | |||||
| from mindspore import nn | |||||
| class DepthwiseSeparableConvolution(nn.Cell): | |||||
| """ DepthwiseSeparableConvolution """ | |||||
| def __init__(self, | |||||
| in_channels, | |||||
| out_channels, | |||||
| kernel_size=3, | |||||
| stride=1, | |||||
| padding=1): | |||||
| super().__init__() | |||||
| self.depthwise_conv = nn.Conv2d(in_channels=in_channels, | |||||
| out_channels=in_channels, | |||||
| kernel_size=kernel_size, | |||||
| stride=stride, | |||||
| pad_mode='pad', | |||||
| padding=padding) | |||||
| self.pointwise_conv = nn.Conv2d(in_channels=in_channels, | |||||
| out_channels=out_channels, | |||||
| kernel_size=1, | |||||
| stride=1, | |||||
| group=1) | |||||
| def construct(self, x): | |||||
| x = self.depthwise_conv(x) | |||||
| out = self.pointwise_conv(x) | |||||
| return out | |||||
| if __name__ == '__main__': | |||||
| in_tensor = ms.ops.randn((1, 3, 224, 224), dtype=ms.float32) | |||||
| conv = DepthwiseSeparableConvolution(3, 64) | |||||
| output = conv(in_tensor) | |||||
| print(output.shape) | |||||