Browse Source

Delete 'model/backbone/ConTNet.py'

v1
limingjuan 2 years ago
parent
commit
8a7ee9b137
1 changed files with 0 additions and 357 deletions
  1. +0
    -357
      model/backbone/ConTNet.py

+ 0
- 357
model/backbone/ConTNet.py View File

@@ -1,357 +0,0 @@
# pylint: disable=E0401
# pylint: disable=W0201
"""
MindSpore implementation of `ConTNet`.
Refer to ConTNet: Why not use convolution and transformer at the same time?
"""
from collections import OrderedDict
import mindspore as ms
from mindspore import nn
from mindspore.common.initializer import initializer, HeNormal, TruncatedNormal, Constant

from model.layers import _trunc_normal_

__all__ = ['ConTBlock', 'ConTNet']


def fixed_padding(inputs, kernel_size, dilation):
""" fixed padding """
kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
pad_total = kernel_size_effective - 1
pad_beg = pad_total // 2
pad_end = pad_total - pad_beg
padded_inputs = ms.ops.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end))
return padded_inputs


class ConvBN(nn.SequentialCell):
""" Conv and BN """
def __init__(self, in_planes, out_planes, kernel_size, stride=1, group=1, bn=True):
padding = (kernel_size - 1) // 2
if bn:
super().__init__(OrderedDict([
('conv', nn.Conv2d(in_planes, out_planes, kernel_size, stride,
pad_mode='pad', padding=padding, group=group)),
]))
else:
super().__init__(OrderedDict([
('conv', nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad',
padding=padding, group=group, has_bias=False))
]))


class MHSA(nn.Cell):
""" Multi-head Self Attention """
def __init__(self, planes, head_num, dropout, patch_size, qkv_bias, relative):
super().__init__()
self.head_num = head_num
head_dim = planes // head_num
self.qkv = nn.Dense(planes, 3 * planes, has_bias=qkv_bias)
self.relative = relative
self.patch_size = patch_size
self.scale = head_dim ** -0.5

if self.relative:
# print('### relative position embedding ###')
self.relative_position_bias_table = ms.Parameter(
ms.ops.zeros(((2 * patch_size - 1) * (2 * patch_size - 1), head_num)))
coords_w = coords_h = ms.ops.arange(patch_size)
coords = ms.ops.stack(ms.ops.meshgrid(coords_h, coords_w))
coords_flatten = ms.ops.flatten(coords, start_dim=1)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0)
relative_coords[:, :, 0] += patch_size - 1
relative_coords[:, :, 1] += patch_size - 1
relative_coords[:, :, 0] *= 2 * patch_size - 1
self.relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
_trunc_normal_(self.relative_position_bias_table, std=.02)

self.attn_drop = nn.Dropout(p=dropout)
self.proj = nn.Dense(planes, planes)
self.proj_drop = nn.Dropout(p=dropout)

def construct(self, x):
B, N, C, H = *x.shape, self.head_num
qkv = self.qkv(x).reshape(B, N, 3, H, C // H).permute(2, 0, 3, 1, 4) # x: (3, B, H, N, C//H)
q, k, v = qkv[0], qkv[1], qkv[2] # x: (B, H, N, C//N)

q = q * self.scale
attn = q @ k.transpose(0, 1, 3, 2) # attn: (B, H, N, N)

if self.relative:
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.patch_size ** 2, self.patch_size ** 2, -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1)
attn = attn + relative_position_bias.unsqueeze(0)

attn = ms.ops.softmax(attn, axis=-1)
# attn = attn.softmax(axis=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(0, 2, 1, 3).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)

return x


class MLP(nn.Cell):
""" Build a Multi-Layer Perceptron """

def __init__(self, planes, mlp_dim, dropout):
super().__init__()
self.fc1 = nn.Dense(planes, mlp_dim)
self.act = nn.GELU()
self.fc2 = nn.Dense(mlp_dim, planes)
self.drop = nn.Dropout(p=dropout)

def construct(self, x):
x = self.drop(self.act(self.fc1(x)))
x = self.drop(self.fc2(x))
return x


class STE(nn.Cell):
""" Build a Standard Transformer Encoder(STE) """

def __init__(self, planes, mlp_dim, head_num, dropout, patch_size,
relative, qkv_bias, pre_norm):
super().__init__()
self.patch_size = patch_size
self.pre_norm = pre_norm
self.relative = relative

