接入图像深度估计模型,新增model、pipeline、test Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10857764master^2
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:3b230497f6ca10be42aed92b86db435d74fd7306746a059b4ad1e0d6b0652806 | |||
size 35694 |
@@ -36,6 +36,7 @@ class Models(object): | |||
swinL_semantic_segmentation = 'swinL-semantic-segmentation' | |||
vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' | |||
text_driven_segmentation = 'text-driven-segmentation' | |||
newcrfs_depth_estimation = 'newcrfs-depth-estimation' | |||
resnet50_bert = 'resnet50-bert' | |||
referring_video_object_segmentation = 'swinT-referring-video-object-segmentation' | |||
fer = 'fer' | |||
@@ -208,6 +209,7 @@ class Pipelines(object): | |||
video_summarization = 'googlenet_pgl_video_summarization' | |||
language_guided_video_summarization = 'clip-it-video-summarization' | |||
image_semantic_segmentation = 'image-semantic-segmentation' | |||
image_depth_estimation = 'image-depth-estimation' | |||
image_reid_person = 'passvitb-image-reid-person' | |||
image_inpainting = 'fft-inpainting' | |||
text_driven_segmentation = 'text-driven-segmentation' | |||
@@ -0,0 +1 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. |
@@ -0,0 +1 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. |
@@ -0,0 +1,215 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from .newcrf_layers import NewCRF | |||
from .swin_transformer import SwinTransformer | |||
from .uper_crf_head import PSP | |||
class NewCRFDepth(nn.Module): | |||
""" | |||
Depth network based on neural window FC-CRFs architecture. | |||
""" | |||
def __init__(self, | |||
version=None, | |||
inv_depth=False, | |||
pretrained=None, | |||
frozen_stages=-1, | |||
min_depth=0.1, | |||
max_depth=100.0, | |||
**kwargs): | |||
super().__init__() | |||
self.inv_depth = inv_depth | |||
self.with_auxiliary_head = False | |||
self.with_neck = False | |||
norm_cfg = dict(type='BN', requires_grad=True) | |||
# norm_cfg = dict(type='GN', requires_grad=True, num_groups=8) | |||
window_size = int(version[-2:]) | |||
if version[:-2] == 'base': | |||
embed_dim = 128 | |||
depths = [2, 2, 18, 2] | |||
num_heads = [4, 8, 16, 32] | |||
in_channels = [128, 256, 512, 1024] | |||
elif version[:-2] == 'large': | |||
embed_dim = 192 | |||
depths = [2, 2, 18, 2] | |||
num_heads = [6, 12, 24, 48] | |||
in_channels = [192, 384, 768, 1536] | |||
elif version[:-2] == 'tiny': | |||
embed_dim = 96 | |||
depths = [2, 2, 6, 2] | |||
num_heads = [3, 6, 12, 24] | |||
in_channels = [96, 192, 384, 768] | |||
backbone_cfg = dict( | |||
embed_dim=embed_dim, | |||
depths=depths, | |||
num_heads=num_heads, | |||
window_size=window_size, | |||
ape=False, | |||
drop_path_rate=0.3, | |||
patch_norm=True, | |||
use_checkpoint=False, | |||
frozen_stages=frozen_stages) | |||
embed_dim = 512 | |||
decoder_cfg = dict( | |||
in_channels=in_channels, | |||
in_index=[0, 1, 2, 3], | |||
pool_scales=(1, 2, 3, 6), | |||
channels=embed_dim, | |||
dropout_ratio=0.0, | |||
num_classes=32, | |||
norm_cfg=norm_cfg, | |||
align_corners=False) | |||
self.backbone = SwinTransformer(**backbone_cfg) | |||
# v_dim = decoder_cfg['num_classes'] * 4 | |||
win = 7 | |||
crf_dims = [128, 256, 512, 1024] | |||
v_dims = [64, 128, 256, embed_dim] | |||
self.crf3 = NewCRF( | |||
input_dim=in_channels[3], | |||
embed_dim=crf_dims[3], | |||
window_size=win, | |||
v_dim=v_dims[3], | |||
num_heads=32) | |||
self.crf2 = NewCRF( | |||
input_dim=in_channels[2], | |||
embed_dim=crf_dims[2], | |||
window_size=win, | |||
v_dim=v_dims[2], | |||
num_heads=16) | |||
self.crf1 = NewCRF( | |||
input_dim=in_channels[1], | |||
embed_dim=crf_dims[1], | |||
window_size=win, | |||
v_dim=v_dims[1], | |||
num_heads=8) | |||
self.crf0 = NewCRF( | |||
input_dim=in_channels[0], | |||
embed_dim=crf_dims[0], | |||
window_size=win, | |||
v_dim=v_dims[0], | |||
num_heads=4) | |||
self.decoder = PSP(**decoder_cfg) | |||
self.disp_head1 = DispHead(input_dim=crf_dims[0]) | |||
self.up_mode = 'bilinear' | |||
if self.up_mode == 'mask': | |||
self.mask_head = nn.Sequential( | |||
nn.Conv2d(crf_dims[0], 64, 3, padding=1), | |||
nn.ReLU(inplace=True), nn.Conv2d(64, 16 * 9, 1, padding=0)) | |||
self.min_depth = min_depth | |||
self.max_depth = max_depth | |||
self.init_weights(pretrained=pretrained) | |||
def init_weights(self, pretrained=None): | |||
"""Initialize the weights in backbone and heads. | |||
Args: | |||
pretrained (str, optional): Path to pre-trained weights. | |||
Defaults to None. | |||
""" | |||
# print(f'== Load encoder backbone from: {pretrained}') | |||
self.backbone.init_weights(pretrained=pretrained) | |||
self.decoder.init_weights() | |||
if self.with_auxiliary_head: | |||
if isinstance(self.auxiliary_head, nn.ModuleList): | |||
for aux_head in self.auxiliary_head: | |||
aux_head.init_weights() | |||
else: | |||
self.auxiliary_head.init_weights() | |||
def upsample_mask(self, disp, mask): | |||
""" Upsample disp [H/4, W/4, 1] -> [H, W, 1] using convex combination """ | |||
N, _, H, W = disp.shape | |||
mask = mask.view(N, 1, 9, 4, 4, H, W) | |||
mask = torch.softmax(mask, dim=2) | |||
up_disp = F.unfold(disp, kernel_size=3, padding=1) | |||
up_disp = up_disp.view(N, 1, 9, 1, 1, H, W) | |||
up_disp = torch.sum(mask * up_disp, dim=2) | |||
up_disp = up_disp.permute(0, 1, 4, 2, 5, 3) | |||
return up_disp.reshape(N, 1, 4 * H, 4 * W) | |||
def forward(self, imgs): | |||
feats = self.backbone(imgs) | |||
if self.with_neck: | |||
feats = self.neck(feats) | |||
ppm_out = self.decoder(feats) | |||
e3 = self.crf3(feats[3], ppm_out) | |||
e3 = nn.PixelShuffle(2)(e3) | |||
e2 = self.crf2(feats[2], e3) | |||
e2 = nn.PixelShuffle(2)(e2) | |||
e1 = self.crf1(feats[1], e2) | |||
e1 = nn.PixelShuffle(2)(e1) | |||
e0 = self.crf0(feats[0], e1) | |||
if self.up_mode == 'mask': | |||
mask = self.mask_head(e0) | |||
d1 = self.disp_head1(e0, 1) | |||
d1 = self.upsample_mask(d1, mask) | |||
else: | |||
d1 = self.disp_head1(e0, 4) | |||
depth = d1 * self.max_depth | |||
return depth | |||
class DispHead(nn.Module): | |||
def __init__(self, input_dim=100): | |||
super(DispHead, self).__init__() | |||
# self.norm1 = nn.BatchNorm2d(input_dim) | |||
self.conv1 = nn.Conv2d(input_dim, 1, 3, padding=1) | |||
# self.relu = nn.ReLU(inplace=True) | |||
self.sigmoid = nn.Sigmoid() | |||
def forward(self, x, scale): | |||
# x = self.relu(self.norm1(x)) | |||
x = self.sigmoid(self.conv1(x)) | |||
if scale > 1: | |||
x = upsample(x, scale_factor=scale) | |||
return x | |||
class DispUnpack(nn.Module): | |||
def __init__(self, input_dim=100, hidden_dim=128): | |||
super(DispUnpack, self).__init__() | |||
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) | |||
self.conv2 = nn.Conv2d(hidden_dim, 16, 3, padding=1) | |||
self.relu = nn.ReLU(inplace=True) | |||
self.sigmoid = nn.Sigmoid() | |||
self.pixel_shuffle = nn.PixelShuffle(4) | |||
def forward(self, x, output_size): | |||
x = self.relu(self.conv1(x)) | |||
x = self.sigmoid(self.conv2(x)) # [b, 16, h/4, w/4] | |||
# x = torch.reshape(x, [x.shape[0], 1, x.shape[2]*4, x.shape[3]*4]) | |||
x = self.pixel_shuffle(x) | |||
return x | |||
def upsample(x, scale_factor=2, mode='bilinear', align_corners=False): | |||
"""Upsample input tensor by a factor of 2 | |||
""" | |||
return F.interpolate( | |||
x, scale_factor=scale_factor, mode=mode, align_corners=align_corners) |
@@ -0,0 +1,504 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
import torch.utils.checkpoint as checkpoint | |||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_ | |||
class Mlp(nn.Module): | |||
""" Multilayer perceptron.""" | |||
def __init__(self, | |||
in_features, | |||
hidden_features=None, | |||
out_features=None, | |||
act_layer=nn.GELU, | |||
drop=0.): | |||
super().__init__() | |||
out_features = out_features or in_features | |||
hidden_features = hidden_features or in_features | |||
self.fc1 = nn.Linear(in_features, hidden_features) | |||
self.act = act_layer() | |||
self.fc2 = nn.Linear(hidden_features, out_features) | |||
self.drop = nn.Dropout(drop) | |||
def forward(self, x): | |||
x = self.fc1(x) | |||
x = self.act(x) | |||
x = self.drop(x) | |||
x = self.fc2(x) | |||
x = self.drop(x) | |||
return x | |||
def window_partition(x, window_size): | |||
""" | |||
Args: | |||
x: (B, H, W, C) | |||
window_size (int): window size | |||
Returns: | |||
windows: (num_windows*B, window_size, window_size, C) | |||
""" | |||
B, H, W, C = x.shape | |||
x = x.view(B, H // window_size, window_size, W // window_size, window_size, | |||
C) | |||
windows = x.permute(0, 1, 3, 2, 4, | |||
5).contiguous().view(-1, window_size, window_size, C) | |||
return windows | |||
def window_reverse(windows, window_size, H, W): | |||
""" | |||
Args: | |||
windows: (num_windows*B, window_size, window_size, C) | |||
window_size (int): Window size | |||
H (int): Height of image | |||
W (int): Width of image | |||
Returns: | |||
x: (B, H, W, C) | |||
""" | |||
B = int(windows.shape[0] / (H * W / window_size / window_size)) | |||
x = windows.view(B, H // window_size, W // window_size, window_size, | |||
window_size, -1) | |||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) | |||
return x | |||
class WindowAttention(nn.Module): | |||
""" Window based multi-head self attention (W-MSA) module with relative position bias. | |||
It supports both of shifted and non-shifted window. | |||
Args: | |||
dim (int): Number of input channels. | |||
window_size (tuple[int]): The height and width of the window. | |||
num_heads (int): Number of attention heads. | |||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True | |||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set | |||
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 | |||
proj_drop (float, optional): Dropout ratio of output. Default: 0.0 | |||
""" | |||
def __init__(self, | |||
dim, | |||
window_size, | |||
num_heads, | |||
v_dim, | |||
qkv_bias=True, | |||
qk_scale=None, | |||
attn_drop=0., | |||
proj_drop=0.): | |||
super().__init__() | |||
self.dim = dim | |||
self.window_size = window_size # Wh, Ww | |||
self.num_heads = num_heads | |||
head_dim = dim // num_heads | |||
self.scale = qk_scale or head_dim**-0.5 | |||
# define a parameter table of relative position bias | |||
self.relative_position_bias_table = nn.Parameter( | |||
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), | |||
num_heads)) # 2*Wh-1 * 2*Ww-1, nH | |||
# get pair-wise relative position index for each token inside the window | |||
coords_h = torch.arange(self.window_size[0]) | |||
coords_w = torch.arange(self.window_size[1]) | |||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww | |||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww | |||
relative_coords = coords_flatten[:, :, | |||
None] - coords_flatten[:, | |||
None, :] # 2, Wh*Ww, Wh*Ww | |||
relative_coords = relative_coords.permute( | |||
1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 | |||
relative_coords[:, :, | |||
0] += self.window_size[0] - 1 # shift to start from 0 | |||
relative_coords[:, :, 1] += self.window_size[1] - 1 | |||
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 | |||
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww | |||
self.register_buffer('relative_position_index', | |||
relative_position_index) | |||
self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias) | |||
self.attn_drop = nn.Dropout(attn_drop) | |||
self.proj = nn.Linear(v_dim, v_dim) | |||
self.proj_drop = nn.Dropout(proj_drop) | |||
trunc_normal_(self.relative_position_bias_table, std=.02) | |||
self.softmax = nn.Softmax(dim=-1) | |||
def forward(self, x, v, mask=None): | |||
""" Forward function. | |||
Args: | |||
x: input features with shape of (num_windows*B, N, C) | |||
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None | |||
""" | |||
B_, N, C = x.shape | |||
qk = self.qk(x).reshape(B_, N, 2, self.num_heads, | |||
C // self.num_heads).permute(2, 0, 3, 1, 4) | |||
q, k = qk[0], qk[ | |||
1] # make torchscript happy (cannot use tensor as tuple) | |||
q = q * self.scale | |||
attn = (q @ k.transpose(-2, -1)) | |||
relative_position_bias = self.relative_position_bias_table[ | |||
self.relative_position_index.view(-1)].view( | |||
self.window_size[0] * self.window_size[1], | |||
self.window_size[0] * self.window_size[1], | |||
-1) # Wh*Ww,Wh*Ww,nH | |||
relative_position_bias = relative_position_bias.permute( | |||
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww | |||
attn = attn + relative_position_bias.unsqueeze(0) | |||
if mask is not None: | |||
nW = mask.shape[0] | |||
attn = attn.view(B_ // nW, nW, self.num_heads, N, | |||
N) + mask.unsqueeze(1).unsqueeze(0) | |||
attn = attn.view(-1, self.num_heads, N, N) | |||
attn = self.softmax(attn) | |||
else: | |||
attn = self.softmax(attn) | |||
attn = self.attn_drop(attn) | |||
# assert self.dim % v.shape[-1] == 0, "self.dim % v.shape[-1] != 0" | |||
# repeat_num = self.dim // v.shape[-1] | |||
# v = v.view(B_, N, self.num_heads // repeat_num, -1).transpose(1, 2).repeat(1, repeat_num, 1, 1) | |||
assert self.dim == v.shape[-1], 'self.dim != v.shape[-1]' | |||
v = v.view(B_, N, self.num_heads, -1).transpose(1, 2) | |||
x = (attn @ v).transpose(1, 2).reshape(B_, N, C) | |||
x = self.proj(x) | |||
x = self.proj_drop(x) | |||
return x | |||
class CRFBlock(nn.Module): | |||
""" CRF Block. | |||
Args: | |||
dim (int): Number of input channels. | |||
num_heads (int): Number of attention heads. | |||
window_size (int): Window size. | |||
shift_size (int): Shift size for SW-MSA. | |||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. | |||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True | |||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. | |||
drop (float, optional): Dropout rate. Default: 0.0 | |||
attn_drop (float, optional): Attention dropout rate. Default: 0.0 | |||
drop_path (float, optional): Stochastic depth rate. Default: 0.0 | |||
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU | |||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm | |||
""" | |||
def __init__(self, | |||
dim, | |||
num_heads, | |||
v_dim, | |||
window_size=7, | |||
shift_size=0, | |||
mlp_ratio=4., | |||
qkv_bias=True, | |||
qk_scale=None, | |||
drop=0., | |||
attn_drop=0., | |||
drop_path=0., | |||
act_layer=nn.GELU, | |||
norm_layer=nn.LayerNorm): | |||
super().__init__() | |||
self.dim = dim | |||
self.num_heads = num_heads | |||
self.v_dim = v_dim | |||
self.window_size = window_size | |||
self.shift_size = shift_size | |||
self.mlp_ratio = mlp_ratio | |||
assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size' | |||
self.norm1 = norm_layer(dim) | |||
self.attn = WindowAttention( | |||
dim, | |||
window_size=to_2tuple(self.window_size), | |||
num_heads=num_heads, | |||
v_dim=v_dim, | |||
qkv_bias=qkv_bias, | |||
qk_scale=qk_scale, | |||
attn_drop=attn_drop, | |||
proj_drop=drop) | |||
self.drop_path = DropPath( | |||
drop_path) if drop_path > 0. else nn.Identity() | |||
self.norm2 = norm_layer(v_dim) | |||
mlp_hidden_dim = int(v_dim * mlp_ratio) | |||
self.mlp = Mlp( | |||
in_features=v_dim, | |||
hidden_features=mlp_hidden_dim, | |||
act_layer=act_layer, | |||
drop=drop) | |||
self.H = None | |||
self.W = None | |||
def forward(self, x, v, mask_matrix): | |||
""" Forward function. | |||
Args: | |||
x: Input feature, tensor size (B, H*W, C). | |||
H, W: Spatial resolution of the input feature. | |||
mask_matrix: Attention mask for cyclic shift. | |||
""" | |||
B, L, C = x.shape | |||
H, W = self.H, self.W | |||
assert L == H * W, 'input feature has wrong size' | |||
shortcut = x | |||
x = self.norm1(x) | |||
x = x.view(B, H, W, C) | |||
# pad feature maps to multiples of window size | |||
pad_l = pad_t = 0 | |||
pad_r = (self.window_size - W % self.window_size) % self.window_size | |||
pad_b = (self.window_size - H % self.window_size) % self.window_size | |||
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) | |||
v = F.pad(v, (0, 0, pad_l, pad_r, pad_t, pad_b)) | |||
_, Hp, Wp, _ = x.shape | |||
# cyclic shift | |||
if self.shift_size > 0: | |||
shifted_x = torch.roll( | |||
x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) | |||
shifted_v = torch.roll( | |||
v, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) | |||
attn_mask = mask_matrix | |||
else: | |||
shifted_x = x | |||
shifted_v = v | |||
attn_mask = None | |||
# partition windows | |||
x_windows = window_partition( | |||
shifted_x, self.window_size) # nW*B, window_size, window_size, C | |||
x_windows = x_windows.view(-1, self.window_size * self.window_size, | |||
C) # nW*B, window_size*window_size, C | |||
v_windows = window_partition( | |||
shifted_v, self.window_size) # nW*B, window_size, window_size, C | |||
v_windows = v_windows.view( | |||
-1, self.window_size * self.window_size, | |||
v_windows.shape[-1]) # nW*B, window_size*window_size, C | |||
# W-MSA/SW-MSA | |||
attn_windows = self.attn( | |||
x_windows, v_windows, | |||
mask=attn_mask) # nW*B, window_size*window_size, C | |||
# merge windows | |||
attn_windows = attn_windows.view(-1, self.window_size, | |||
self.window_size, self.v_dim) | |||
shifted_x = window_reverse(attn_windows, self.window_size, Hp, | |||
Wp) # B H' W' C | |||
# reverse cyclic shift | |||
if self.shift_size > 0: | |||
x = torch.roll( | |||
shifted_x, | |||
shifts=(self.shift_size, self.shift_size), | |||
dims=(1, 2)) | |||
else: | |||
x = shifted_x | |||
if pad_r > 0 or pad_b > 0: | |||
x = x[:, :H, :W, :].contiguous() | |||
x = x.view(B, H * W, self.v_dim) | |||
# FFN | |||
x = shortcut + self.drop_path(x) | |||
x = x + self.drop_path(self.mlp(self.norm2(x))) | |||
return x | |||
class BasicCRFLayer(nn.Module): | |||
""" A basic NeWCRFs layer for one stage. | |||
Args: | |||
dim (int): Number of feature channels | |||
depth (int): Depths of this stage. | |||
num_heads (int): Number of attention head. | |||
window_size (int): Local window size. Default: 7. | |||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. | |||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True | |||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. | |||
drop (float, optional): Dropout rate. Default: 0.0 | |||
attn_drop (float, optional): Attention dropout rate. Default: 0.0 | |||
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 | |||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm | |||
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None | |||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. | |||
""" | |||
def __init__(self, | |||
dim, | |||
depth, | |||
num_heads, | |||
v_dim, | |||
window_size=7, | |||
mlp_ratio=4., | |||
qkv_bias=True, | |||
qk_scale=None, | |||
drop=0., | |||
attn_drop=0., | |||
drop_path=0., | |||
norm_layer=nn.LayerNorm, | |||
downsample=None, | |||
use_checkpoint=False): | |||
super().__init__() | |||
self.window_size = window_size | |||
self.shift_size = window_size // 2 | |||
self.depth = depth | |||
self.use_checkpoint = use_checkpoint | |||
# build blocks | |||
self.blocks = nn.ModuleList([ | |||
CRFBlock( | |||
dim=dim, | |||
num_heads=num_heads, | |||
v_dim=v_dim, | |||
window_size=window_size, | |||
shift_size=0 if (i % 2 == 0) else window_size // 2, | |||
mlp_ratio=mlp_ratio, | |||
qkv_bias=qkv_bias, | |||
qk_scale=qk_scale, | |||
drop=drop, | |||
attn_drop=attn_drop, | |||
drop_path=drop_path[i] | |||
if isinstance(drop_path, list) else drop_path, | |||
norm_layer=norm_layer) for i in range(depth) | |||
]) | |||
# patch merging layer | |||
if downsample is not None: | |||
self.downsample = downsample(dim=dim, norm_layer=norm_layer) | |||
else: | |||
self.downsample = None | |||
def forward(self, x, v, H, W): | |||
""" Forward function. | |||
Args: | |||
x: Input feature, tensor size (B, H*W, C). | |||
H, W: Spatial resolution of the input feature. | |||
""" | |||
# calculate attention mask for SW-MSA | |||
Hp = int(np.ceil(H / self.window_size)) * self.window_size | |||
Wp = int(np.ceil(W / self.window_size)) * self.window_size | |||
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 | |||
h_slices = (slice(0, -self.window_size), | |||
slice(-self.window_size, | |||
-self.shift_size), slice(-self.shift_size, None)) | |||
w_slices = (slice(0, -self.window_size), | |||
slice(-self.window_size, | |||
-self.shift_size), slice(-self.shift_size, None)) | |||
cnt = 0 | |||
for h in h_slices: | |||
for w in w_slices: | |||
img_mask[:, h, w, :] = cnt | |||
cnt += 1 | |||
mask_windows = window_partition( | |||
img_mask, self.window_size) # nW, window_size, window_size, 1 | |||
mask_windows = mask_windows.view(-1, | |||
self.window_size * self.window_size) | |||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) | |||
attn_mask = attn_mask.masked_fill(attn_mask != 0, | |||
float(-100.0)).masked_fill( | |||
attn_mask == 0, float(0.0)) | |||
for blk in self.blocks: | |||
blk.H, blk.W = H, W | |||
if self.use_checkpoint: | |||
x = checkpoint.checkpoint(blk, x, attn_mask) | |||
else: | |||
x = blk(x, v, attn_mask) | |||
if self.downsample is not None: | |||
x_down = self.downsample(x, H, W) | |||
Wh, Ww = (H + 1) // 2, (W + 1) // 2 | |||
return x, H, W, x_down, Wh, Ww | |||
else: | |||
return x, H, W, x, H, W | |||
class NewCRF(nn.Module): | |||
def __init__(self, | |||
input_dim=96, | |||
embed_dim=96, | |||
v_dim=64, | |||
window_size=7, | |||
num_heads=4, | |||
depth=2, | |||
patch_size=4, | |||
in_chans=3, | |||
norm_layer=nn.LayerNorm, | |||
patch_norm=True): | |||
super().__init__() | |||
self.embed_dim = embed_dim | |||
self.patch_norm = patch_norm | |||
if input_dim != embed_dim: | |||
self.proj_x = nn.Conv2d(input_dim, embed_dim, 3, padding=1) | |||
else: | |||
self.proj_x = None | |||
if v_dim != embed_dim: | |||
self.proj_v = nn.Conv2d(v_dim, embed_dim, 3, padding=1) | |||
elif embed_dim % v_dim == 0: | |||
self.proj_v = None | |||
v_dim = embed_dim | |||
assert v_dim == embed_dim | |||
self.crf_layer = BasicCRFLayer( | |||
dim=embed_dim, | |||
depth=depth, | |||
num_heads=num_heads, | |||
v_dim=v_dim, | |||
window_size=window_size, | |||
mlp_ratio=4., | |||
qkv_bias=True, | |||
qk_scale=None, | |||
drop=0., | |||
attn_drop=0., | |||
drop_path=0., | |||
norm_layer=norm_layer, | |||
downsample=None, | |||
use_checkpoint=False) | |||
layer = norm_layer(embed_dim) | |||
layer_name = 'norm_crf' | |||
self.add_module(layer_name, layer) | |||
def forward(self, x, v): | |||
if self.proj_x is not None: | |||
x = self.proj_x(x) | |||
if self.proj_v is not None: | |||
v = self.proj_v(v) | |||
Wh, Ww = x.size(2), x.size(3) | |||
x = x.flatten(2).transpose(1, 2) | |||
v = v.transpose(1, 2).transpose(2, 3) | |||
x_out, H, W, x, Wh, Ww = self.crf_layer(x, v, Wh, Ww) | |||
norm_layer = getattr(self, 'norm_crf') | |||
x_out = norm_layer(x_out) | |||
out = x_out.view(-1, H, W, self.embed_dim).permute(0, 3, 1, | |||
2).contiguous() | |||
return out |
@@ -0,0 +1,272 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
import os.path as osp | |||
import pkgutil | |||
import warnings | |||
from collections import OrderedDict | |||
from importlib import import_module | |||
import torch | |||
import torch.nn as nn | |||
import torchvision | |||
from torch import distributed as dist | |||
from torch.nn import functional as F | |||
from torch.nn.parallel import DataParallel, DistributedDataParallel | |||
from torch.utils import model_zoo | |||
TORCH_VERSION = torch.__version__ | |||
def resize(input, | |||
size=None, | |||
scale_factor=None, | |||
mode='nearest', | |||
align_corners=None, | |||
warning=True): | |||
if warning: | |||
if size is not None and align_corners: | |||
input_h, input_w = tuple(int(x) for x in input.shape[2:]) | |||
output_h, output_w = tuple(int(x) for x in size) | |||
if output_h > input_h or output_w > output_h: | |||
if ((output_h > 1 and output_w > 1 and input_h > 1 | |||
and input_w > 1) and (output_h - 1) % (input_h - 1) | |||
and (output_w - 1) % (input_w - 1)): | |||
warnings.warn( | |||
f'When align_corners={align_corners}, ' | |||
'the output would more aligned if ' | |||
f'input size {(input_h, input_w)} is `x+1` and ' | |||
f'out size {(output_h, output_w)} is `nx+1`') | |||
if isinstance(size, torch.Size): | |||
size = tuple(int(x) for x in size) | |||
return F.interpolate(input, size, scale_factor, mode, align_corners) | |||
def normal_init(module, mean=0, std=1, bias=0): | |||
if hasattr(module, 'weight') and module.weight is not None: | |||
nn.init.normal_(module.weight, mean, std) | |||
if hasattr(module, 'bias') and module.bias is not None: | |||
nn.init.constant_(module.bias, bias) | |||
def is_module_wrapper(module): | |||
module_wrappers = (DataParallel, DistributedDataParallel) | |||
return isinstance(module, module_wrappers) | |||
def get_dist_info(): | |||
if TORCH_VERSION < '1.0': | |||
initialized = dist._initialized | |||
else: | |||
if dist.is_available(): | |||
initialized = dist.is_initialized() | |||
else: | |||
initialized = False | |||
if initialized: | |||
rank = dist.get_rank() | |||
world_size = dist.get_world_size() | |||
else: | |||
rank = 0 | |||
world_size = 1 | |||
return rank, world_size | |||
def load_state_dict(module, state_dict, strict=False, logger=None): | |||
"""Load state_dict to a module. | |||
This method is modified from :meth:`torch.nn.Module.load_state_dict`. | |||
Default value for ``strict`` is set to ``False`` and the message for | |||
param mismatch will be shown even if strict is False. | |||
Args: | |||
module (Module): Module that receives the state_dict. | |||
state_dict (OrderedDict): Weights. | |||
strict (bool): whether to strictly enforce that the keys | |||
in :attr:`state_dict` match the keys returned by this module's | |||
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``. | |||
logger (:obj:`logging.Logger`, optional): Logger to log the error | |||
message. If not specified, print function will be used. | |||
""" | |||
unexpected_keys = [] | |||
all_missing_keys = [] | |||
err_msg = [] | |||
metadata = getattr(state_dict, '_metadata', None) | |||
state_dict = state_dict.copy() | |||
if metadata is not None: | |||
state_dict._metadata = metadata | |||
# use _load_from_state_dict to enable checkpoint version control | |||
def load(module, prefix=''): | |||
# recursively check parallel module in case that the model has a | |||
# complicated structure, e.g., nn.Module(nn.Module(DDP)) | |||
if is_module_wrapper(module): | |||
module = module.module | |||
local_metadata = {} if metadata is None else metadata.get( | |||
prefix[:-1], {}) | |||
module._load_from_state_dict(state_dict, prefix, local_metadata, True, | |||
all_missing_keys, unexpected_keys, | |||
err_msg) | |||
for name, child in module._modules.items(): | |||
if child is not None: | |||
load(child, prefix + name + '.') | |||
load(module) | |||
load = None # break load->load reference cycle | |||
# ignore "num_batches_tracked" of BN layers | |||
missing_keys = [ | |||
key for key in all_missing_keys if 'num_batches_tracked' not in key | |||
] | |||
if unexpected_keys: | |||
err_msg.append('unexpected key in source ' | |||
f'state_dict: {", ".join(unexpected_keys)}\n') | |||
if missing_keys: | |||
err_msg.append( | |||
f'missing keys in source state_dict: {", ".join(missing_keys)}\n') | |||
rank, _ = get_dist_info() | |||
if len(err_msg) > 0 and rank == 0: | |||
err_msg.insert( | |||
0, 'The model and loaded state dict do not match exactly\n') | |||
err_msg = '\n'.join(err_msg) | |||
if strict: | |||
raise RuntimeError(err_msg) | |||
elif logger is not None: | |||
logger.warning(err_msg) | |||
else: | |||
print(err_msg) | |||
def load_url_dist(url, model_dir=None): | |||
"""In distributed setting, this function only download checkpoint at local | |||
rank 0.""" | |||
rank, world_size = get_dist_info() | |||
rank = int(os.environ.get('LOCAL_RANK', rank)) | |||
if rank == 0: | |||
checkpoint = model_zoo.load_url(url, model_dir=model_dir) | |||
if world_size > 1: | |||
torch.distributed.barrier() | |||
if rank > 0: | |||
checkpoint = model_zoo.load_url(url, model_dir=model_dir) | |||
return checkpoint | |||
def get_torchvision_models(): | |||
model_urls = dict() | |||
for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__): | |||
if ispkg: | |||
continue | |||
_zoo = import_module(f'torchvision.models.{name}') | |||
if hasattr(_zoo, 'model_urls'): | |||
_urls = getattr(_zoo, 'model_urls') | |||
model_urls.update(_urls) | |||
return model_urls | |||
def _load_checkpoint(filename, map_location=None): | |||
"""Load checkpoint from somewhere (modelzoo, file, url). | |||
Args: | |||
filename (str): Accept local filepath, URL, ``torchvision://xxx``, | |||
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for | |||
details. | |||
map_location (str | None): Same as :func:`torch.load`. Default: None. | |||
Returns: | |||
dict | OrderedDict: The loaded checkpoint. It can be either an | |||
OrderedDict storing model weights or a dict containing other | |||
information, which depends on the checkpoint. | |||
""" | |||
if filename.startswith('modelzoo://'): | |||
warnings.warn('The URL scheme of "modelzoo://" is deprecated, please ' | |||
'use "torchvision://" instead') | |||
model_urls = get_torchvision_models() | |||
model_name = filename[11:] | |||
checkpoint = load_url_dist(model_urls[model_name]) | |||
else: | |||
if not osp.isfile(filename): | |||
raise IOError(f'{filename} is not a checkpoint file') | |||
checkpoint = torch.load(filename, map_location=map_location) | |||
return checkpoint | |||
def load_checkpoint(model, | |||
filename, | |||
map_location='cpu', | |||
strict=False, | |||
logger=None): | |||
"""Load checkpoint from a file or URI. | |||
Args: | |||
model (Module): Module to load checkpoint. | |||
filename (str): Accept local filepath, URL, ``torchvision://xxx``, | |||
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for | |||
details. | |||
map_location (str): Same as :func:`torch.load`. | |||
strict (bool): Whether to allow different params for the model and | |||
checkpoint. | |||
logger (:mod:`logging.Logger` or None): The logger for error message. | |||
Returns: | |||
dict or OrderedDict: The loaded checkpoint. | |||
""" | |||
checkpoint = _load_checkpoint(filename, map_location) | |||
# OrderedDict is a subclass of dict | |||
if not isinstance(checkpoint, dict): | |||
raise RuntimeError( | |||
f'No state_dict found in checkpoint file {filename}') | |||
# get state_dict from checkpoint | |||
if 'state_dict' in checkpoint: | |||
state_dict = checkpoint['state_dict'] | |||
elif 'model' in checkpoint: | |||
state_dict = checkpoint['model'] | |||
else: | |||
state_dict = checkpoint | |||
# strip prefix of state_dict | |||
if list(state_dict.keys())[0].startswith('module.'): | |||
state_dict = {k[7:]: v for k, v in state_dict.items()} | |||
# for MoBY, load model of online branch | |||
if sorted(list(state_dict.keys()))[0].startswith('encoder'): | |||
state_dict = { | |||
k.replace('encoder.', ''): v | |||
for k, v in state_dict.items() if k.startswith('encoder.') | |||
} | |||
# reshape absolute position embedding | |||
if state_dict.get('absolute_pos_embed') is not None: | |||
absolute_pos_embed = state_dict['absolute_pos_embed'] | |||
N1, L, C1 = absolute_pos_embed.size() | |||
N2, C2, H, W = model.absolute_pos_embed.size() | |||
if N1 != N2 or C1 != C2 or L != H * W: | |||
logger.warning('Error in loading absolute_pos_embed, pass') | |||
else: | |||
state_dict['absolute_pos_embed'] = absolute_pos_embed.view( | |||
N2, H, W, C2).permute(0, 3, 1, 2) | |||
# interpolate position bias table if needed | |||
relative_position_bias_table_keys = [ | |||
k for k in state_dict.keys() if 'relative_position_bias_table' in k | |||
] | |||
for table_key in relative_position_bias_table_keys: | |||
table_pretrained = state_dict[table_key] | |||
table_current = model.state_dict()[table_key] | |||
L1, nH1 = table_pretrained.size() | |||
L2, nH2 = table_current.size() | |||
if nH1 != nH2: | |||
logger.warning(f'Error in loading {table_key}, pass') | |||
else: | |||
if L1 != L2: | |||
S1 = int(L1**0.5) | |||
S2 = int(L2**0.5) | |||
table_pretrained_resized = F.interpolate( | |||
table_pretrained.permute(1, 0).view(1, nH1, S1, S1), | |||
size=(S2, S2), | |||
mode='bicubic') | |||
state_dict[table_key] = table_pretrained_resized.view( | |||
nH2, L2).permute(1, 0) | |||
# load state_dict | |||
load_state_dict(model, state_dict, strict, logger) | |||
return checkpoint |
@@ -0,0 +1,706 @@ | |||
# The implementation is adopted from Swin Transformer | |||
# made publicly available under the MIT License at https://github.com/microsoft/Swin-Transformer | |||
import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
import torch.utils.checkpoint as checkpoint | |||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_ | |||
from .newcrf_utils import load_checkpoint | |||
class Mlp(nn.Module): | |||
""" Multilayer perceptron.""" | |||
def __init__(self, | |||
in_features, | |||
hidden_features=None, | |||
out_features=None, | |||
act_layer=nn.GELU, | |||
drop=0.): | |||
super().__init__() | |||
out_features = out_features or in_features | |||
hidden_features = hidden_features or in_features | |||
self.fc1 = nn.Linear(in_features, hidden_features) | |||
self.act = act_layer() | |||
self.fc2 = nn.Linear(hidden_features, out_features) | |||
self.drop = nn.Dropout(drop) | |||
def forward(self, x): | |||
x = self.fc1(x) | |||
x = self.act(x) | |||
x = self.drop(x) | |||
x = self.fc2(x) | |||
x = self.drop(x) | |||
return x | |||
def window_partition(x, window_size): | |||
""" | |||
Args: | |||
x: (B, H, W, C) | |||
window_size (int): window size | |||
Returns: | |||
windows: (num_windows*B, window_size, window_size, C) | |||
""" | |||
B, H, W, C = x.shape | |||
x = x.view(B, H // window_size, window_size, W // window_size, window_size, | |||
C) | |||
windows = x.permute(0, 1, 3, 2, 4, | |||
5).contiguous().view(-1, window_size, window_size, C) | |||
return windows | |||
def window_reverse(windows, window_size, H, W): | |||
""" | |||
Args: | |||
windows: (num_windows*B, window_size, window_size, C) | |||
window_size (int): Window size | |||
H (int): Height of image | |||
W (int): Width of image | |||
Returns: | |||
x: (B, H, W, C) | |||
""" | |||
B = int(windows.shape[0] / (H * W / window_size / window_size)) | |||
x = windows.view(B, H // window_size, W // window_size, window_size, | |||
window_size, -1) | |||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) | |||
return x | |||
class WindowAttention(nn.Module): | |||
""" Window based multi-head self attention (W-MSA) module with relative position bias. | |||
It supports both of shifted and non-shifted window. | |||
Args: | |||
dim (int): Number of input channels. | |||
window_size (tuple[int]): The height and width of the window. | |||
num_heads (int): Number of attention heads. | |||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True | |||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set | |||
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 | |||
proj_drop (float, optional): Dropout ratio of output. Default: 0.0 | |||
""" | |||
def __init__(self, | |||
dim, | |||
window_size, | |||
num_heads, | |||
qkv_bias=True, | |||
qk_scale=None, | |||
attn_drop=0., | |||
proj_drop=0.): | |||
super().__init__() | |||
self.dim = dim | |||
self.window_size = window_size # Wh, Ww | |||
self.num_heads = num_heads | |||
head_dim = dim // num_heads | |||
self.scale = qk_scale or head_dim**-0.5 | |||
# define a parameter table of relative position bias | |||
self.relative_position_bias_table = nn.Parameter( | |||
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), | |||
num_heads)) # 2*Wh-1 * 2*Ww-1, nH | |||
# get pair-wise relative position index for each token inside the window | |||
coords_h = torch.arange(self.window_size[0]) | |||
coords_w = torch.arange(self.window_size[1]) | |||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww | |||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww | |||
relative_coords = coords_flatten[:, :, | |||
None] - coords_flatten[:, | |||
None, :] # 2, Wh*Ww, Wh*Ww | |||
relative_coords = relative_coords.permute( | |||
1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 | |||
relative_coords[:, :, | |||
0] += self.window_size[0] - 1 # shift to start from 0 | |||
relative_coords[:, :, 1] += self.window_size[1] - 1 | |||
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 | |||
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww | |||
self.register_buffer('relative_position_index', | |||
relative_position_index) | |||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |||
self.attn_drop = nn.Dropout(attn_drop) | |||
self.proj = nn.Linear(dim, dim) | |||
self.proj_drop = nn.Dropout(proj_drop) | |||
trunc_normal_(self.relative_position_bias_table, std=.02) | |||
self.softmax = nn.Softmax(dim=-1) | |||
def forward(self, x, mask=None): | |||
""" Forward function. | |||
Args: | |||
x: input features with shape of (num_windows*B, N, C) | |||
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None | |||
""" | |||
B_, N, C = x.shape | |||
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, | |||
C // self.num_heads).permute(2, 0, 3, 1, 4) | |||
q, k, v = qkv[0], qkv[1], qkv[ | |||
2] # make torchscript happy (cannot use tensor as tuple) | |||
q = q * self.scale | |||
attn = (q @ k.transpose(-2, -1)) | |||
relative_position_bias = self.relative_position_bias_table[ | |||
self.relative_position_index.view(-1)].view( | |||
self.window_size[0] * self.window_size[1], | |||
self.window_size[0] * self.window_size[1], | |||
-1) # Wh*Ww,Wh*Ww,nH | |||
relative_position_bias = relative_position_bias.permute( | |||
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww | |||
attn = attn + relative_position_bias.unsqueeze(0) | |||
if mask is not None: | |||
nW = mask.shape[0] | |||
attn = attn.view(B_ // nW, nW, self.num_heads, N, | |||
N) + mask.unsqueeze(1).unsqueeze(0) | |||
attn = attn.view(-1, self.num_heads, N, N) | |||
attn = self.softmax(attn) | |||
else: | |||
attn = self.softmax(attn) | |||
attn = self.attn_drop(attn) | |||
x = (attn @ v).transpose(1, 2).reshape(B_, N, C) | |||
x = self.proj(x) | |||
x = self.proj_drop(x) | |||
return x | |||
class SwinTransformerBlock(nn.Module): | |||
""" Swin Transformer Block. | |||
Args: | |||
dim (int): Number of input channels. | |||
num_heads (int): Number of attention heads. | |||
window_size (int): Window size. | |||
shift_size (int): Shift size for SW-MSA. | |||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. | |||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True | |||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. | |||
drop (float, optional): Dropout rate. Default: 0.0 | |||
attn_drop (float, optional): Attention dropout rate. Default: 0.0 | |||
drop_path (float, optional): Stochastic depth rate. Default: 0.0 | |||
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU | |||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm | |||
""" | |||
def __init__(self, | |||
dim, | |||
num_heads, | |||
window_size=7, | |||
shift_size=0, | |||
mlp_ratio=4., | |||
qkv_bias=True, | |||
qk_scale=None, | |||
drop=0., | |||
attn_drop=0., | |||
drop_path=0., | |||
act_layer=nn.GELU, | |||
norm_layer=nn.LayerNorm): | |||
super().__init__() | |||
self.dim = dim | |||
self.num_heads = num_heads | |||
self.window_size = window_size | |||
self.shift_size = shift_size | |||
self.mlp_ratio = mlp_ratio | |||
assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size' | |||
self.norm1 = norm_layer(dim) | |||
self.attn = WindowAttention( | |||
dim, | |||
window_size=to_2tuple(self.window_size), | |||
num_heads=num_heads, | |||
qkv_bias=qkv_bias, | |||
qk_scale=qk_scale, | |||
attn_drop=attn_drop, | |||
proj_drop=drop) | |||
self.drop_path = DropPath( | |||
drop_path) if drop_path > 0. else nn.Identity() | |||
self.norm2 = norm_layer(dim) | |||
mlp_hidden_dim = int(dim * mlp_ratio) | |||
self.mlp = Mlp( | |||
in_features=dim, | |||
hidden_features=mlp_hidden_dim, | |||
act_layer=act_layer, | |||
drop=drop) | |||
self.H = None | |||
self.W = None | |||
def forward(self, x, mask_matrix): | |||
""" Forward function. | |||
Args: | |||
x: Input feature, tensor size (B, H*W, C). | |||
H, W: Spatial resolution of the input feature. | |||
mask_matrix: Attention mask for cyclic shift. | |||
""" | |||
B, L, C = x.shape | |||
H, W = self.H, self.W | |||
assert L == H * W, 'input feature has wrong size' | |||
shortcut = x | |||
x = self.norm1(x) | |||
x = x.view(B, H, W, C) | |||
# pad feature maps to multiples of window size | |||
pad_l = pad_t = 0 | |||
pad_r = (self.window_size - W % self.window_size) % self.window_size | |||
pad_b = (self.window_size - H % self.window_size) % self.window_size | |||
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) | |||
_, Hp, Wp, _ = x.shape | |||
# cyclic shift | |||
if self.shift_size > 0: | |||
shifted_x = torch.roll( | |||
x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) | |||
attn_mask = mask_matrix | |||
else: | |||
shifted_x = x | |||
attn_mask = None | |||
# partition windows | |||
x_windows = window_partition( | |||
shifted_x, self.window_size) # nW*B, window_size, window_size, C | |||
x_windows = x_windows.view(-1, self.window_size * self.window_size, | |||
C) # nW*B, window_size*window_size, C | |||
# W-MSA/SW-MSA | |||
attn_windows = self.attn( | |||
x_windows, mask=attn_mask) # nW*B, window_size*window_size, C | |||
# merge windows | |||
attn_windows = attn_windows.view(-1, self.window_size, | |||
self.window_size, C) | |||
shifted_x = window_reverse(attn_windows, self.window_size, Hp, | |||
Wp) # B H' W' C | |||
# reverse cyclic shift | |||
if self.shift_size > 0: | |||
x = torch.roll( | |||
shifted_x, | |||
shifts=(self.shift_size, self.shift_size), | |||
dims=(1, 2)) | |||
else: | |||
x = shifted_x | |||
if pad_r > 0 or pad_b > 0: | |||
x = x[:, :H, :W, :].contiguous() | |||
x = x.view(B, H * W, C) | |||
# FFN | |||
x = shortcut + self.drop_path(x) | |||
x = x + self.drop_path(self.mlp(self.norm2(x))) | |||
return x | |||
class PatchMerging(nn.Module): | |||
""" Patch Merging Layer | |||
Args: | |||
dim (int): Number of input channels. | |||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm | |||
""" | |||
def __init__(self, dim, norm_layer=nn.LayerNorm): | |||
super().__init__() | |||
self.dim = dim | |||
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) | |||
self.norm = norm_layer(4 * dim) | |||
def forward(self, x, H, W): | |||
""" Forward function. | |||
Args: | |||
x: Input feature, tensor size (B, H*W, C). | |||
H, W: Spatial resolution of the input feature. | |||
""" | |||
B, L, C = x.shape | |||
assert L == H * W, 'input feature has wrong size' | |||
x = x.view(B, H, W, C) | |||
# padding | |||
pad_input = (H % 2 == 1) or (W % 2 == 1) | |||
if pad_input: | |||
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) | |||
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C | |||
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C | |||
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C | |||
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C | |||
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C | |||
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C | |||
x = self.norm(x) | |||
x = self.reduction(x) | |||
return x | |||
class BasicLayer(nn.Module): | |||
""" A basic Swin Transformer layer for one stage. | |||
Args: | |||
dim (int): Number of feature channels | |||
depth (int): Depths of this stage. | |||
num_heads (int): Number of attention head. | |||
window_size (int): Local window size. Default: 7. | |||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. | |||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True | |||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. | |||
drop (float, optional): Dropout rate. Default: 0.0 | |||
attn_drop (float, optional): Attention dropout rate. Default: 0.0 | |||
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 | |||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm | |||
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None | |||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. | |||
""" | |||
def __init__(self, | |||
dim, | |||
depth, | |||
num_heads, | |||
window_size=7, | |||
mlp_ratio=4., | |||
qkv_bias=True, | |||
qk_scale=None, | |||
drop=0., | |||
attn_drop=0., | |||
drop_path=0., | |||
norm_layer=nn.LayerNorm, | |||
downsample=None, | |||
use_checkpoint=False): | |||
super().__init__() | |||
self.window_size = window_size | |||
self.shift_size = window_size // 2 | |||
self.depth = depth | |||
self.use_checkpoint = use_checkpoint | |||
# build blocks | |||
self.blocks = nn.ModuleList([ | |||
SwinTransformerBlock( | |||
dim=dim, | |||
num_heads=num_heads, | |||
window_size=window_size, | |||
shift_size=0 if (i % 2 == 0) else window_size // 2, | |||
mlp_ratio=mlp_ratio, | |||
qkv_bias=qkv_bias, | |||
qk_scale=qk_scale, | |||
drop=drop, | |||
attn_drop=attn_drop, | |||
drop_path=drop_path[i] | |||
if isinstance(drop_path, list) else drop_path, | |||
norm_layer=norm_layer) for i in range(depth) | |||
]) | |||
# patch merging layer | |||
if downsample is not None: | |||
self.downsample = downsample(dim=dim, norm_layer=norm_layer) | |||
else: | |||
self.downsample = None | |||
def forward(self, x, H, W): | |||
""" Forward function. | |||
Args: | |||
x: Input feature, tensor size (B, H*W, C). | |||
H, W: Spatial resolution of the input feature. | |||
""" | |||
# calculate attention mask for SW-MSA | |||
Hp = int(np.ceil(H / self.window_size)) * self.window_size | |||
Wp = int(np.ceil(W / self.window_size)) * self.window_size | |||
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 | |||
h_slices = (slice(0, -self.window_size), | |||
slice(-self.window_size, | |||
-self.shift_size), slice(-self.shift_size, None)) | |||
w_slices = (slice(0, -self.window_size), | |||
slice(-self.window_size, | |||
-self.shift_size), slice(-self.shift_size, None)) | |||
cnt = 0 | |||
for h in h_slices: | |||
for w in w_slices: | |||
img_mask[:, h, w, :] = cnt | |||
cnt += 1 | |||
mask_windows = window_partition( | |||
img_mask, self.window_size) # nW, window_size, window_size, 1 | |||
mask_windows = mask_windows.view(-1, | |||
self.window_size * self.window_size) | |||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) | |||
attn_mask = attn_mask.masked_fill(attn_mask != 0, | |||
float(-100.0)).masked_fill( | |||
attn_mask == 0, float(0.0)) | |||
for blk in self.blocks: | |||
blk.H, blk.W = H, W | |||
if self.use_checkpoint: | |||
x = checkpoint.checkpoint(blk, x, attn_mask) | |||
else: | |||
x = blk(x, attn_mask) | |||
if self.downsample is not None: | |||
x_down = self.downsample(x, H, W) | |||
Wh, Ww = (H + 1) // 2, (W + 1) // 2 | |||
return x, H, W, x_down, Wh, Ww | |||
else: | |||
return x, H, W, x, H, W | |||
class PatchEmbed(nn.Module): | |||
""" Image to Patch Embedding | |||
Args: | |||
patch_size (int): Patch token size. Default: 4. | |||
in_chans (int): Number of input image channels. Default: 3. | |||
embed_dim (int): Number of linear projection output channels. Default: 96. | |||
norm_layer (nn.Module, optional): Normalization layer. Default: None | |||
""" | |||
def __init__(self, | |||
patch_size=4, | |||
in_chans=3, | |||
embed_dim=96, | |||
norm_layer=None): | |||
super().__init__() | |||
patch_size = to_2tuple(patch_size) | |||
self.patch_size = patch_size | |||
self.in_chans = in_chans | |||
self.embed_dim = embed_dim | |||
self.proj = nn.Conv2d( | |||
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) | |||
if norm_layer is not None: | |||
self.norm = norm_layer(embed_dim) | |||
else: | |||
self.norm = None | |||
def forward(self, x): | |||
"""Forward function.""" | |||
# padding | |||
_, _, H, W = x.size() | |||
if W % self.patch_size[1] != 0: | |||
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) | |||
if H % self.patch_size[0] != 0: | |||
x = F.pad(x, | |||
(0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) | |||
x = self.proj(x) # B C Wh Ww | |||
if self.norm is not None: | |||
Wh, Ww = x.size(2), x.size(3) | |||
x = x.flatten(2).transpose(1, 2) | |||
x = self.norm(x) | |||
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) | |||
return x | |||
class SwinTransformer(nn.Module): | |||
""" Swin Transformer backbone. | |||
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - | |||
https://arxiv.org/pdf/2103.14030 | |||
Args: | |||
pretrain_img_size (int): Input image size for training the pretrained model, | |||
used in absolute postion embedding. Default 224. | |||
patch_size (int | tuple(int)): Patch size. Default: 4. | |||
in_chans (int): Number of input image channels. Default: 3. | |||
embed_dim (int): Number of linear projection output channels. Default: 96. | |||
depths (tuple[int]): Depths of each Swin Transformer stage. | |||
num_heads (tuple[int]): Number of attention head of each stage. | |||
window_size (int): Window size. Default: 7. | |||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. | |||
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True | |||
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. | |||
drop_rate (float): Dropout rate. | |||
attn_drop_rate (float): Attention dropout rate. Default: 0. | |||
drop_path_rate (float): Stochastic depth rate. Default: 0.2. | |||
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. | |||
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. | |||
patch_norm (bool): If True, add normalization after patch embedding. Default: True. | |||
out_indices (Sequence[int]): Output from which stages. | |||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode). | |||
-1 means not freezing any parameters. | |||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. | |||
""" | |||
def __init__(self, | |||
pretrain_img_size=224, | |||
patch_size=4, | |||
in_chans=3, | |||
embed_dim=96, | |||
depths=[2, 2, 6, 2], | |||
num_heads=[3, 6, 12, 24], | |||
window_size=7, | |||
mlp_ratio=4., | |||
qkv_bias=True, | |||
qk_scale=None, | |||
drop_rate=0., | |||
attn_drop_rate=0., | |||
drop_path_rate=0.2, | |||
norm_layer=nn.LayerNorm, | |||
ape=False, | |||
patch_norm=True, | |||
out_indices=(0, 1, 2, 3), | |||
frozen_stages=-1, | |||
use_checkpoint=False): | |||
super().__init__() | |||
self.pretrain_img_size = pretrain_img_size | |||
self.num_layers = len(depths) | |||
self.embed_dim = embed_dim | |||
self.ape = ape | |||
self.patch_norm = patch_norm | |||
self.out_indices = out_indices | |||
self.frozen_stages = frozen_stages | |||
# split image into non-overlapping patches | |||
self.patch_embed = PatchEmbed( | |||
patch_size=patch_size, | |||
in_chans=in_chans, | |||
embed_dim=embed_dim, | |||
norm_layer=norm_layer if self.patch_norm else None) | |||
# absolute position embedding | |||
if self.ape: | |||
pretrain_img_size = to_2tuple(pretrain_img_size) | |||
patch_size = to_2tuple(patch_size) | |||
patches_resolution = [ | |||
pretrain_img_size[0] // patch_size[0], | |||
pretrain_img_size[1] // patch_size[1] | |||
] | |||
self.absolute_pos_embed = nn.Parameter( | |||
torch.zeros(1, embed_dim, patches_resolution[0], | |||
patches_resolution[1])) | |||
trunc_normal_(self.absolute_pos_embed, std=.02) | |||
self.pos_drop = nn.Dropout(p=drop_rate) | |||
# stochastic depth | |||
dpr = [ | |||
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) | |||
] # stochastic depth decay rule | |||
# build layers | |||
self.layers = nn.ModuleList() | |||
for i_layer in range(self.num_layers): | |||
layer = BasicLayer( | |||
dim=int(embed_dim * 2**i_layer), | |||
depth=depths[i_layer], | |||
num_heads=num_heads[i_layer], | |||
window_size=window_size, | |||
mlp_ratio=mlp_ratio, | |||
qkv_bias=qkv_bias, | |||
qk_scale=qk_scale, | |||
drop=drop_rate, | |||
attn_drop=attn_drop_rate, | |||
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], | |||
norm_layer=norm_layer, | |||
downsample=PatchMerging if | |||
(i_layer < self.num_layers - 1) else None, | |||
use_checkpoint=use_checkpoint) | |||
self.layers.append(layer) | |||
num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)] | |||
self.num_features = num_features | |||
# add a norm layer for each output | |||
for i_layer in out_indices: | |||
layer = norm_layer(num_features[i_layer]) | |||
layer_name = f'norm{i_layer}' | |||
self.add_module(layer_name, layer) | |||
self._freeze_stages() | |||
def _freeze_stages(self): | |||
if self.frozen_stages >= 0: | |||
self.patch_embed.eval() | |||
for param in self.patch_embed.parameters(): | |||
param.requires_grad = False | |||
if self.frozen_stages >= 1 and self.ape: | |||
self.absolute_pos_embed.requires_grad = False | |||
if self.frozen_stages >= 2: | |||
self.pos_drop.eval() | |||
for i in range(0, self.frozen_stages - 1): | |||
m = self.layers[i] | |||
m.eval() | |||
for param in m.parameters(): | |||
param.requires_grad = False | |||
def init_weights(self, pretrained=None): | |||
"""Initialize the weights in backbone. | |||
Args: | |||
pretrained (str, optional): Path to pre-trained weights. | |||
Defaults to None. | |||
""" | |||
def _init_weights(m): | |||
if isinstance(m, nn.Linear): | |||
trunc_normal_(m.weight, std=.02) | |||
if isinstance(m, nn.Linear) and m.bias is not None: | |||
nn.init.constant_(m.bias, 0) | |||
elif isinstance(m, nn.LayerNorm): | |||
nn.init.constant_(m.bias, 0) | |||
nn.init.constant_(m.weight, 1.0) | |||
if isinstance(pretrained, str): | |||
self.apply(_init_weights) | |||
# logger = get_root_logger() | |||
load_checkpoint(self, pretrained, strict=False) | |||
elif pretrained is None: | |||
self.apply(_init_weights) | |||
else: | |||
raise TypeError('pretrained must be a str or None') | |||
def forward(self, x): | |||
"""Forward function.""" | |||
x = self.patch_embed(x) | |||
Wh, Ww = x.size(2), x.size(3) | |||
if self.ape: | |||
# interpolate the position embedding to the corresponding size | |||
absolute_pos_embed = F.interpolate( | |||
self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic') | |||
x = (x + absolute_pos_embed).flatten(2).transpose(1, | |||
2) # B Wh*Ww C | |||
else: | |||
x = x.flatten(2).transpose(1, 2) | |||
x = self.pos_drop(x) | |||
outs = [] | |||
for i in range(self.num_layers): | |||
layer = self.layers[i] | |||
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) | |||
if i in self.out_indices: | |||
norm_layer = getattr(self, f'norm{i}') | |||
x_out = norm_layer(x_out) | |||
out = x_out.view(-1, H, W, | |||
self.num_features[i]).permute(0, 3, 1, | |||
2).contiguous() | |||
outs.append(out) | |||
return tuple(outs) | |||
def train(self, mode=True): | |||
"""Convert the model into training mode while keep layers freezed.""" | |||
super(SwinTransformer, self).train(mode) | |||
self._freeze_stages() |
@@ -0,0 +1,365 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from mmcv.cnn import ConvModule | |||
from .newcrf_utils import normal_init, resize | |||
class PPM(nn.ModuleList): | |||
"""Pooling Pyramid Module used in PSPNet. | |||
Args: | |||
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid | |||
Module. | |||
in_channels (int): Input channels. | |||
channels (int): Channels after modules, before conv_seg. | |||
conv_cfg (dict|None): Config of conv layers. | |||
norm_cfg (dict|None): Config of norm layers. | |||
act_cfg (dict): Config of activation layers. | |||
align_corners (bool): align_corners argument of F.interpolate. | |||
""" | |||
def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg, | |||
act_cfg, align_corners): | |||
super(PPM, self).__init__() | |||
self.pool_scales = pool_scales | |||
self.align_corners = align_corners | |||
self.in_channels = in_channels | |||
self.channels = channels | |||
self.conv_cfg = conv_cfg | |||
self.norm_cfg = norm_cfg | |||
self.act_cfg = act_cfg | |||
for pool_scale in pool_scales: | |||
# == if batch size = 1, BN is not supported, change to GN | |||
if pool_scale == 1: | |||
norm_cfg = dict(type='GN', requires_grad=True, num_groups=256) | |||
self.append( | |||
nn.Sequential( | |||
nn.AdaptiveAvgPool2d(pool_scale), | |||
ConvModule( | |||
self.in_channels, | |||
self.channels, | |||
1, | |||
conv_cfg=self.conv_cfg, | |||
norm_cfg=norm_cfg, | |||
act_cfg=self.act_cfg))) | |||
def forward(self, x): | |||
"""Forward function.""" | |||
ppm_outs = [] | |||
for ppm in self: | |||
ppm_out = ppm(x) | |||
upsampled_ppm_out = resize( | |||
ppm_out, | |||
size=x.size()[2:], | |||
mode='bilinear', | |||
align_corners=self.align_corners) | |||
ppm_outs.append(upsampled_ppm_out) | |||
return ppm_outs | |||
class BaseDecodeHead(nn.Module): | |||
"""Base class for BaseDecodeHead. | |||
Args: | |||
in_channels (int|Sequence[int]): Input channels. | |||
channels (int): Channels after modules, before conv_seg. | |||
num_classes (int): Number of classes. | |||
dropout_ratio (float): Ratio of dropout layer. Default: 0.1. | |||
conv_cfg (dict|None): Config of conv layers. Default: None. | |||
norm_cfg (dict|None): Config of norm layers. Default: None. | |||
act_cfg (dict): Config of activation layers. | |||
Default: dict(type='ReLU') | |||
in_index (int|Sequence[int]): Input feature index. Default: -1 | |||
input_transform (str|None): Transformation type of input features. | |||
Options: 'resize_concat', 'multiple_select', None. | |||
'resize_concat': Multiple feature maps will be resize to the | |||
same size as first one and than concat together. | |||
Usually used in FCN head of HRNet. | |||
'multiple_select': Multiple feature maps will be bundle into | |||
a list and passed into decode head. | |||
None: Only one select feature map is allowed. | |||
Default: None. | |||
loss_decode (dict): Config of decode loss. | |||
Default: dict(type='CrossEntropyLoss'). | |||
ignore_index (int | None): The label index to be ignored. When using | |||
masked BCE loss, ignore_index should be set to None. Default: 255 | |||
sampler (dict|None): The config of segmentation map sampler. | |||
Default: None. | |||
align_corners (bool): align_corners argument of F.interpolate. | |||
Default: False. | |||
""" | |||
def __init__(self, | |||
in_channels, | |||
channels, | |||
*, | |||
num_classes, | |||
dropout_ratio=0.1, | |||
conv_cfg=None, | |||
norm_cfg=None, | |||
act_cfg=dict(type='ReLU'), | |||
in_index=-1, | |||
input_transform=None, | |||
loss_decode=dict( | |||
type='CrossEntropyLoss', | |||
use_sigmoid=False, | |||
loss_weight=1.0), | |||
ignore_index=255, | |||
sampler=None, | |||
align_corners=False): | |||
super(BaseDecodeHead, self).__init__() | |||
self._init_inputs(in_channels, in_index, input_transform) | |||
self.channels = channels | |||
self.num_classes = num_classes | |||
self.dropout_ratio = dropout_ratio | |||
self.conv_cfg = conv_cfg | |||
self.norm_cfg = norm_cfg | |||
self.act_cfg = act_cfg | |||
self.in_index = in_index | |||
# self.loss_decode = build_loss(loss_decode) | |||
self.ignore_index = ignore_index | |||
self.align_corners = align_corners | |||
# if sampler is not None: | |||
# self.sampler = build_pixel_sampler(sampler, context=self) | |||
# else: | |||
# self.sampler = None | |||
# self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1) | |||
# self.conv1 = nn.Conv2d(channels, num_classes, 3, padding=1) | |||
if dropout_ratio > 0: | |||
self.dropout = nn.Dropout2d(dropout_ratio) | |||
else: | |||
self.dropout = None | |||
self.fp16_enabled = False | |||
def extra_repr(self): | |||
"""Extra repr.""" | |||
s = f'input_transform={self.input_transform}, ' \ | |||
f'ignore_index={self.ignore_index}, ' \ | |||
f'align_corners={self.align_corners}' | |||
return s | |||
def _init_inputs(self, in_channels, in_index, input_transform): | |||
"""Check and initialize input transforms. | |||
The in_channels, in_index and input_transform must match. | |||
Specifically, when input_transform is None, only single feature map | |||
will be selected. So in_channels and in_index must be of type int. | |||
When input_transform | |||
Args: | |||
in_channels (int|Sequence[int]): Input channels. | |||
in_index (int|Sequence[int]): Input feature index. | |||
input_transform (str|None): Transformation type of input features. | |||
Options: 'resize_concat', 'multiple_select', None. | |||
'resize_concat': Multiple feature maps will be resize to the | |||
same size as first one and than concat together. | |||
Usually used in FCN head of HRNet. | |||
'multiple_select': Multiple feature maps will be bundle into | |||
a list and passed into decode head. | |||
None: Only one select feature map is allowed. | |||
""" | |||
if input_transform is not None: | |||
assert input_transform in ['resize_concat', 'multiple_select'] | |||
self.input_transform = input_transform | |||
self.in_index = in_index | |||
if input_transform is not None: | |||
assert isinstance(in_channels, (list, tuple)) | |||
assert isinstance(in_index, (list, tuple)) | |||
assert len(in_channels) == len(in_index) | |||
if input_transform == 'resize_concat': | |||
self.in_channels = sum(in_channels) | |||
else: | |||
self.in_channels = in_channels | |||
else: | |||
assert isinstance(in_channels, int) | |||
assert isinstance(in_index, int) | |||
self.in_channels = in_channels | |||
def init_weights(self): | |||
"""Initialize weights of classification layer.""" | |||
# normal_init(self.conv_seg, mean=0, std=0.01) | |||
# normal_init(self.conv1, mean=0, std=0.01) | |||
def _transform_inputs(self, inputs): | |||
"""Transform inputs for decoder. | |||
Args: | |||
inputs (list[Tensor]): List of multi-level img features. | |||
Returns: | |||
Tensor: The transformed inputs | |||
""" | |||
if self.input_transform == 'resize_concat': | |||
inputs = [inputs[i] for i in self.in_index] | |||
upsampled_inputs = [ | |||
resize( | |||
input=x, | |||
size=inputs[0].shape[2:], | |||
mode='bilinear', | |||
align_corners=self.align_corners) for x in inputs | |||
] | |||
inputs = torch.cat(upsampled_inputs, dim=1) | |||
elif self.input_transform == 'multiple_select': | |||
inputs = [inputs[i] for i in self.in_index] | |||
else: | |||
inputs = inputs[self.in_index] | |||
return inputs | |||
def forward(self, inputs): | |||
"""Placeholder of forward function.""" | |||
pass | |||
def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg): | |||
"""Forward function for training. | |||
Args: | |||
inputs (list[Tensor]): List of multi-level img features. | |||
img_metas (list[dict]): List of image info dict where each dict | |||
has: 'img_shape', 'scale_factor', 'flip', and may also contain | |||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. | |||
For details on the values of these keys see | |||
`mmseg/datasets/pipelines/formatting.py:Collect`. | |||
gt_semantic_seg (Tensor): Semantic segmentation masks | |||
used if the architecture supports semantic segmentation task. | |||
train_cfg (dict): The training config. | |||
Returns: | |||
dict[str, Tensor]: a dictionary of loss components | |||
""" | |||
seg_logits = self.forward(inputs) | |||
losses = self.losses(seg_logits, gt_semantic_seg) | |||
return losses | |||
def forward_test(self, inputs, img_metas, test_cfg): | |||
"""Forward function for testing. | |||
Args: | |||
inputs (list[Tensor]): List of multi-level img features. | |||
img_metas (list[dict]): List of image info dict where each dict | |||
has: 'img_shape', 'scale_factor', 'flip', and may also contain | |||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. | |||
For details on the values of these keys see | |||
`mmseg/datasets/pipelines/formatting.py:Collect`. | |||
test_cfg (dict): The testing config. | |||
Returns: | |||
Tensor: Output segmentation map. | |||
""" | |||
return self.forward(inputs) | |||
class UPerHead(BaseDecodeHead): | |||
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): | |||
super(UPerHead, self).__init__( | |||
input_transform='multiple_select', **kwargs) | |||
# FPN Module | |||
self.lateral_convs = nn.ModuleList() | |||
self.fpn_convs = nn.ModuleList() | |||
for in_channels in self.in_channels: # skip the top layer | |||
l_conv = ConvModule( | |||
in_channels, | |||
self.channels, | |||
1, | |||
conv_cfg=self.conv_cfg, | |||
norm_cfg=self.norm_cfg, | |||
act_cfg=self.act_cfg, | |||
inplace=True) | |||
fpn_conv = ConvModule( | |||
self.channels, | |||
self.channels, | |||
3, | |||
padding=1, | |||
conv_cfg=self.conv_cfg, | |||
norm_cfg=self.norm_cfg, | |||
act_cfg=self.act_cfg, | |||
inplace=True) | |||
self.lateral_convs.append(l_conv) | |||
self.fpn_convs.append(fpn_conv) | |||
def forward(self, inputs): | |||
"""Forward function.""" | |||
inputs = self._transform_inputs(inputs) | |||
# build laterals | |||
laterals = [ | |||
lateral_conv(inputs[i]) | |||
for i, lateral_conv in enumerate(self.lateral_convs) | |||
] | |||
# laterals.append(self.psp_forward(inputs)) | |||
# build top-down path | |||
used_backbone_levels = len(laterals) | |||
for i in range(used_backbone_levels - 1, 0, -1): | |||
prev_shape = laterals[i - 1].shape[2:] | |||
laterals[i - 1] += resize( | |||
laterals[i], | |||
size=prev_shape, | |||
mode='bilinear', | |||
align_corners=self.align_corners) | |||
# build outputs | |||
fpn_outs = [ | |||
self.fpn_convs[i](laterals[i]) | |||
for i in range(used_backbone_levels - 1) | |||
] | |||
# append psp feature | |||
fpn_outs.append(laterals[-1]) | |||
return fpn_outs[0] | |||
class PSP(BaseDecodeHead): | |||
"""Unified Perceptual Parsing for Scene Understanding. | |||
This head is the implementation of `UPerNet | |||
<https://arxiv.org/abs/1807.10221>`_. | |||
Args: | |||
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid | |||
Module applied on the last feature. Default: (1, 2, 3, 6). | |||
""" | |||
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): | |||
super(PSP, self).__init__(input_transform='multiple_select', **kwargs) | |||
# PSP Module | |||
self.psp_modules = PPM( | |||
pool_scales, | |||
self.in_channels[-1], | |||
self.channels, | |||
conv_cfg=self.conv_cfg, | |||
norm_cfg=self.norm_cfg, | |||
act_cfg=self.act_cfg, | |||
align_corners=self.align_corners) | |||
self.bottleneck = ConvModule( | |||
self.in_channels[-1] + len(pool_scales) * self.channels, | |||
self.channels, | |||
3, | |||
padding=1, | |||
conv_cfg=self.conv_cfg, | |||
norm_cfg=self.norm_cfg, | |||
act_cfg=self.act_cfg) | |||
def psp_forward(self, inputs): | |||
"""Forward function of PSP module.""" | |||
x = inputs[-1] | |||
psp_outs = [x] | |||
psp_outs.extend(self.psp_modules(x)) | |||
psp_outs = torch.cat(psp_outs, dim=1) | |||
output = self.bottleneck(psp_outs) | |||
return output | |||
def forward(self, inputs): | |||
"""Forward function.""" | |||
inputs = self._transform_inputs(inputs) | |||
return self.psp_forward(inputs) |
@@ -0,0 +1,53 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os.path as osp | |||
import numpy as np | |||
import torch | |||
from modelscope.metainfo import Models | |||
from modelscope.models.base.base_torch_model import TorchModel | |||
from modelscope.models.builder import MODELS | |||
from modelscope.models.cv.image_depth_estimation.networks.newcrf_depth import \ | |||
NewCRFDepth | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
@MODELS.register_module( | |||
Tasks.image_depth_estimation, module_name=Models.newcrfs_depth_estimation) | |||
class DepthEstimation(TorchModel): | |||
def __init__(self, model_dir: str, **kwargs): | |||
"""str -- model file root.""" | |||
super().__init__(model_dir, **kwargs) | |||
# build model | |||
self.model = NewCRFDepth( | |||
version='large07', inv_depth=False, max_depth=10) | |||
# load model | |||
model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) | |||
checkpoint = torch.load(model_path) | |||
state_dict = {} | |||
for k in checkpoint['model'].keys(): | |||
if k.startswith('module.'): | |||
state_dict[k[7:]] = checkpoint['model'][k] | |||
else: | |||
state_dict[k] = checkpoint['model'][k] | |||
self.model.load_state_dict(state_dict) | |||
self.model.eval() | |||
def forward(self, Inputs): | |||
return self.model(Inputs['imgs']) | |||
def postprocess(self, Inputs): | |||
depth_result = Inputs | |||
results = {OutputKeys.DEPTHS: depth_result} | |||
return results | |||
def inference(self, data): | |||
results = self.forward(data) | |||
return results |
@@ -19,6 +19,7 @@ class OutputKeys(object): | |||
BOXES = 'boxes' | |||
KEYPOINTS = 'keypoints' | |||
MASKS = 'masks' | |||
DEPTHS = 'depths' | |||
TEXT = 'text' | |||
POLYGONS = 'polygons' | |||
OUTPUT = 'output' | |||
@@ -147,6 +147,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
Tasks.image_segmentation: | |||
(Pipelines.image_instance_segmentation, | |||
'damo/cv_swin-b_image-instance-segmentation_coco'), | |||
Tasks.image_depth_estimation: | |||
(Pipelines.image_depth_estimation, | |||
'damo/cv_newcrfs_image-depth-estimation_indoor'), | |||
Tasks.image_style_transfer: (Pipelines.image_style_transfer, | |||
'damo/cv_aams_style-transfer_damo'), | |||
Tasks.face_image_generation: (Pipelines.face_image_generation, | |||
@@ -0,0 +1,52 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import Any, Dict, Union | |||
import cv2 | |||
import numpy as np | |||
import PIL | |||
import torch | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.base import Input, Model, Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import LoadImage | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
@PIPELINES.register_module( | |||
Tasks.image_depth_estimation, module_name=Pipelines.image_depth_estimation) | |||
class ImageDepthEstimationPipeline(Pipeline): | |||
def __init__(self, model: str, **kwargs): | |||
""" | |||
use `model` to create a image depth estimation pipeline for prediction | |||
Args: | |||
model: model id on modelscope hub. | |||
""" | |||
super().__init__(model=model, **kwargs) | |||
logger.info('depth estimation model, pipeline init') | |||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||
img = LoadImage.convert_to_ndarray(input).astype(np.float32) | |||
H, W = 480, 640 | |||
img = cv2.resize(img, [W, H]) | |||
img = img.transpose(2, 0, 1) / 255.0 | |||
imgs = img[None, ...] | |||
data = {'imgs': imgs} | |||
return data | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
results = self.model.inference(input) | |||
return results | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
results = self.model.postprocess(inputs) | |||
outputs = {OutputKeys.DEPTHS: results[OutputKeys.DEPTHS]} | |||
return outputs |
@@ -44,6 +44,7 @@ class CVTasks(object): | |||
image_segmentation = 'image-segmentation' | |||
semantic_segmentation = 'semantic-segmentation' | |||
image_depth_estimation = 'image-depth-estimation' | |||
portrait_matting = 'portrait-matting' | |||
text_driven_segmentation = 'text-driven-segmentation' | |||
shop_segmentation = 'shop-segmentation' | |||
@@ -1,6 +1,7 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import cv2 | |||
import matplotlib.pyplot as plt | |||
import numpy as np | |||
from modelscope.outputs import OutputKeys | |||
@@ -439,3 +440,11 @@ def show_image_object_detection_auto_result(img_path, | |||
if save_path is not None: | |||
cv2.imwrite(save_path, img) | |||
return img | |||
def depth_to_color(depth): | |||
colormap = plt.get_cmap('plasma') | |||
depth_color = (colormap( | |||
(depth.max() - depth) / depth.max()) * 2**8).astype(np.uint8)[:, :, :3] | |||
depth_color = cv2.cvtColor(depth_color, cv2.COLOR_RGB2BGR) | |||
return depth_color |
@@ -0,0 +1,35 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import unittest | |||
import cv2 | |||
import numpy as np | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines import pipeline | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.cv.image_utils import depth_to_color | |||
from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||
from modelscope.utils.test_utils import test_level | |||
class ImageDepthEstimationTest(unittest.TestCase, DemoCompatibilityCheck): | |||
def setUp(self) -> None: | |||
self.task = 'image-depth-estimation' | |||
self.model_id = 'damo/cv_newcrfs_image-depth-estimation_indoor' | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_image_depth_estimation(self): | |||
input_location = 'data/test/images/image_depth_estimation.jpg' | |||
estimator = pipeline(Tasks.image_depth_estimation, model=self.model_id) | |||
result = estimator(input_location) | |||
depths = result[OutputKeys.DEPTHS] | |||
depth_viz = depth_to_color(depths[0].squeeze().cpu().numpy()) | |||
cv2.imwrite('result.jpg', depth_viz) | |||
print('test_image_depth_estimation DONE') | |||
if __name__ == '__main__': | |||
unittest.main() |