Browse Source

resmlp

pull/4/head
Huyf9 2 years ago
parent
commit
7794e80cc8
1 changed files with 93 additions and 0 deletions
  1. +93
    -0
      model/mlp/resmlp.py

+ 93
- 0
model/mlp/resmlp.py View File

@@ -0,0 +1,93 @@
"""
MindSpore implementation of 'resmlp'
Refer to "ResMLP: Feedforward networks for image classification with data-efficient training"
"""
import mindspore as ms
from mindspore import nn


class Rearrange(nn.Cell):
""" rearrange """
def __init__(self, image_size=14, patch_size=7):
super().__init__()
self.h = patch_size
self.w = patch_size
self.nw = image_size // patch_size
self.nh = image_size // patch_size

def construct(self, x):
B, C, _, _ = x.shape
out = x.reshape(B, C, self.h, self.nh, self.w, self.nw)
out = out.permute(0, 3, 5, 2, 4, 1)
out = out.view(B, self.nh * self.nw, -1)
return out


class Affine(nn.Cell):
""" affine """
def __init__(self, channels):
super().__init__()
self.g = ms.Parameter(ms.ops.ones((1, 1, channels)))
self.b = ms.Parameter(ms.ops.zeros((1, 1, channels)))

def construct(self, x):
return x * self.g * self.b


class PreAffinePostLayerScale(nn.Cell):
""" pre affine post layer scale """
def __init__(self, dim, depth, fn):
super().__init__()
if depth <= 18:
init_eps = 0.1
elif 18 < depth <= 24:
init_eps = 1e-5
else:
init_eps = 1e-6

scale = ms.ops.fill(ms.float32, (1, 1, dim), init_eps)
self.scale = ms.Parameter(scale)
self.affine = Affine(dim)
self.fn = fn

def construct(self, x):
return self.fn(self.affine(x)) * self.scale + x


class ResMLP(nn.Cell):
""" resmlp """
def __init__(self, dim=128, image_size=14, patch_size=7, expansion_factor=4, depth=4, class_num=1000):
super().__init__()
self.flatten = Rearrange(image_size, patch_size)
# num_patches = (image_size // patch_size) ** 2

def wrapper(idx, fn):
return PreAffinePostLayerScale(dim, idx + 1, fn)
# wrapper = lambda i, fn: PreAffinePostLayerScale(dim, i + 1, fn)
self.embedding = nn.Dense((patch_size ** 2) * 3, dim)
self.mlp = nn.SequentialCell()
for i in range(depth):
self.mlp.insert_child_to_cell(f'fc1_{i}', wrapper(i, nn.Conv1d(patch_size ** 2, patch_size ** 2, 1)))
self.mlp.insert_child_to_cell(f'fc1_{i}', wrapper(i, nn.SequentialCell(
nn.Dense(dim, dim * expansion_factor),
nn.GELU(),
nn.Dense(dim * expansion_factor, dim)
)))
self.aff = Affine(dim)
self.classifier = nn.Dense(dim, class_num)
self.softmax = nn.Softmax(1)

def construct(self, x):
y = self.flatten(x)
y = self.embedding(y)
y = self.mlp(y)
y = self.aff(y)
y = y.mean(axis=1)
return self.softmax(self.classifier(y))


if __name__ == "__main__":
dummy_input = ms.ops.randn((50, 3, 14, 14))
resmlp = ResMLP(dim=128, image_size=14, patch_size=7, class_num=1000)
output = resmlp(dummy_input)
print(output.shape)

Loading…
Cancel
Save