if not relative:
self.pe = ms.ParameterTuple(
[ms.Parameter(ms.ops.zeros((1, patch_size, 1, planes // 2))),
ms.Parameter(ms.ops.zeros((1, 1, patch_size, planes // 2)))]
)
self.attn = MHSA(planes, head_num, dropout, patch_size, qkv_bias, relative)
self.mlp = MLP(planes, mlp_dim, dropout)
self.norm1 = nn.LayerNorm((planes,))
self.norm2 = nn.LayerNorm((planes,))

self.unfold1 = nn.Unfold((1, 1, self.patch_size, 1), strides=[1, 1, self.patch_size, 1], rates=[1, 1, 1, 1])
self.unfold2 = nn.Unfold((1, self.patch_size, 1, 1), strides=[1, self.patch_size, 1, 1], rates=[1, 1, 1, 1])

def construct(self, x):
B, C, H, W = x.shape
patch_size = self.patch_size
patch_num_h, patch_num_w = H // patch_size, W // patch_size

x = x.unfold(kernel_size=(self.patch_size, 1), stride=(self.patch_size, 1)).reshape(B, C, self.patch_size, -1)
x = x.unfold(kernel_size=(1, self.patch_size), stride=(1, self.patch_size))
x = x.reshape(B, C, patch_num_h, patch_num_w, patch_size, patch_size)
x = x.permute(0, 2, 3, 4, 5, 1).reshape(-1, patch_size, patch_size, C)
if not self.relative:
x_h, x_w = x.split(C // 2, axis=3)
x = ms.ops.cat((x_h + self.pe[0], x_w + self.pe[1]), axis=3)

x = x.reshape(x.shape[0], -1, x.shape[-1])

if self.pre_norm:
x += self.attn(self.norm1(x))
x += self.mlp(self.norm2(x))
else:
x = self.norm1(x + self.attn(x))
x = self.norm2(x + self.mlp(x))

b_pnh_pnw, _, c = x.shape
b = b_pnh_pnw // (patch_num_h * patch_num_w)
x = x.reshape(b, patch_num_h, patch_num_w, patch_size, patch_size, c).permute(0, 5, 1, 3, 2, 4)
x = x.reshape(b, c, patch_num_h * patch_size, patch_num_w * patch_size)
return x


class ConTBlock(nn.Cell):
""" Build a ContBlock """

def __init__(self, planes, out_planes, mlp_dim, head_num, dropout, patch_size,
downsample, stride=1, last_dropout=.3, **kwargs):
super().__init__()
self.downsample = downsample
self.identity = nn.Identity()
self.dropout = nn.Identity()

self.bn = nn.BatchNorm2d(planes)
self.relu = nn.ReLU()
self.ste1 = STE(planes, mlp_dim, head_num, dropout, patch_size[0], **kwargs)
self.ste2 = STE(planes, mlp_dim, head_num, dropout, patch_size[1], **kwargs)

if stride == 1 and downsample is not None:
self.dropout = nn.Dropout(p=last_dropout)
kernel_size = 1
else:
kernel_size = 3

self.out_conv = ConvBN(planes, out_planes, kernel_size, stride, bn=False)

def construct(self, x):
x_preact = self.relu(self.bn(x))
identity = self.identity(x)

if self.downsample is not None:
identity = self.downsample(x_preact)

residual = self.ste1(x_preact)
residual = self.ste2(residual)
residual = self.out_conv(residual)
out = self.dropout(residual + identity)
return out


class ConTNet(nn.Cell):
""" Build a ConTNet backbone """

def __init__(self, block, layers, mlp_dim, head_num, dropout, in_channels=3,
inplanes=64, num_classes=1000, init_weights=True, first_embedding=False,
tweak_C=False, **kwargs):
super().__init__()
self.inplanes = inplanes
self.block = block

if tweak_C:
self.layer0 = nn.SequentialCell(OrderedDict([
('conv_bn1', ConvBN(in_channels, inplanes//2, kernel_size=3, stride=2)),
('relu1', nn.ReLU()),
('conv_bn2', ConvBN(inplanes//2, inplanes//2, kernel_size=3, stride=1)),
('relu2', nn.ReLU()),
('conv_bn3', ConvBN(inplanes//2, inplanes, kernel_size=3, stride=1)),
('relu3', nn.ReLU()),
('maxpool', nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='pad', padding=1))
]))
elif first_embedding:
self.layer0 = nn.SequentialCell(OrderedDict([
('conv', nn.Conv2d(in_channels, inplanes, kernel_size=4, stride=4)),
('norm', nn.LayerNorm((inplanes, )))
]))
else:
self.layer0 = nn.SequentialCell(OrderedDict([
('conv', ConvBN(in_channels, inplanes, kernel_size=7, stride=2, bn=False)),
('maxpool', nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='pad', padding=1))
]))

self.cont_layers = []
self.out_channels = OrderedDict()

for i, layer in enumerate(layers):
stride = 2
patch_size = [7, 14]
if i == len(layers) - 1:
stride, patch_size[1] = 1, 7
cont_layer = self._make_layer(inplanes * 2**i, layer, stride=stride, mlp_dim=mlp_dim[i],
head_num=head_num[i], dropout=dropout[i], patch_size=patch_size, **kwargs)
layer_name = f'layer{i+1}'
self.insert_child_to_cell(layer_name, cont_layer)
self.cont_layers.append(layer_name)
self.out_channels[layer_name] = 2 * inplanes * 2**i

self.last_out_channels = next(reversed(self.out_channels.values()))
self.fc = nn.Dense(self.last_out_channels, num_classes)

if init_weights:
self.apply(self._init_weights)

def _make_layer(self, planes, blocks, stride, mlp_dim, head_num, dropout, patch_size, use_avg_down=False, **kwargs):
""" make layer """
layers = OrderedDict()
for i in range(0, blocks-1):
layers[f'{self.block.__name__}{i}'] = self.block(
planes, planes, mlp_dim, head_num, dropout, patch_size, **kwargs
)
# downsample = None
if stride != 1:
if use_avg_down:
downsample = nn.SequentialCell(OrderedDict([
('avgpool', nn.AvgPool2d(kernel_size=2, stride=2)),
('conv', ConvBN(planes, planes*2, kernel_size=1, stride=1, bn=False))
]))
else:
downsample = ConvBN(planes, planes*2, kernel_size=1, stride=2, bn=False)
else:
downsample = ConvBN(planes, planes*2, kernel_size=1, stride=1, bn=False)
layers[f'{self.block.__name__}{blocks-1}'] = self.block(
planes, planes*2, mlp_dim, head_num, dropout, patch_size, downsample, stride, **kwargs
)
return nn.SequentialCell(layers)

def _init_weights(self, cell):
""" initialize weight """
if isinstance(cell, nn.Conv2d):
cell.weight.set_data(initializer(HeNormal(mode='fan_out', nonlinearity='relu'),
cell.weight.shape, cell.weight.dtype))
elif isinstance(cell, nn.Dense):
cell.weight.set_data(initializer(TruncatedNormal(sigma=.02),
cell.weight.shape, cell.weight.dtype))
if cell.bias is not None:
cell.bias.set_data(initializer(Constant(0), cell.bias.shape, cell.bias.dtype))
elif isinstance(cell, (nn.BatchNorm2d, nn.LayerNorm)):
cell.gamma.set_data(initializer(Constant(1), cell.gamma.shape, cell.gamma.dtype))
cell.beta.set_data(initializer(Constant(0), cell.beta.shape, cell.beta.dtype))

def construct(self, x):
x = self.layer0(x)

for _, layer_name in enumerate(self.cont_layers):
cont_layer = getattr(self, layer_name)
x = cont_layer(x)
x = x.mean([2, 3])
x = self.fc(x)

return x


def create_ConTNet_Ti(kwargs):
""" ConTNet-Ti """
return ConTNet(block=ConTBlock,
mlp_dim=[196, 392, 768, 768],
head_num=[1, 2, 4, 8],
dropout=[0, 0, 0, 0],
inplanes=48,
layers=[1, 1, 1, 1],
last_dropout=0,
**kwargs)

def create_ConTNet_S(kwargs):
""" ConTNet-S """
return ConTNet(block=ConTBlock,
mlp_dim=[256, 512, 1024, 1024],
head_num=[1, 2, 4, 8],
dropout=[0, 0, 0, 0],
inplanes=64,
layers=[1, 1, 1, 1],
last_dropout=0,
**kwargs)

def create_ConTNet_M(kwargs):
""" ConTNet-M """
return ConTNet(block=ConTBlock,
mlp_dim=[256, 512, 1024, 1024],
head_num=[1, 2, 4, 8],
dropout=[0, 0, 0, 0],
inplanes=64,
layers=[2, 2, 2, 2],
last_dropout=0,
**kwargs)

def create_ConTNet_B(kwargs):
""" ConTNet-B """
return ConTNet(block=ConTBlock,
mlp_dim=[256, 512, 1024, 1024],
head_num=[1, 2, 4, 8],
dropout=[0, 0, 0.1, 0.1],
inplanes=64,
layers=[3, 4, 6, 3],
last_dropout=0.2,
**kwargs)

def build_model(relative, qkv_bias, pre_norm):
""" build model """
kwargs = {"relative": relative, "qkv_bias": qkv_bias, 'pre_norm': pre_norm}
return create_ConTNet_Ti(kwargs)


if __name__ == "__main__":
model = build_model(relative=True, qkv_bias=True, pre_norm=True)
dummy_input = ms.ops.randn(1, 3, 224, 224)
output = model(dummy_input)
print(output.shape)

Loading…
Cancel
Save