diff --git a/model/conv/HorNet.py b/model/conv/HorNet.py new file mode 100644 index 0000000..8bf4ca7 --- /dev/null +++ b/model/conv/HorNet.py @@ -0,0 +1,400 @@ +""" HorNet """ +from functools import partial +import mindspore as ms +from mindspore import nn +from mindspore.common.initializer import initializer, TruncatedNormal, XavierUniform + + +def trunc_norm(tensor, std=.02): + """ truncated normalization + """ + return initializer(TruncatedNormal(sigma=std), tensor.shape, tensor.dtype) + + +def get_dwconv(dim, kernel, bias): + """ get dwconv + """ + return nn.Conv2d(dim, dim, kernel_size=kernel, pad_mode='pad', padding=(kernel - 1) // 2, has_bias=bias, group=dim) + + +class DropPath(nn.Cell): + """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob, ndim=2): + super().__init__() + self.drop = nn.Dropout(keep_prob=1 - drop_prob) + shape = (1,) + (1,) * (ndim + 1) + self.ndim = ndim + self.mask = ms.Tensor(ms.ops.ones(shape), dtype=ms.float32) + + def construct(self, x): + if not self.training: + return x + mask = ms.ops.Tile()(self.mask, (x.shape[0],) + (1,) * (self.ndim + 1)) + out = self.drop(mask) + out = out * x + return out + + +class GlobalLocalFilter(nn.Cell): + """ GlobalLocalFilter + https://arxiv.org/abs/2207.14284 + """ + def __init__(self, dim, h=14, w=8): + super().__init__() + self.dw = nn.Conv2d(dim // 2, dim // 2, kernel_size=3, pad_mode='pad', padding=1, group=dim // 2) + self.complex_weight = ms.Parameter(ms.ops.randn((dim // 2, h, w, 2), dtype=ms.float32) * 0.02) + self.complex_weight = trunc_norm(self.complex_weight) + self.pre_norm = LayerNorm(dim, eps=1e-6, data_format='channels_first') + self.post_norm = LayerNorm(dim, eps=1e-6, data_format='channels_first') + + self.rfft = ms.ops.FFTWithSize(signal_ndim=2, inverse=False, real=True, norm='ortho', onesided=True, + signal_sizes=(2, 3)) + self.irfft = ms.ops.FFTWithSize(signal_ndim=2, inverse=True, real=True, norm='ortho', onesided=True, + signal_sizes=(2, 3)) + + def construct(self, x): + x = self.pre_norm(x) + x1, x2 = ms.ops.chunk(x, 2, axis=1) + x1 = self.dw(x1) + + x2 = x2.to(ms.float32) + B, C, H, W = x2.shape + x2 = self.rfft(x2) + + weight = self.complex_weight + if not weight.shape[1:3] == x2.shape[2:4]: + weight = ms.ops.interpolate(weight.permute(3, 0, 1, 2), size=x2.shape[2:4], mode='bilinear', + align_corners=True).permute(1, 2, 3, 0) + + weight = weight.to(ms.complex64) + x2 *= weight + x2 = self.irfft(x2) + + x = ms.ops.cat([x1.unsqueeze(2), x2.unsqueeze(2)], axis=2).reshape(B, 2 * C, H, W) + x = self.post_norm(x) + return x + + +class Gnconv(nn.Cell): + """gnconv + """ + def __init__(self, dim, order=5, gflayer=None, h=14, w=8, s=1.): + super().__init__() + self.order = order + self.dims = [dim // 2 ** i for i in range(order)] + self.dims.reverse() + self.proj_in = nn.Conv2d(dim, 2 * dim, 1) + + if gflayer is None: + self.dwconv = get_dwconv(sum(self.dims), 7, True) + else: + self.dwconv = gflayer(sum(self.dims), h=h, w=w) + + self.proj_out = nn.Conv2d(dim, dim, 1) + + self.pws = nn.SequentialCell( + [nn.Conv2d(self.dims[i], self.dims[i + 1], 1) for i in range(order - 1)] + ) + self.scale = s + + def construct(self, x): + fused_x = self.proj_in(x) + pwa, abc = ms.ops.split(fused_x, (self.dims[0], sum(self.dims)), axis=1) + + dw_abc = self.dwconv(abc) * self.scale + + dw_list = ms.ops.split(dw_abc, self.dims, axis=1) + x = pwa * dw_list[0] + + for i in range(self.order - 1): + x = self.pws[i](x) * dw_list[i + 1] + + return self.proj_out(x) + + +class Block(nn.Cell): + """HorNet's Block + """ + def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, gnconv=Gnconv): + super().__init__() + self.norm1 = LayerNorm(dim, eps=1e-6, data_format='channels_first') + self.gnconv = gnconv(dim) + self.norm2 = LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Dense(dim, 4 * dim) + self.act = nn.GELU() + self.pwconv2 = nn.Dense(4 * dim, dim) + + self.gamma1 = ms.Parameter(layer_scale_init_value * ms.ops.ones(dim), + requires_grad=True) if layer_scale_init_value > 0 else None + self.gamma2 = ms.Parameter(layer_scale_init_value * ms.ops.ones((dim)), + requires_grad=True) if layer_scale_init_value > 0 else None + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def construct(self, x): + _, C, _, _ = x.shape + if self.gamma1 is not None: + gamma1 = self.gamma1.view(C, 1, 1) + else: + gamma1 = 1 + x += self.drop_path(gamma1 * self.gnconv(self.norm1(x))) + + inputs = x + x = x.permute(0, 2, 3, 1) + x = self.norm2(x) + + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + + if self.gamma2 is not None: + x *= self.gamma2 + x = x.permute(0, 3, 1, 2) + x = inputs + self.drop_path(x) + return x + + +class HorNet(nn.Cell): + """HorNet + """ + def __init__(self, depths, in_chans=3, num_classes=1000, base_dim=96, drop_path_rate=0., layer_scale_init_value=1e-6, + head_init_scale=1., gnconv=Gnconv, block=Block, uniform_init=False): + super().__init__() + dims = [base_dim, base_dim * 2, base_dim * 4, base_dim * 8] + self.downsample_layers = nn.SequentialCell() + stem = nn.SequentialCell( + nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), + LayerNorm(dims[0], eps=1e-6, data_format='channels_first') + ) + self.downsample_layers.append(stem) + for i in range(3): + dowmsample_layer = nn.SequentialCell( + LayerNorm(dims[i], eps=1e-6, data_format='channels_first'), + nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2) + ) + self.downsample_layers.append(dowmsample_layer) + + self.stages = nn.SequentialCell() + dp_rates = list(x for x in ms.ops.linspace(0, drop_path_rate, sum(depths))) + + if not isinstance(gnconv, list): + gnconv = [gnconv, gnconv, gnconv, gnconv] + assert len(gnconv) == 4 + + cur = 0 + for i in range(4): + stage = nn.SequentialCell( + *[block(dim=dims[i], drop_path=dp_rates[cur + j], + layer_scale_init_value=layer_scale_init_value, + gnconv=gnconv[i]) for j in range(depths[i])] + ) + self.stages.append(stage) + cur += depths[i] + + self.norm = nn.LayerNorm((dims[-1],), epsilon=1e-6) + self.head = nn.Dense(dims[-1], num_classes) + self.uniform_init = uniform_init + + weight = self.head.weight * head_init_scale + self.head.weight.set_data(weight) + + if self.head.bias is not None: + bias = self.head.bias * head_init_scale + self.head.bias.set_data(bias) + + self.apply(self.init_weight) + + def init_weight(self, cell): + """ init weight + """ + if not self.uniform_init: + if isinstance(cell, (nn.Conv2d, nn.Dense)): + cell.weight.set_data(trunc_norm(cell.weight)) + if cell.bias is not None: + cell.bias.set_data(initializer('zeros', cell.bias.shape, cell.bias.dtype)) + else: + if isinstance(cell, (nn.Conv2d, nn.Dense)): + cell.weight.set_data(initializer(XavierUniform(), cell.weight.shape, cell.weight.dtype)) + if cell.bias is not None: + cell.bias.set_data(initializer('zeros', cell.bias.shape, cell.bias.dtype)) + + def construct_features(self, x): + """construct features + """ + for i in range(4): + x = self.downsample_layers[i](x) + for blk in self.stages[i]: + x = blk(x) + return self.norm(x.mean([-2, -1])) + + def construct(self, x): + x = self.construct_features(x) + return self.head(x) + + +class LayerNorm(nn.Cell): + """LayerNorm + """ + def __init__(self, normalized_shape, eps=1e-6, data_format='channels_last'): + super().__init__() + self.weight = ms.Parameter(ms.ops.ones(normalized_shape)) + self.bias = ms.Parameter(ms.ops.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ['channels_first', 'channels_last']: + raise NotImplementedError + + self.normalized_shape = (normalized_shape,) + self.layer_norm = nn.LayerNorm(self.normalized_shape, epsilon=self.eps) + + def construct(self, x): + if self.data_format == 'channels_last': + self.layer_norm.gamma.set_data(self.weight) + self.layer_norm.beta.set_data(self.bias) + x = self.layer_norm(x) + if self.data_format == 'channels_first': + u = x.mean(1, keep_dims=True) + s = (x - u).pow(2).mean(1, keep_dims=True) + x = (x - u) / ms.ops.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +def hornet_tiny_7x7(**kwargs): + """hornet_tiny_7x7 + """ + s = 1 / 3. + model = HorNet(depths=[2, 3, 18, 2], base_dim=64, block=Block, + gnconv=[ + partial(Gnconv, order=2, s=s), + partial(Gnconv, order=3, s=s), + partial(Gnconv, order=4, s=s), + partial(Gnconv, order=5, s=s) + ], **kwargs) + return model + + +def hornet_tiny_gf(**kwargs): + """hornet_tiny_gf + """ + s = 1 / 3. + model = HorNet(depths=[2, 3, 18, 2], base_dim=64, block=Block, + gnconv=[ + partial(Gnconv, order=2, s=s), + partial(Gnconv, order=3, s=s), + partial(Gnconv, order=4, s=s, h=14, w=8, gflayer=GlobalLocalFilter), + partial(Gnconv, order=5, s=s, h=7, w=4, gflayer=GlobalLocalFilter) + ], **kwargs) + return model + + +def hornet_small_7x7(**kwargs): + """hornet_small_7x7 + """ + s = 1 / 3. + model = HorNet(depths=[2, 3, 18, 2], base_dim=96, block=Block, + gnconv=[ + partial(Gnconv, order=2, s=s), + partial(Gnconv, order=3, s=s), + partial(Gnconv, order=4, s=s), + partial(Gnconv, order=5, s=s) + ], **kwargs) + return model + + +def hornet_small_gf(**kwargs): + """hornet_small_gf + """ + s = 1 / 3. + model = HorNet(depths=[2, 3, 18, 2], base_dim=96, block=Block, + gnconv=[ + partial(Gnconv, order=2, s=s), + partial(Gnconv, order=3, s=s), + partial(Gnconv, order=4, s=s, h=14, w=8, gflayer=GlobalLocalFilter), + partial(Gnconv, order=5, s=s, h=7, w=4, gflayer=GlobalLocalFilter) + ], **kwargs) + return model + + +def hornet_base_7x7(**kwargs): + """hornet_base_7x7 + """ + s = 1 / 3. + model = HorNet(depths=[2, 3, 18, 2], base_dim=128, block=Block, + gnconv=[ + partial(Gnconv, order=2, s=s), + partial(Gnconv, order=3, s=s), + partial(Gnconv, order=4, s=s), + partial(Gnconv, order=5, s=s) + ], **kwargs) + return model + + +def hornet_base_gf(**kwargs): + """hornet_base_gf + """ + s = 1 / 3. + model = HorNet(depths=[2, 3, 18, 2], base_dim=128, block=Block, + gnconv=[ + partial(Gnconv, order=2, s=s), + partial(Gnconv, order=3, s=s), + partial(Gnconv, order=4, s=s, h=14, w=8, gflayer=GlobalLocalFilter), + partial(Gnconv, order=5, s=s, h=7, w=4, gflayer=GlobalLocalFilter) + ], **kwargs) + return model + + +def hornet_large_7x7(**kwargs): + """hornet_large_7x7 + """ + s = 1 / 3. + model = HorNet(depths=[2, 3, 18, 2], base_dim=192, block=Block, + gnconv=[ + partial(Gnconv, order=2, s=s), + partial(Gnconv, order=3, s=s), + partial(Gnconv, order=4, s=s), + partial(Gnconv, order=5, s=s) + ], **kwargs) + return model + + +def hornet_large_gf(**kwargs): + """hornet_large_gf + """ + s = 1 / 3. + model = HorNet(depths=[2, 3, 18, 2], base_dim=192, block=Block, + gnconv=[ + partial(Gnconv, order=2, s=s), + partial(Gnconv, order=3, s=s), + partial(Gnconv, order=4, s=s, h=14, w=8, gflayer=GlobalLocalFilter), + partial(Gnconv, order=5, s=s, h=7, w=4, gflayer=GlobalLocalFilter) + ], **kwargs) + return model + + +def hornet_large_gf_img384(**kwargs): + """hornet_large_gf_img384 + """ + s = 1 / 3. + model = HorNet(depths=[2, 3, 18, 2], base_dim=192, block=Block, + gnconv=[ + partial(Gnconv, order=2, s=s), + partial(Gnconv, order=3, s=s), + partial(Gnconv, order=4, s=s, h=24, w=13, gflayer=GlobalLocalFilter), + partial(Gnconv, order=5, s=s, h=12, w=7, gflayer=GlobalLocalFilter) + ], **kwargs) + return model + + +if __name__ == "__main__": + in_tensor = ms.ops.randn((1, 3, 224, 224)) + hornet = HorNet(depths=[2, 3, 18, 2], base_dim=64, block=Block, + gnconv=[ + partial(Gnconv, order=2, s=1/3), + partial(Gnconv, order=3, s=1/3), + partial(Gnconv, order=4, s=1/3), + partial(Gnconv, order=5, s=1/3) + ]) + output = hornet(in_tensor) + print(output.shape)