From 11d80c8add646c7cfeb0069313e991a41415b287 Mon Sep 17 00:00:00 2001 From: Huyf9 <568558927@qq.com> Date: Wed, 23 Aug 2023 11:05:37 +0800 Subject: [PATCH] CoTAttention --- model/attention/CoTAttention.py | 53 +++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 model/attention/CoTAttention.py diff --git a/model/attention/CoTAttention.py b/model/attention/CoTAttention.py new file mode 100644 index 0000000..8b08250 --- /dev/null +++ b/model/attention/CoTAttention.py @@ -0,0 +1,53 @@ +""" +MindSpore implementation of 'CoTAttention' +Refer to "Contextual Transformer Networks for Visual Recognition" +""" +import mindspore as ms +from mindspore import nn + + +class CoTAttention(nn.Cell): + """ CoTAttention """ + def __init__(self, dim=512, kernel_size=3): + super().__init__() + self.kernel_size = kernel_size + + self.key_embed = nn.SequentialCell( + nn.Conv2d(dim, dim, kernel_size=kernel_size, pad_mode='pad', padding=kernel_size // 2, group=4), + nn.BatchNorm2d(dim), + nn.ReLU() + ) + + self.value_embed = nn.SequentialCell( + nn.Conv2d(dim, dim, kernel_size=1), + nn.BatchNorm2d(dim) + ) + + factor = 4 + self.attention_embed = nn.SequentialCell( + nn.Conv2d(2 * dim, 2 * dim // factor, kernel_size=1), + nn.BatchNorm2d(2 * dim // factor), + nn.ReLU(), + nn.Conv2d(2 * dim // factor, kernel_size * kernel_size * dim, kernel_size=1) + ) + + def construct(self, x): + B, C, H, W = x.shape + k1 = self.key_embed(x) + v = self.value_embed(x).view(B, C, -1) + + y = ms.ops.cat([k1, x], axis=1) + att = self.attention_embed(y) + att = att.reshape(B, C, self.kernel_size * self.kernel_size, H, W) + att = att.mean(2, keep_dims=False).view(B, C, -1) + k2 = ms.ops.softmax(att, axis=-1) * v + k2 = k2.view(B, C, H, W) + + return k1 + k2 + + +if __name__ == "__main__": + dummy_input = ms.ops.randn(50, 512, 7, 7) + cot = CoTAttention() + output = cot(dummy_input) + print(output.shape)