Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10834768master^2
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:4c713215f7fb4da5382c9137347ee52956a7a44d5979c4cffd3c9b6d1d7e878f | |||
size 19445 |
@@ -165,6 +165,8 @@ class Pipelines(object): | |||
easycv_segmentation = 'easycv-segmentation' | |||
face_2d_keypoints = 'mobilenet_face-2d-keypoints_alignment' | |||
salient_detection = 'u2net-salient-detection' | |||
salient_boudary_detection = 'res2net-salient-detection' | |||
camouflaged_detection = 'res2net-camouflaged-detection' | |||
image_classification = 'image-classification' | |||
face_detection = 'resnet-face-detection-scrfd10gkps' | |||
card_detection = 'resnet-card-detection-scrfd34gkps' | |||
@@ -1,3 +1,4 @@ | |||
# The implementation is adopted from U-2-Net, made publicly available under the Apache 2.0 License | |||
# source code avaiable via https://github.com/xuebinqin/U-2-Net | |||
from .senet import SENet | |||
from .u2net import U2NET |
@@ -0,0 +1,187 @@ | |||
# Implementation in this file is modified based on Res2Net-PretrainedModels | |||
# Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License | |||
# publicly avaialbe at https://github.com/Res2Net/Res2Net-PretrainedModels/blob/master/res2net_v1b.py | |||
import math | |||
import torch | |||
import torch.nn as nn | |||
__all__ = ['Res2Net', 'res2net50_v1b_26w_4s'] | |||
class Bottle2neck(nn.Module): | |||
expansion = 4 | |||
def __init__(self, | |||
inplanes, | |||
planes, | |||
stride=1, | |||
downsample=None, | |||
baseWidth=26, | |||
scale=4, | |||
stype='normal'): | |||
""" Constructor | |||
Args: | |||
inplanes: input channel dimensionality | |||
planes: output channel dimensionality | |||
stride: conv stride. Replaces pooling layer. | |||
downsample: None when stride = 1 | |||
baseWidth: basic width of conv3x3 | |||
scale: number of scale. | |||
type: 'normal': normal set. 'stage': first block of a new stage. | |||
""" | |||
super(Bottle2neck, self).__init__() | |||
width = int(math.floor(planes * (baseWidth / 64.0))) | |||
self.conv1 = nn.Conv2d( | |||
inplanes, width * scale, kernel_size=1, bias=False) | |||
self.bn1 = nn.BatchNorm2d(width * scale) | |||
if scale == 1: | |||
self.nums = 1 | |||
else: | |||
self.nums = scale - 1 | |||
if stype == 'stage': | |||
self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1) | |||
convs = [] | |||
bns = [] | |||
for i in range(self.nums): | |||
convs.append( | |||
nn.Conv2d( | |||
width, | |||
width, | |||
kernel_size=3, | |||
stride=stride, | |||
padding=1, | |||
bias=False)) | |||
bns.append(nn.BatchNorm2d(width)) | |||
self.convs = nn.ModuleList(convs) | |||
self.bns = nn.ModuleList(bns) | |||
self.conv3 = nn.Conv2d( | |||
width * scale, planes * self.expansion, kernel_size=1, bias=False) | |||
self.bn3 = nn.BatchNorm2d(planes * self.expansion) | |||
self.relu = nn.ReLU(inplace=True) | |||
self.downsample = downsample | |||
self.stype = stype | |||
self.scale = scale | |||
self.width = width | |||
def forward(self, x): | |||
residual = x | |||
out = self.conv1(x) | |||
out = self.bn1(out) | |||
out = self.relu(out) | |||
spx = torch.split(out, self.width, 1) | |||
for i in range(self.nums): | |||
if i == 0 or self.stype == 'stage': | |||
sp = spx[i] | |||
else: | |||
sp = sp + spx[i] | |||
sp = self.convs[i](sp) | |||
sp = self.relu(self.bns[i](sp)) | |||
if i == 0: | |||
out = sp | |||
else: | |||
out = torch.cat((out, sp), 1) | |||
if self.scale != 1 and self.stype == 'normal': | |||
out = torch.cat((out, spx[self.nums]), 1) | |||
elif self.scale != 1 and self.stype == 'stage': | |||
out = torch.cat((out, self.pool(spx[self.nums])), 1) | |||
out = self.conv3(out) | |||
out = self.bn3(out) | |||
if self.downsample is not None: | |||
residual = self.downsample(x) | |||
out += residual | |||
out = self.relu(out) | |||
return out | |||
class Res2Net(nn.Module): | |||
def __init__(self, block, layers, baseWidth=26, scale=4, num_classes=1000): | |||
self.inplanes = 64 | |||
super(Res2Net, self).__init__() | |||
self.baseWidth = baseWidth | |||
self.scale = scale | |||
self.conv1 = nn.Sequential( | |||
nn.Conv2d(3, 32, 3, 2, 1, bias=False), nn.BatchNorm2d(32), | |||
nn.ReLU(inplace=True), nn.Conv2d(32, 32, 3, 1, 1, bias=False), | |||
nn.BatchNorm2d(32), nn.ReLU(inplace=True), | |||
nn.Conv2d(32, 64, 3, 1, 1, bias=False)) | |||
self.bn1 = nn.BatchNorm2d(64) | |||
self.relu = nn.ReLU() | |||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |||
self.layer1 = self._make_layer(block, 64, layers[0]) | |||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2) | |||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2) | |||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2) | |||
self.avgpool = nn.AdaptiveAvgPool2d(1) | |||
self.fc = nn.Linear(512 * block.expansion, num_classes) | |||
for m in self.modules(): | |||
if isinstance(m, nn.Conv2d): | |||
nn.init.kaiming_normal_( | |||
m.weight, mode='fan_out', nonlinearity='relu') | |||
elif isinstance(m, nn.BatchNorm2d): | |||
nn.init.constant_(m.weight, 1) | |||
nn.init.constant_(m.bias, 0) | |||
def _make_layer(self, block, planes, blocks, stride=1): | |||
downsample = None | |||
if stride != 1 or self.inplanes != planes * block.expansion: | |||
downsample = nn.Sequential( | |||
nn.AvgPool2d( | |||
kernel_size=stride, | |||
stride=stride, | |||
ceil_mode=True, | |||
count_include_pad=False), | |||
nn.Conv2d( | |||
self.inplanes, | |||
planes * block.expansion, | |||
kernel_size=1, | |||
stride=1, | |||
bias=False), | |||
nn.BatchNorm2d(planes * block.expansion), | |||
) | |||
layers = [] | |||
layers.append( | |||
block( | |||
self.inplanes, | |||
planes, | |||
stride, | |||
downsample=downsample, | |||
stype='stage', | |||
baseWidth=self.baseWidth, | |||
scale=self.scale)) | |||
self.inplanes = planes * block.expansion | |||
for i in range(1, blocks): | |||
layers.append( | |||
block( | |||
self.inplanes, | |||
planes, | |||
baseWidth=self.baseWidth, | |||
scale=self.scale)) | |||
return nn.Sequential(*layers) | |||
def forward(self, x): | |||
x = self.conv1(x) | |||
x = self.bn1(x) | |||
x = self.relu(x) | |||
x = self.maxpool(x) | |||
x = self.layer1(x) | |||
x = self.layer2(x) | |||
x = self.layer3(x) | |||
x = self.layer4(x) | |||
x = self.avgpool(x) | |||
x = x.view(x.size(0), -1) | |||
x = self.fc(x) | |||
return x | |||
def res2net50_v1b_26w_4s(backbone_path, pretrained=False, **kwargs): | |||
"""Constructs a Res2Net-50_v1b_26w_4s lib. | |||
Args: | |||
pretrained (bool): If True, returns a lib pre-trained on ImageNet | |||
""" | |||
model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, **kwargs) | |||
if pretrained: | |||
model_state = torch.load(backbone_path) | |||
model.load_state_dict(model_state) | |||
return model |
@@ -0,0 +1,6 @@ | |||
# Implementation in this file is modified based on Res2Net-PretrainedModels | |||
# Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License | |||
# publicly avaialbe at https://github.com/Res2Net/Res2Net-PretrainedModels/blob/master/res2net_v1b.py | |||
from .Res2Net_v1b import res2net50_v1b_26w_4s | |||
__all__ = ['res2net50_v1b_26w_4s'] |
@@ -0,0 +1,178 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from .utils import ConvBNReLU | |||
class AreaLayer(nn.Module): | |||
def __init__(self, in_channel, out_channel): | |||
super(AreaLayer, self).__init__() | |||
self.lbody = nn.Sequential( | |||
nn.Conv2d(out_channel, out_channel, 1), | |||
nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True)) | |||
self.hbody = nn.Sequential( | |||
nn.Conv2d(in_channel, out_channel, 1), nn.BatchNorm2d(out_channel), | |||
nn.ReLU(inplace=True)) | |||
self.body = nn.Sequential( | |||
nn.Conv2d(2 * out_channel, out_channel, 3, 1, 1), | |||
nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True), | |||
nn.Conv2d(out_channel, out_channel, 3, 1, 1), | |||
nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True), | |||
nn.Conv2d(out_channel, 1, 1)) | |||
def forward(self, xl, xh): | |||
xl1 = self.lbody(xl) | |||
xl1 = F.interpolate( | |||
xl1, size=xh.size()[2:], mode='bilinear', align_corners=True) | |||
xh1 = self.hbody(xh) | |||
x = torch.cat((xl1, xh1), dim=1) | |||
x_out = self.body(x) | |||
return x_out | |||
class EdgeLayer(nn.Module): | |||
def __init__(self, in_channel, out_channel): | |||
super(EdgeLayer, self).__init__() | |||
self.lbody = nn.Sequential( | |||
nn.Conv2d(out_channel, out_channel, 1), | |||
nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True)) | |||
self.hbody = nn.Sequential( | |||
nn.Conv2d(in_channel, out_channel, 1), nn.BatchNorm2d(out_channel), | |||
nn.ReLU(inplace=True)) | |||
self.bodye = nn.Sequential( | |||
nn.Conv2d(2 * out_channel, out_channel, 3, 1, 1), | |||
nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True), | |||
nn.Conv2d(out_channel, out_channel, 3, 1, 1), | |||
nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True), | |||
nn.Conv2d(out_channel, 1, 1)) | |||
def forward(self, xl, xh): | |||
xl1 = self.lbody(xl) | |||
xh1 = self.hbody(xh) | |||
xh1 = F.interpolate( | |||
xh1, size=xl.size()[2:], mode='bilinear', align_corners=True) | |||
x = torch.cat((xl1, xh1), dim=1) | |||
x_out = self.bodye(x) | |||
return x_out | |||
class EBlock(nn.Module): | |||
def __init__(self, inchs, outchs): | |||
super(EBlock, self).__init__() | |||
self.elayer = nn.Sequential( | |||
ConvBNReLU(inchs + 1, outchs, kernel_size=3, padding=1, stride=1), | |||
ConvBNReLU(outchs, outchs, 1)) | |||
self.salayer = nn.Sequential( | |||
nn.Conv2d(2, 1, 3, 1, 1, bias=False), | |||
nn.BatchNorm2d(1, momentum=0.01), nn.Sigmoid()) | |||
def forward(self, x, edgeAtten): | |||
x = torch.cat((x, edgeAtten), dim=1) | |||
ex = self.elayer(x) | |||
ex_max = torch.max(ex, 1, keepdim=True)[0] | |||
ex_mean = torch.mean(ex, dim=1, keepdim=True) | |||
xei_compress = torch.cat((ex_max, ex_mean), dim=1) | |||
scale = self.salayer(xei_compress) | |||
x_out = ex * scale | |||
return x_out | |||
class StructureE(nn.Module): | |||
def __init__(self, inchs, outchs, EM): | |||
super(StructureE, self).__init__() | |||
self.ne_modules = int(inchs / EM) | |||
NM = int(outchs / self.ne_modules) | |||
elayes = [] | |||
for i in range(self.ne_modules): | |||
emblock = EBlock(EM, NM) | |||
elayes.append(emblock) | |||
self.emlayes = nn.ModuleList(elayes) | |||
self.body = nn.Sequential( | |||
ConvBNReLU(outchs, outchs, 3, 1, 1), ConvBNReLU(outchs, outchs, 1)) | |||
def forward(self, x, edgeAtten): | |||
if edgeAtten.size() != x.size(): | |||
edgeAtten = F.interpolate( | |||
edgeAtten, x.size()[2:], mode='bilinear', align_corners=False) | |||
xx = torch.chunk(x, self.ne_modules, dim=1) | |||
efeas = [] | |||
for i in range(self.ne_modules): | |||
xei = self.emlayes[i](xx[i], edgeAtten) | |||
efeas.append(xei) | |||
efeas = torch.cat(efeas, dim=1) | |||
x_out = self.body(efeas) | |||
return x_out | |||
class ABlock(nn.Module): | |||
def __init__(self, inchs, outchs, k): | |||
super(ABlock, self).__init__() | |||
self.alayer = nn.Sequential( | |||
ConvBNReLU(inchs, outchs, k, 1, k // 2), | |||
ConvBNReLU(outchs, outchs, 1)) | |||
self.arlayer = nn.Sequential( | |||
ConvBNReLU(inchs, outchs, k, 1, k // 2), | |||
ConvBNReLU(outchs, outchs, 1)) | |||
self.fusion = ConvBNReLU(2 * outchs, outchs, 1) | |||
def forward(self, x, areaAtten): | |||
xa = x * areaAtten | |||
xra = x * (1 - areaAtten) | |||
xout = self.fusion(torch.cat((xa, xra), dim=1)) | |||
return xout | |||
class AMFusion(nn.Module): | |||
def __init__(self, inchs, outchs, AM): | |||
super(AMFusion, self).__init__() | |||
self.k = [3, 3, 5, 5] | |||
self.conv_up = ConvBNReLU(inchs, outchs, 3, 1, 1) | |||
self.up = nn.Upsample( | |||
scale_factor=2, mode='bilinear', align_corners=True) | |||
self.na_modules = int(outchs / AM) | |||
alayers = [] | |||
for i in range(self.na_modules): | |||
layer = ABlock(AM, AM, self.k[i]) | |||
alayers.append(layer) | |||
self.alayers = nn.ModuleList(alayers) | |||
self.fusion_0 = ConvBNReLU(outchs, outchs, 3, 1, 1) | |||
self.fusion_e = nn.Sequential( | |||
nn.Conv2d( | |||
outchs, outchs, kernel_size=(3, 1), padding=(1, 0), | |||
bias=False), nn.BatchNorm2d(outchs), nn.ReLU(inplace=True), | |||
nn.Conv2d( | |||
outchs, outchs, kernel_size=(1, 3), padding=(0, 1), | |||
bias=False), nn.BatchNorm2d(outchs), nn.ReLU(inplace=True)) | |||
self.fusion_e1 = nn.Sequential( | |||
nn.Conv2d( | |||
outchs, outchs, kernel_size=(5, 1), padding=(2, 0), | |||
bias=False), nn.BatchNorm2d(outchs), nn.ReLU(inplace=True), | |||
nn.Conv2d( | |||
outchs, outchs, kernel_size=(1, 5), padding=(0, 2), | |||
bias=False), nn.BatchNorm2d(outchs), nn.ReLU(inplace=True)) | |||
self.fusion = ConvBNReLU(3 * outchs, outchs, 1) | |||
def forward(self, xl, xh, xhm): | |||
xh1 = self.up(self.conv_up(xh)) | |||
x = xh1 + xl | |||
xm = self.up(torch.sigmoid(xhm)) | |||
xx = torch.chunk(x, self.na_modules, dim=1) | |||
xxmids = [] | |||
for i in range(self.na_modules): | |||
xi = self.alayers[i](xx[i], xm) | |||
xxmids.append(xi) | |||
xfea = torch.cat(xxmids, dim=1) | |||
x0 = self.fusion_0(xfea) | |||
x1 = self.fusion_e(xfea) | |||
x2 = self.fusion_e1(xfea) | |||
x_out = self.fusion(torch.cat((x0, x1, x2), dim=1)) | |||
return x_out |
@@ -0,0 +1,74 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from .backbone import res2net50_v1b_26w_4s as res2net | |||
from .modules import AMFusion, AreaLayer, EdgeLayer, StructureE | |||
from .utils import ASPP, CBAM, ConvBNReLU | |||
class SENet(nn.Module): | |||
def __init__(self, backbone_path=None, pretrained=False): | |||
super(SENet, self).__init__() | |||
resnet50 = res2net(backbone_path, pretrained) | |||
self.layer0_1 = nn.Sequential(resnet50.conv1, resnet50.bn1, | |||
resnet50.relu) | |||
self.maxpool = resnet50.maxpool | |||
self.layer1 = resnet50.layer1 | |||
self.layer2 = resnet50.layer2 | |||
self.layer3 = resnet50.layer3 | |||
self.layer4 = resnet50.layer4 | |||
self.aspp3 = ASPP(1024, 256) | |||
self.aspp4 = ASPP(2048, 256) | |||
self.cbblock3 = CBAM(inchs=256, kernel_size=5) | |||
self.cbblock4 = CBAM(inchs=256, kernel_size=5) | |||
self.up = nn.Upsample( | |||
mode='bilinear', scale_factor=2, align_corners=False) | |||
self.conv_up = ConvBNReLU(512, 512, 1) | |||
self.aux_edge = EdgeLayer(512, 256) | |||
self.aux_area = AreaLayer(512, 256) | |||
self.layer1_enhance = StructureE(256, 128, 128) | |||
self.layer2_enhance = StructureE(512, 256, 128) | |||
self.layer3_decoder = AMFusion(512, 256, 128) | |||
self.layer2_decoder = AMFusion(256, 128, 128) | |||
self.out_conv_8 = nn.Conv2d(256, 1, 1) | |||
self.out_conv_4 = nn.Conv2d(128, 1, 1) | |||
def forward(self, x): | |||
layer0 = self.layer0_1(x) | |||
layer0s = self.maxpool(layer0) | |||
layer1 = self.layer1(layer0s) | |||
layer2 = self.layer2(layer1) | |||
layer3 = self.layer3(layer2) | |||
layer4 = self.layer4(layer3) | |||
layer3_eh = self.cbblock3(self.aspp3(layer3)) | |||
layer4_eh = self.cbblock4(self.aspp4(layer4)) | |||
layer34 = self.conv_up( | |||
torch.cat((self.up(layer4_eh), layer3_eh), dim=1)) | |||
edge_atten = self.aux_edge(layer1, layer34) | |||
area_atten = self.aux_area(layer1, layer34) | |||
edge_atten_ = torch.sigmoid(edge_atten) | |||
layer1_eh = self.layer1_enhance(layer1, edge_atten_) | |||
layer2_eh = self.layer2_enhance(layer2, edge_atten_) | |||
layer2_fu = self.layer3_decoder(layer2_eh, layer34, area_atten) | |||
out_8 = self.out_conv_8(layer2_fu) | |||
layer1_fu = self.layer2_decoder(layer1_eh, layer2_fu, out_8) | |||
out_4 = self.out_conv_4(layer1_fu) | |||
out_16 = F.interpolate( | |||
area_atten, | |||
size=x.size()[2:], | |||
mode='bilinear', | |||
align_corners=False) | |||
out_8 = F.interpolate( | |||
out_8, size=x.size()[2:], mode='bilinear', align_corners=False) | |||
out_4 = F.interpolate( | |||
out_4, size=x.size()[2:], mode='bilinear', align_corners=False) | |||
edge_out = F.interpolate( | |||
edge_atten_, | |||
size=x.size()[2:], | |||
mode='bilinear', | |||
align_corners=False) | |||
return out_4.sigmoid(), out_8.sigmoid(), out_16.sigmoid(), edge_out |
@@ -0,0 +1,105 @@ | |||
# Implementation in this file is modified based on deeplabv3 | |||
# Originally MIT license,publicly avaialbe at https://github.com/fregu856/deeplabv3/blob/master/model/aspp.py | |||
# Implementation in this file is modified based on attention-module | |||
# Originally MIT license,publicly avaialbe at https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py | |||
import torch | |||
import torch.nn as nn | |||
class ConvBNReLU(nn.Module): | |||
def __init__(self, | |||
inplanes, | |||
planes, | |||
kernel_size=3, | |||
stride=1, | |||
padding=0, | |||
dilation=1, | |||
bias=False): | |||
super(ConvBNReLU, self).__init__() | |||
self.block = nn.Sequential( | |||
nn.Conv2d( | |||
inplanes, | |||
planes, | |||
kernel_size, | |||
stride=stride, | |||
padding=padding, | |||
dilation=dilation, | |||
bias=bias), nn.BatchNorm2d(planes), nn.ReLU(inplace=True)) | |||
def forward(self, x): | |||
return self.block(x) | |||
class ASPP(nn.Module): | |||
def __init__(self, in_dim, out_dim): | |||
super(ASPP, self).__init__() | |||
mid_dim = 128 | |||
self.conv1 = ConvBNReLU(in_dim, mid_dim, kernel_size=1) | |||
self.conv2 = ConvBNReLU( | |||
in_dim, mid_dim, kernel_size=3, padding=2, dilation=2) | |||
self.conv3 = ConvBNReLU( | |||
in_dim, mid_dim, kernel_size=3, padding=5, dilation=5) | |||
self.conv4 = ConvBNReLU( | |||
in_dim, mid_dim, kernel_size=3, padding=7, dilation=7) | |||
self.conv5 = ConvBNReLU(in_dim, mid_dim, kernel_size=1, padding=0) | |||
self.fuse = ConvBNReLU(5 * mid_dim, out_dim, 3, 1, 1) | |||
self.global_pooling = nn.AdaptiveAvgPool2d(1) | |||
def forward(self, x): | |||
conv1 = self.conv1(x) | |||
conv2 = self.conv2(x) | |||
conv3 = self.conv3(x) | |||
conv4 = self.conv4(x) | |||
xg = self.conv5(self.global_pooling(x)) | |||
conv5 = nn.Upsample((x.shape[2], x.shape[3]), mode='nearest')(xg) | |||
return self.fuse(torch.cat((conv1, conv2, conv3, conv4, conv5), 1)) | |||
class ChannelAttention(nn.Module): | |||
def __init__(self, inchs, ratio=16): | |||
super(ChannelAttention, self).__init__() | |||
self.avg_pool = nn.AdaptiveAvgPool2d(1) | |||
self.max_pool = nn.AdaptiveMaxPool2d(1) | |||
self.fc = nn.Sequential( | |||
nn.Conv2d(inchs, inchs // 16, 1, bias=False), nn.ReLU(), | |||
nn.Conv2d(inchs // 16, inchs, 1, bias=False)) | |||
self.sigmoid = nn.Sigmoid() | |||
def forward(self, x): | |||
avg_out = self.fc(self.avg_pool(x)) | |||
max_out = self.fc(self.max_pool(x)) | |||
out = avg_out + max_out | |||
return self.sigmoid(out) | |||
class SpatialAttention(nn.Module): | |||
def __init__(self, kernel_size=7): | |||
super(SpatialAttention, self).__init__() | |||
self.conv1 = nn.Conv2d( | |||
2, 1, kernel_size, padding=kernel_size // 2, bias=False) | |||
self.sigmoid = nn.Sigmoid() | |||
def forward(self, x): | |||
avg_out = torch.mean(x, dim=1, keepdim=True) | |||
max_out, _ = torch.max(x, dim=1, keepdim=True) | |||
x = torch.cat([avg_out, max_out], dim=1) | |||
x = self.conv1(x) | |||
return self.sigmoid(x) | |||
class CBAM(nn.Module): | |||
def __init__(self, inchs, kernel_size=7): | |||
super().__init__() | |||
self.calayer = ChannelAttention(inchs=inchs) | |||
self.saLayer = SpatialAttention(kernel_size=kernel_size) | |||
def forward(self, x): | |||
xca = self.calayer(x) * x | |||
xsa = self.saLayer(xca) * xca | |||
return xsa |
@@ -2,7 +2,6 @@ | |||
import os.path as osp | |||
import cv2 | |||
import numpy as np | |||
import torch | |||
from PIL import Image | |||
from torchvision import transforms | |||
@@ -10,8 +9,9 @@ from torchvision import transforms | |||
from modelscope.metainfo import Models | |||
from modelscope.models.base.base_torch_model import TorchModel | |||
from modelscope.models.builder import MODELS | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from .models import U2NET | |||
from .models import U2NET, SENet | |||
@MODELS.register_module( | |||
@@ -22,13 +22,25 @@ class SalientDetection(TorchModel): | |||
"""str -- model file root.""" | |||
super().__init__(model_dir, *args, **kwargs) | |||
model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) | |||
self.model = U2NET(3, 1) | |||
self.norm_mean = [0.485, 0.456, 0.406] | |||
self.norm_std = [0.229, 0.224, 0.225] | |||
self.norm_size = (320, 320) | |||
config_path = osp.join(model_dir, 'config.py') | |||
if osp.exists(config_path) is False: | |||
self.model = U2NET(3, 1) | |||
else: | |||
self.model = SENet(backbone_path=None, pretrained=False) | |||
config = Config.from_file(config_path) | |||
self.norm_mean = config.norm_mean | |||
self.norm_std = config.norm_std | |||
self.norm_size = config.norm_size | |||
checkpoint = torch.load(model_path, map_location='cpu') | |||
self.transform_input = transforms.Compose([ | |||
transforms.Resize((320, 320)), | |||
transforms.Resize(self.norm_size), | |||
transforms.ToTensor(), | |||
transforms.Normalize( | |||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |||
transforms.Normalize(mean=self.norm_mean, std=self.norm_std) | |||
]) | |||
self.model.load_state_dict(checkpoint) | |||
self.model.eval() | |||
@@ -12,6 +12,11 @@ from modelscope.utils.constant import Tasks | |||
@PIPELINES.register_module( | |||
Tasks.semantic_segmentation, module_name=Pipelines.salient_detection) | |||
@PIPELINES.register_module( | |||
Tasks.semantic_segmentation, | |||
module_name=Pipelines.salient_boudary_detection) | |||
@PIPELINES.register_module( | |||
Tasks.semantic_segmentation, module_name=Pipelines.camouflaged_detection) | |||
class ImageSalientDetectionPipeline(Pipeline): | |||
def __init__(self, model: str, **kwargs): | |||
@@ -23,6 +23,27 @@ class SalientDetectionTest(unittest.TestCase, DemoCompatibilityCheck): | |||
import cv2 | |||
cv2.imwrite(input_location + '_salient.jpg', result[OutputKeys.MASKS]) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_salient_boudary_detection(self): | |||
input_location = 'data/test/images/image_salient_detection.jpg' | |||
model_id = 'damo/cv_res2net_salient-detection' | |||
salient_detect = pipeline(Tasks.semantic_segmentation, model=model_id) | |||
result = salient_detect(input_location) | |||
import cv2 | |||
cv2.imwrite(input_location + '_boudary_salient.jpg', | |||
result[OutputKeys.MASKS]) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_camouflag_detection(self): | |||
input_location = 'data/test/images/image_camouflag_detection.jpg' | |||
model_id = 'damo/cv_res2net_camouflaged-detection' | |||
camouflag_detect = pipeline( | |||
Tasks.semantic_segmentation, model=model_id) | |||
result = camouflag_detect(input_location) | |||
import cv2 | |||
cv2.imwrite(input_location + '_camouflag.jpg', | |||
result[OutputKeys.MASKS]) | |||
@unittest.skip('demo compatibility test is only enabled on a needed-basis') | |||
def test_demo_compatibility(self): | |||
self.compatibility_check() | |||