Browse Source

add HorNet

pull/2/head
Huyf9 2 years ago
parent
commit
5ed00832b4
1 changed files with 400 additions and 0 deletions
  1. +400
    -0
      model/conv/HorNet.py

+ 400
- 0
model/conv/HorNet.py View File

@@ -0,0 +1,400 @@
""" HorNet """
from functools import partial
import mindspore as ms
from mindspore import nn
from mindspore.common.initializer import initializer, TruncatedNormal, XavierUniform


def trunc_norm(tensor, std=.02):
""" truncated normalization
"""
return initializer(TruncatedNormal(sigma=std), tensor.shape, tensor.dtype)


def get_dwconv(dim, kernel, bias):
""" get dwconv
"""
return nn.Conv2d(dim, dim, kernel_size=kernel, pad_mode='pad', padding=(kernel - 1) // 2, has_bias=bias, group=dim)


class DropPath(nn.Cell):
""" Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""

def __init__(self, drop_prob, ndim=2):
super().__init__()
self.drop = nn.Dropout(keep_prob=1 - drop_prob)
shape = (1,) + (1,) * (ndim + 1)
self.ndim = ndim
self.mask = ms.Tensor(ms.ops.ones(shape), dtype=ms.float32)

def construct(self, x):
if not self.training:
return x
mask = ms.ops.Tile()(self.mask, (x.shape[0],) + (1,) * (self.ndim + 1))
out = self.drop(mask)
out = out * x
return out


class GlobalLocalFilter(nn.Cell):
""" GlobalLocalFilter
https://arxiv.org/abs/2207.14284
"""
def __init__(self, dim, h=14, w=8):
super().__init__()
self.dw = nn.Conv2d(dim // 2, dim // 2, kernel_size=3, pad_mode='pad', padding=1, group=dim // 2)
self.complex_weight = ms.Parameter(ms.ops.randn((dim // 2, h, w, 2), dtype=ms.float32) * 0.02)
self.complex_weight = trunc_norm(self.complex_weight)
self.pre_norm = LayerNorm(dim, eps=1e-6, data_format='channels_first')
self.post_norm = LayerNorm(dim, eps=1e-6, data_format='channels_first')

self.rfft = ms.ops.FFTWithSize(signal_ndim=2, inverse=False, real=True, norm='ortho', onesided=True,
signal_sizes=(2, 3))
self.irfft = ms.ops.FFTWithSize(signal_ndim=2, inverse=True, real=True, norm='ortho', onesided=True,
signal_sizes=(2, 3))

def construct(self, x):
x = self.pre_norm(x)
x1, x2 = ms.ops.chunk(x, 2, axis=1)
x1 = self.dw(x1)

x2 = x2.to(ms.float32)
B, C, H, W = x2.shape
x2 = self.rfft(x2)

weight = self.complex_weight
if not weight.shape[1:3] == x2.shape[2:4]:
weight = ms.ops.interpolate(weight.permute(3, 0, 1, 2), size=x2.shape[2:4], mode='bilinear',
align_corners=True).permute(1, 2, 3, 0)

weight = weight.to(ms.complex64)
x2 *= weight
x2 = self.irfft(x2)

x = ms.ops.cat([x1.unsqueeze(2), x2.unsqueeze(2)], axis=2).reshape(B, 2 * C, H, W)
x = self.post_norm(x)
return x


class Gnconv(nn.Cell):
"""gnconv
"""
def __init__(self, dim, order=5, gflayer=None, h=14, w=8, s=1.):
super().__init__()
self.order = order
self.dims = [dim // 2 ** i for i in range(order)]
self.dims.reverse()
self.proj_in = nn.Conv2d(dim, 2 * dim, 1)

if gflayer is None:
self.dwconv = get_dwconv(sum(self.dims), 7, True)
else:
self.dwconv = gflayer(sum(self.dims), h=h, w=w)

self.proj_out = nn.Conv2d(dim, dim, 1)

self.pws = nn.SequentialCell(
[nn.Conv2d(self.dims[i], self.dims[i + 1], 1) for i in range(order - 1)]
)
self.scale = s

def construct(self, x):
fused_x = self.proj_in(x)
pwa, abc = ms.ops.split(fused_x, (self.dims[0], sum(self.dims)), axis=1)

dw_abc = self.dwconv(abc) * self.scale

dw_list = ms.ops.split(dw_abc, self.dims, axis=1)
x = pwa * dw_list[0]

for i in range(self.order - 1):
x = self.pws[i](x) * dw_list[i + 1]

return self.proj_out(x)


class Block(nn.Cell):
"""HorNet's Block
"""
def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, gnconv=Gnconv):
super().__init__()
self.norm1 = LayerNorm(dim, eps=1e-6, data_format='channels_first')
self.gnconv = gnconv(dim)
self.norm2 = LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Dense(dim, 4 * dim)
self.act = nn.GELU()
self.pwconv2 = nn.Dense(4 * dim, dim)

self.gamma1 = ms.Parameter(layer_scale_init_value * ms.ops.ones(dim),
requires_grad=True) if layer_scale_init_value > 0 else None
self.gamma2 = ms.Parameter(layer_scale_init_value * ms.ops.ones((dim)),
requires_grad=True) if layer_scale_init_value > 0 else None
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

def construct(self, x):
_, C, _, _ = x.shape
if self.gamma1 is not None:
gamma1 = self.gamma1.view(C, 1, 1)
else:
gamma1 = 1
x += self.drop_path(gamma1 * self.gnconv(self.norm1(x)))

inputs = x
x = x.permute(0, 2, 3, 1)
x = self.norm2(x)

x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)

if self.gamma2 is not None:
x *= self.gamma2
x = x.permute(0, 3, 1, 2)
x = inputs + self.drop_path(x)
return x


class HorNet(nn.Cell):
"""HorNet
"""
def __init__(self, depths, in_chans=3, num_classes=1000, base_dim=96, drop_path_rate=0., layer_scale_init_value=1e-6,
head_init_scale=1., gnconv=Gnconv, block=Block, uniform_init=False):
super().__init__()
dims = [base_dim, base_dim * 2, base_dim * 4, base_dim * 8]
self.downsample_layers = nn.SequentialCell()
stem = nn.SequentialCell(
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
LayerNorm(dims[0], eps=1e-6, data_format='channels_first')
)
self.downsample_layers.append(stem)
for i in range(3):
dowmsample_layer = nn.SequentialCell(
LayerNorm(dims[i], eps=1e-6, data_format='channels_first'),
nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2)
)
self.downsample_layers.append(dowmsample_layer)

self.stages = nn.SequentialCell()
dp_rates = list(x for x in ms.ops.linspace(0, drop_path_rate, sum(depths)))

if not isinstance(gnconv, list):
gnconv = [gnconv, gnconv, gnconv, gnconv]
assert len(gnconv) == 4

cur = 0
for i in range(4):
stage = nn.SequentialCell(
*[block(dim=dims[i], drop_path=dp_rates[cur + j],
layer_scale_init_value=layer_scale_init_value,
gnconv=gnconv[i]) for j in range(depths[i])]
)
self.stages.append(stage)
cur += depths[i]

self.norm = nn.LayerNorm((dims[-1],), epsilon=1e-6)
self.head = nn.Dense(dims[-1], num_classes)
self.uniform_init = uniform_init

weight = self.head.weight * head_init_scale
self.head.weight.set_data(weight)

if self.head.bias is not None:
bias = self.head.bias * head_init_scale
self.head.bias.set_data(bias)

self.apply(self.init_weight)

def init_weight(self, cell):
""" init weight
"""
if not self.uniform_init:
if isinstance(cell, (nn.Conv2d, nn.Dense)):
cell.weight.set_data(trunc_norm(cell.weight))
if cell.bias is not None:
cell.bias.set_data(initializer('zeros', cell.bias.shape, cell.bias.dtype))
else:
if isinstance(cell, (nn.Conv2d, nn.Dense)):
cell.weight.set_data(initializer(XavierUniform(), cell.weight.shape, cell.weight.dtype))
if cell.bias is not None:
cell.bias.set_data(initializer('zeros', cell.bias.shape, cell.bias.dtype))

def construct_features(self, x):
"""construct features
"""
for i in range(4):
x = self.downsample_layers[i](x)
for blk in self.stages[i]:
x = blk(x)
return self.norm(x.mean([-2, -1]))

def construct(self, x):
x = self.construct_features(x)
return self.head(x)


class LayerNorm(nn.Cell):
"""LayerNorm
"""
def __init__(self, normalized_shape, eps=1e-6, data_format='channels_last'):
super().__init__()
self.weight = ms.Parameter(ms.ops.ones(normalized_shape))
self.bias = ms.Parameter(ms.ops.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ['channels_first', 'channels_last']:
raise NotImplementedError

self.normalized_shape = (normalized_shape,)
self.layer_norm = nn.LayerNorm(self.normalized_shape, epsilon=self.eps)

def construct(self, x):
if self.data_format == 'channels_last':
self.layer_norm.gamma.set_data(self.weight)
self.layer_norm.beta.set_data(self.bias)
x = self.layer_norm(x)
if self.data_format == 'channels_first':
u = x.mean(1, keep_dims=True)
s = (x - u).pow(2).mean(1, keep_dims=True)
x = (x - u) / ms.ops.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x


def hornet_tiny_7x7(**kwargs):
"""hornet_tiny_7x7
"""
s = 1 / 3.
model = HorNet(depths=[2, 3, 18, 2], base_dim=64, block=Block,
gnconv=[
partial(Gnconv, order=2, s=s),
partial(Gnconv, order=3, s=s),
partial(Gnconv, order=4, s=s),
partial(Gnconv, order=5, s=s)
], **kwargs)
return model


def hornet_tiny_gf(**kwargs):
"""hornet_tiny_gf
"""
s = 1 / 3.
model = HorNet(depths=[2, 3, 18, 2], base_dim=64, block=Block,
gnconv=[
partial(Gnconv, order=2, s=s),
partial(Gnconv, order=3, s=s),
partial(Gnconv, order=4, s=s, h=14, w=8, gflayer=GlobalLocalFilter),
partial(Gnconv, order=5, s=s, h=7, w=4, gflayer=GlobalLocalFilter)
], **kwargs)
return model


def hornet_small_7x7(**kwargs):
"""hornet_small_7x7
"""
s = 1 / 3.
model = HorNet(depths=[2, 3, 18, 2], base_dim=96, block=Block,
gnconv=[
partial(Gnconv, order=2, s=s),
partial(Gnconv, order=3, s=s),
partial(Gnconv, order=4, s=s),
partial(Gnconv, order=5, s=s)
], **kwargs)
return model


def hornet_small_gf(**kwargs):
"""hornet_small_gf
"""
s = 1 / 3.
model = HorNet(depths=[2, 3, 18, 2], base_dim=96, block=Block,
gnconv=[
partial(Gnconv, order=2, s=s),
partial(Gnconv, order=3, s=s),
partial(Gnconv, order=4, s=s, h=14, w=8, gflayer=GlobalLocalFilter),
partial(Gnconv, order=5, s=s, h=7, w=4, gflayer=GlobalLocalFilter)
], **kwargs)
return model


def hornet_base_7x7(**kwargs):
"""hornet_base_7x7
"""
s = 1 / 3.
model = HorNet(depths=[2, 3, 18, 2], base_dim=128, block=Block,
gnconv=[
partial(Gnconv, order=2, s=s),
partial(Gnconv, order=3, s=s),
partial(Gnconv, order=4, s=s),
partial(Gnconv, order=5, s=s)
], **kwargs)
return model


def hornet_base_gf(**kwargs):
"""hornet_base_gf
"""
s = 1 / 3.
model = HorNet(depths=[2, 3, 18, 2], base_dim=128, block=Block,
gnconv=[
partial(Gnconv, order=2, s=s),
partial(Gnconv, order=3, s=s),
partial(Gnconv, order=4, s=s, h=14, w=8, gflayer=GlobalLocalFilter),
partial(Gnconv, order=5, s=s, h=7, w=4, gflayer=GlobalLocalFilter)
], **kwargs)
return model


def hornet_large_7x7(**kwargs):
"""hornet_large_7x7
"""
s = 1 / 3.
model = HorNet(depths=[2, 3, 18, 2], base_dim=192, block=Block,
gnconv=[
partial(Gnconv, order=2, s=s),
partial(Gnconv, order=3, s=s),
partial(Gnconv, order=4, s=s),
partial(Gnconv, order=5, s=s)
], **kwargs)
return model


def hornet_large_gf(**kwargs):
"""hornet_large_gf
"""
s = 1 / 3.
model = HorNet(depths=[2, 3, 18, 2], base_dim=192, block=Block,
gnconv=[
partial(Gnconv, order=2, s=s),
partial(Gnconv, order=3, s=s),
partial(Gnconv, order=4, s=s, h=14, w=8, gflayer=GlobalLocalFilter),
partial(Gnconv, order=5, s=s, h=7, w=4, gflayer=GlobalLocalFilter)
], **kwargs)
return model


def hornet_large_gf_img384(**kwargs):
"""hornet_large_gf_img384
"""
s = 1 / 3.
model = HorNet(depths=[2, 3, 18, 2], base_dim=192, block=Block,
gnconv=[
partial(Gnconv, order=2, s=s),
partial(Gnconv, order=3, s=s),
partial(Gnconv, order=4, s=s, h=24, w=13, gflayer=GlobalLocalFilter),
partial(Gnconv, order=5, s=s, h=12, w=7, gflayer=GlobalLocalFilter)
], **kwargs)
return model


if __name__ == "__main__":
in_tensor = ms.ops.randn((1, 3, 224, 224))
hornet = HorNet(depths=[2, 3, 18, 2], base_dim=64, block=Block,
gnconv=[
partial(Gnconv, order=2, s=1/3),
partial(Gnconv, order=3, s=1/3),
partial(Gnconv, order=4, s=1/3),
partial(Gnconv, order=5, s=1/3)
])
output = hornet(in_tensor)
print(output.shape)

Loading…
Cancel
Save