|
|
|
@@ -1,30 +0,0 @@ |
|
|
|
""" |
|
|
|
MindSpore implementation of 'sMLP' |
|
|
|
Refer to "Sparse MLP for Image Recognition: Is Self-Attention Really Necessary?" |
|
|
|
""" |
|
|
|
import mindspore as ms |
|
|
|
from mindspore import nn |
|
|
|
|
|
|
|
|
|
|
|
class sMLPBlock(nn.Cell): |
|
|
|
""" sMLPBlock """ |
|
|
|
def __init__(self, h=224, w=224, c=3): |
|
|
|
super().__init__() |
|
|
|
self.proj_h = nn.Dense(h, h) |
|
|
|
self.proj_w = nn.Dense(w, w) |
|
|
|
self.fuse = nn.Dense(3 * c, c) |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
x_h = self.proj_h(x.permute(0, 1, 3, 2)).permute(0, 1, 3, 2) |
|
|
|
x_w = self.proj_w(x) |
|
|
|
x_id = x |
|
|
|
x_fuse = ms.ops.cat([x_h, x_w, x_id], axis=1) |
|
|
|
out = self.fuse(x_fuse.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) |
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
dummy_input = ms.ops.randn((50, 3, 224, 224)) |
|
|
|
smlp = sMLPBlock(h=224, w=224) |
|
|
|
output = smlp(dummy_input) |
|
|
|
print(output.shape) |