wendi.hwd yingda.chen 2 years ago
parent
commit
ff55bd9436
11 changed files with 600 additions and 6 deletions
  1. +3
    -0
      data/test/images/image_camouflag_detection.jpg
  2. +2
    -0
      modelscope/metainfo.py
  3. +1
    -0
      modelscope/models/cv/salient_detection/models/__init__.py
  4. +187
    -0
      modelscope/models/cv/salient_detection/models/backbone/Res2Net_v1b.py
  5. +6
    -0
      modelscope/models/cv/salient_detection/models/backbone/__init__.py
  6. +178
    -0
      modelscope/models/cv/salient_detection/models/modules.py
  7. +74
    -0
      modelscope/models/cv/salient_detection/models/senet.py
  8. +105
    -0
      modelscope/models/cv/salient_detection/models/utils.py
  9. +18
    -6
      modelscope/models/cv/salient_detection/salient_model.py
  10. +5
    -0
      modelscope/pipelines/cv/image_salient_detection_pipeline.py
  11. +21
    -0
      tests/pipelines/test_salient_detection.py

+ 3
- 0
data/test/images/image_camouflag_detection.jpg View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4c713215f7fb4da5382c9137347ee52956a7a44d5979c4cffd3c9b6d1d7e878f
size 19445

+ 2
- 0
modelscope/metainfo.py View File

@@ -165,6 +165,8 @@ class Pipelines(object):
easycv_segmentation = 'easycv-segmentation'
face_2d_keypoints = 'mobilenet_face-2d-keypoints_alignment'
salient_detection = 'u2net-salient-detection'
salient_boudary_detection = 'res2net-salient-detection'
camouflaged_detection = 'res2net-camouflaged-detection'
image_classification = 'image-classification'
face_detection = 'resnet-face-detection-scrfd10gkps'
card_detection = 'resnet-card-detection-scrfd34gkps'


+ 1
- 0
modelscope/models/cv/salient_detection/models/__init__.py View File

@@ -1,3 +1,4 @@
# The implementation is adopted from U-2-Net, made publicly available under the Apache 2.0 License
# source code avaiable via https://github.com/xuebinqin/U-2-Net
from .senet import SENet
from .u2net import U2NET

+ 187
- 0
modelscope/models/cv/salient_detection/models/backbone/Res2Net_v1b.py View File

@@ -0,0 +1,187 @@
# Implementation in this file is modified based on Res2Net-PretrainedModels
# Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License
# publicly avaialbe at https://github.com/Res2Net/Res2Net-PretrainedModels/blob/master/res2net_v1b.py
import math

import torch
import torch.nn as nn

__all__ = ['Res2Net', 'res2net50_v1b_26w_4s']


class Bottle2neck(nn.Module):
expansion = 4

def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
baseWidth=26,
scale=4,
stype='normal'):
""" Constructor
Args:
inplanes: input channel dimensionality
planes: output channel dimensionality
stride: conv stride. Replaces pooling layer.
downsample: None when stride = 1
baseWidth: basic width of conv3x3
scale: number of scale.
type: 'normal': normal set. 'stage': first block of a new stage.
"""
super(Bottle2neck, self).__init__()
width = int(math.floor(planes * (baseWidth / 64.0)))
self.conv1 = nn.Conv2d(
inplanes, width * scale, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(width * scale)
if scale == 1:
self.nums = 1
else:
self.nums = scale - 1
if stype == 'stage':
self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1)
convs = []
bns = []
for i in range(self.nums):
convs.append(
nn.Conv2d(
width,
width,
kernel_size=3,
stride=stride,
padding=1,
bias=False))
bns.append(nn.BatchNorm2d(width))
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList(bns)
self.conv3 = nn.Conv2d(
width * scale, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stype = stype
self.scale = scale
self.width = width

def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
spx = torch.split(out, self.width, 1)
for i in range(self.nums):
if i == 0 or self.stype == 'stage':
sp = spx[i]
else:
sp = sp + spx[i]
sp = self.convs[i](sp)
sp = self.relu(self.bns[i](sp))
if i == 0:
out = sp
else:
out = torch.cat((out, sp), 1)
if self.scale != 1 and self.stype == 'normal':
out = torch.cat((out, spx[self.nums]), 1)
elif self.scale != 1 and self.stype == 'stage':
out = torch.cat((out, self.pool(spx[self.nums])), 1)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out


class Res2Net(nn.Module):

def __init__(self, block, layers, baseWidth=26, scale=4, num_classes=1000):
self.inplanes = 64
super(Res2Net, self).__init__()
self.baseWidth = baseWidth
self.scale = scale
self.conv1 = nn.Sequential(
nn.Conv2d(3, 32, 3, 2, 1, bias=False), nn.BatchNorm2d(32),
nn.ReLU(inplace=True), nn.Conv2d(32, 32, 3, 1, 1, bias=False),
nn.BatchNorm2d(32), nn.ReLU(inplace=True),
nn.Conv2d(32, 64, 3, 1, 1, bias=False))
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.AvgPool2d(
kernel_size=stride,
stride=stride,
ceil_mode=True,
count_include_pad=False),
nn.Conv2d(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=1,
bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(
block(
self.inplanes,
planes,
stride,
downsample=downsample,
stype='stage',
baseWidth=self.baseWidth,
scale=self.scale))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(
block(
self.inplanes,
planes,
baseWidth=self.baseWidth,
scale=self.scale))
return nn.Sequential(*layers)

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x


def res2net50_v1b_26w_4s(backbone_path, pretrained=False, **kwargs):
"""Constructs a Res2Net-50_v1b_26w_4s lib.
Args:
pretrained (bool): If True, returns a lib pre-trained on ImageNet
"""
model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, **kwargs)
if pretrained:
model_state = torch.load(backbone_path)
model.load_state_dict(model_state)
return model

+ 6
- 0
modelscope/models/cv/salient_detection/models/backbone/__init__.py View File

@@ -0,0 +1,6 @@
# Implementation in this file is modified based on Res2Net-PretrainedModels
# Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License
# publicly avaialbe at https://github.com/Res2Net/Res2Net-PretrainedModels/blob/master/res2net_v1b.py
from .Res2Net_v1b import res2net50_v1b_26w_4s

__all__ = ['res2net50_v1b_26w_4s']

+ 178
- 0
modelscope/models/cv/salient_detection/models/modules.py View File

@@ -0,0 +1,178 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torch.nn as nn
import torch.nn.functional as F

from .utils import ConvBNReLU


class AreaLayer(nn.Module):

def __init__(self, in_channel, out_channel):
super(AreaLayer, self).__init__()
self.lbody = nn.Sequential(
nn.Conv2d(out_channel, out_channel, 1),
nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True))
self.hbody = nn.Sequential(
nn.Conv2d(in_channel, out_channel, 1), nn.BatchNorm2d(out_channel),
nn.ReLU(inplace=True))
self.body = nn.Sequential(
nn.Conv2d(2 * out_channel, out_channel, 3, 1, 1),
nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True),
nn.Conv2d(out_channel, out_channel, 3, 1, 1),
nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True),
nn.Conv2d(out_channel, 1, 1))

def forward(self, xl, xh):
xl1 = self.lbody(xl)
xl1 = F.interpolate(
xl1, size=xh.size()[2:], mode='bilinear', align_corners=True)
xh1 = self.hbody(xh)
x = torch.cat((xl1, xh1), dim=1)
x_out = self.body(x)
return x_out


class EdgeLayer(nn.Module):

def __init__(self, in_channel, out_channel):
super(EdgeLayer, self).__init__()
self.lbody = nn.Sequential(
nn.Conv2d(out_channel, out_channel, 1),
nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True))
self.hbody = nn.Sequential(
nn.Conv2d(in_channel, out_channel, 1), nn.BatchNorm2d(out_channel),
nn.ReLU(inplace=True))
self.bodye = nn.Sequential(
nn.Conv2d(2 * out_channel, out_channel, 3, 1, 1),
nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True),
nn.Conv2d(out_channel, out_channel, 3, 1, 1),
nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True),
nn.Conv2d(out_channel, 1, 1))

def forward(self, xl, xh):
xl1 = self.lbody(xl)
xh1 = self.hbody(xh)
xh1 = F.interpolate(
xh1, size=xl.size()[2:], mode='bilinear', align_corners=True)
x = torch.cat((xl1, xh1), dim=1)
x_out = self.bodye(x)
return x_out


class EBlock(nn.Module):

def __init__(self, inchs, outchs):
super(EBlock, self).__init__()
self.elayer = nn.Sequential(
ConvBNReLU(inchs + 1, outchs, kernel_size=3, padding=1, stride=1),
ConvBNReLU(outchs, outchs, 1))
self.salayer = nn.Sequential(
nn.Conv2d(2, 1, 3, 1, 1, bias=False),
nn.BatchNorm2d(1, momentum=0.01), nn.Sigmoid())

def forward(self, x, edgeAtten):
x = torch.cat((x, edgeAtten), dim=1)
ex = self.elayer(x)
ex_max = torch.max(ex, 1, keepdim=True)[0]
ex_mean = torch.mean(ex, dim=1, keepdim=True)
xei_compress = torch.cat((ex_max, ex_mean), dim=1)

scale = self.salayer(xei_compress)
x_out = ex * scale
return x_out


class StructureE(nn.Module):

def __init__(self, inchs, outchs, EM):
super(StructureE, self).__init__()
self.ne_modules = int(inchs / EM)
NM = int(outchs / self.ne_modules)
elayes = []
for i in range(self.ne_modules):
emblock = EBlock(EM, NM)
elayes.append(emblock)
self.emlayes = nn.ModuleList(elayes)
self.body = nn.Sequential(
ConvBNReLU(outchs, outchs, 3, 1, 1), ConvBNReLU(outchs, outchs, 1))

def forward(self, x, edgeAtten):
if edgeAtten.size() != x.size():
edgeAtten = F.interpolate(
edgeAtten, x.size()[2:], mode='bilinear', align_corners=False)
xx = torch.chunk(x, self.ne_modules, dim=1)
efeas = []
for i in range(self.ne_modules):
xei = self.emlayes[i](xx[i], edgeAtten)
efeas.append(xei)
efeas = torch.cat(efeas, dim=1)
x_out = self.body(efeas)
return x_out


class ABlock(nn.Module):

def __init__(self, inchs, outchs, k):
super(ABlock, self).__init__()
self.alayer = nn.Sequential(
ConvBNReLU(inchs, outchs, k, 1, k // 2),
ConvBNReLU(outchs, outchs, 1))
self.arlayer = nn.Sequential(
ConvBNReLU(inchs, outchs, k, 1, k // 2),
ConvBNReLU(outchs, outchs, 1))
self.fusion = ConvBNReLU(2 * outchs, outchs, 1)

def forward(self, x, areaAtten):
xa = x * areaAtten
xra = x * (1 - areaAtten)
xout = self.fusion(torch.cat((xa, xra), dim=1))
return xout


class AMFusion(nn.Module):

def __init__(self, inchs, outchs, AM):
super(AMFusion, self).__init__()
self.k = [3, 3, 5, 5]
self.conv_up = ConvBNReLU(inchs, outchs, 3, 1, 1)
self.up = nn.Upsample(
scale_factor=2, mode='bilinear', align_corners=True)
self.na_modules = int(outchs / AM)
alayers = []
for i in range(self.na_modules):
layer = ABlock(AM, AM, self.k[i])
alayers.append(layer)
self.alayers = nn.ModuleList(alayers)
self.fusion_0 = ConvBNReLU(outchs, outchs, 3, 1, 1)
self.fusion_e = nn.Sequential(
nn.Conv2d(
outchs, outchs, kernel_size=(3, 1), padding=(1, 0),
bias=False), nn.BatchNorm2d(outchs), nn.ReLU(inplace=True),
nn.Conv2d(
outchs, outchs, kernel_size=(1, 3), padding=(0, 1),
bias=False), nn.BatchNorm2d(outchs), nn.ReLU(inplace=True))
self.fusion_e1 = nn.Sequential(
nn.Conv2d(
outchs, outchs, kernel_size=(5, 1), padding=(2, 0),
bias=False), nn.BatchNorm2d(outchs), nn.ReLU(inplace=True),
nn.Conv2d(
outchs, outchs, kernel_size=(1, 5), padding=(0, 2),
bias=False), nn.BatchNorm2d(outchs), nn.ReLU(inplace=True))
self.fusion = ConvBNReLU(3 * outchs, outchs, 1)

def forward(self, xl, xh, xhm):
xh1 = self.up(self.conv_up(xh))
x = xh1 + xl
xm = self.up(torch.sigmoid(xhm))
xx = torch.chunk(x, self.na_modules, dim=1)
xxmids = []
for i in range(self.na_modules):
xi = self.alayers[i](xx[i], xm)
xxmids.append(xi)
xfea = torch.cat(xxmids, dim=1)
x0 = self.fusion_0(xfea)
x1 = self.fusion_e(xfea)
x2 = self.fusion_e1(xfea)
x_out = self.fusion(torch.cat((x0, x1, x2), dim=1))
return x_out

+ 74
- 0
modelscope/models/cv/salient_detection/models/senet.py View File

@@ -0,0 +1,74 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torch.nn as nn
import torch.nn.functional as F

from .backbone import res2net50_v1b_26w_4s as res2net
from .modules import AMFusion, AreaLayer, EdgeLayer, StructureE
from .utils import ASPP, CBAM, ConvBNReLU


class SENet(nn.Module):

def __init__(self, backbone_path=None, pretrained=False):
super(SENet, self).__init__()
resnet50 = res2net(backbone_path, pretrained)
self.layer0_1 = nn.Sequential(resnet50.conv1, resnet50.bn1,
resnet50.relu)
self.maxpool = resnet50.maxpool
self.layer1 = resnet50.layer1
self.layer2 = resnet50.layer2
self.layer3 = resnet50.layer3
self.layer4 = resnet50.layer4
self.aspp3 = ASPP(1024, 256)
self.aspp4 = ASPP(2048, 256)
self.cbblock3 = CBAM(inchs=256, kernel_size=5)
self.cbblock4 = CBAM(inchs=256, kernel_size=5)
self.up = nn.Upsample(
mode='bilinear', scale_factor=2, align_corners=False)
self.conv_up = ConvBNReLU(512, 512, 1)
self.aux_edge = EdgeLayer(512, 256)
self.aux_area = AreaLayer(512, 256)
self.layer1_enhance = StructureE(256, 128, 128)
self.layer2_enhance = StructureE(512, 256, 128)
self.layer3_decoder = AMFusion(512, 256, 128)
self.layer2_decoder = AMFusion(256, 128, 128)
self.out_conv_8 = nn.Conv2d(256, 1, 1)
self.out_conv_4 = nn.Conv2d(128, 1, 1)

def forward(self, x):
layer0 = self.layer0_1(x)
layer0s = self.maxpool(layer0)
layer1 = self.layer1(layer0s)
layer2 = self.layer2(layer1)
layer3 = self.layer3(layer2)
layer4 = self.layer4(layer3)
layer3_eh = self.cbblock3(self.aspp3(layer3))
layer4_eh = self.cbblock4(self.aspp4(layer4))
layer34 = self.conv_up(
torch.cat((self.up(layer4_eh), layer3_eh), dim=1))
edge_atten = self.aux_edge(layer1, layer34)
area_atten = self.aux_area(layer1, layer34)
edge_atten_ = torch.sigmoid(edge_atten)
layer1_eh = self.layer1_enhance(layer1, edge_atten_)
layer2_eh = self.layer2_enhance(layer2, edge_atten_)
layer2_fu = self.layer3_decoder(layer2_eh, layer34, area_atten)
out_8 = self.out_conv_8(layer2_fu)
layer1_fu = self.layer2_decoder(layer1_eh, layer2_fu, out_8)
out_4 = self.out_conv_4(layer1_fu)
out_16 = F.interpolate(
area_atten,
size=x.size()[2:],
mode='bilinear',
align_corners=False)
out_8 = F.interpolate(
out_8, size=x.size()[2:], mode='bilinear', align_corners=False)
out_4 = F.interpolate(
out_4, size=x.size()[2:], mode='bilinear', align_corners=False)
edge_out = F.interpolate(
edge_atten_,
size=x.size()[2:],
mode='bilinear',
align_corners=False)

return out_4.sigmoid(), out_8.sigmoid(), out_16.sigmoid(), edge_out

+ 105
- 0
modelscope/models/cv/salient_detection/models/utils.py View File

@@ -0,0 +1,105 @@
# Implementation in this file is modified based on deeplabv3
# Originally MIT license,publicly avaialbe at https://github.com/fregu856/deeplabv3/blob/master/model/aspp.py
# Implementation in this file is modified based on attention-module
# Originally MIT license,publicly avaialbe at https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py
import torch
import torch.nn as nn


class ConvBNReLU(nn.Module):

def __init__(self,
inplanes,
planes,
kernel_size=3,
stride=1,
padding=0,
dilation=1,
bias=False):
super(ConvBNReLU, self).__init__()
self.block = nn.Sequential(
nn.Conv2d(
inplanes,
planes,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias), nn.BatchNorm2d(planes), nn.ReLU(inplace=True))

def forward(self, x):
return self.block(x)


class ASPP(nn.Module):

def __init__(self, in_dim, out_dim):
super(ASPP, self).__init__()
mid_dim = 128
self.conv1 = ConvBNReLU(in_dim, mid_dim, kernel_size=1)
self.conv2 = ConvBNReLU(
in_dim, mid_dim, kernel_size=3, padding=2, dilation=2)
self.conv3 = ConvBNReLU(
in_dim, mid_dim, kernel_size=3, padding=5, dilation=5)
self.conv4 = ConvBNReLU(
in_dim, mid_dim, kernel_size=3, padding=7, dilation=7)
self.conv5 = ConvBNReLU(in_dim, mid_dim, kernel_size=1, padding=0)
self.fuse = ConvBNReLU(5 * mid_dim, out_dim, 3, 1, 1)
self.global_pooling = nn.AdaptiveAvgPool2d(1)

def forward(self, x):
conv1 = self.conv1(x)
conv2 = self.conv2(x)
conv3 = self.conv3(x)
conv4 = self.conv4(x)
xg = self.conv5(self.global_pooling(x))
conv5 = nn.Upsample((x.shape[2], x.shape[3]), mode='nearest')(xg)
return self.fuse(torch.cat((conv1, conv2, conv3, conv4, conv5), 1))


class ChannelAttention(nn.Module):

def __init__(self, inchs, ratio=16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(
nn.Conv2d(inchs, inchs // 16, 1, bias=False), nn.ReLU(),
nn.Conv2d(inchs // 16, inchs, 1, bias=False))
self.sigmoid = nn.Sigmoid()

def forward(self, x):
avg_out = self.fc(self.avg_pool(x))
max_out = self.fc(self.max_pool(x))
out = avg_out + max_out
return self.sigmoid(out)


class SpatialAttention(nn.Module):

def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()

self.conv1 = nn.Conv2d(
2, 1, kernel_size, padding=kernel_size // 2, bias=False)
self.sigmoid = nn.Sigmoid()

def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
x = self.conv1(x)
return self.sigmoid(x)


class CBAM(nn.Module):

def __init__(self, inchs, kernel_size=7):
super().__init__()
self.calayer = ChannelAttention(inchs=inchs)
self.saLayer = SpatialAttention(kernel_size=kernel_size)

def forward(self, x):
xca = self.calayer(x) * x
xsa = self.saLayer(xca) * xca
return xsa

+ 18
- 6
modelscope/models/cv/salient_detection/salient_model.py View File

@@ -2,7 +2,6 @@
import os.path as osp

import cv2
import numpy as np
import torch
from PIL import Image
from torchvision import transforms
@@ -10,8 +9,9 @@ from torchvision import transforms
from modelscope.metainfo import Models
from modelscope.models.base.base_torch_model import TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from .models import U2NET
from .models import U2NET, SENet


@MODELS.register_module(
@@ -22,13 +22,25 @@ class SalientDetection(TorchModel):
"""str -- model file root."""
super().__init__(model_dir, *args, **kwargs)
model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE)
self.model = U2NET(3, 1)

self.norm_mean = [0.485, 0.456, 0.406]
self.norm_std = [0.229, 0.224, 0.225]
self.norm_size = (320, 320)

config_path = osp.join(model_dir, 'config.py')
if osp.exists(config_path) is False:
self.model = U2NET(3, 1)
else:
self.model = SENet(backbone_path=None, pretrained=False)
config = Config.from_file(config_path)
self.norm_mean = config.norm_mean
self.norm_std = config.norm_std
self.norm_size = config.norm_size
checkpoint = torch.load(model_path, map_location='cpu')
self.transform_input = transforms.Compose([
transforms.Resize((320, 320)),
transforms.Resize(self.norm_size),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transforms.Normalize(mean=self.norm_mean, std=self.norm_std)
])
self.model.load_state_dict(checkpoint)
self.model.eval()


+ 5
- 0
modelscope/pipelines/cv/image_salient_detection_pipeline.py View File

@@ -12,6 +12,11 @@ from modelscope.utils.constant import Tasks

@PIPELINES.register_module(
Tasks.semantic_segmentation, module_name=Pipelines.salient_detection)
@PIPELINES.register_module(
Tasks.semantic_segmentation,
module_name=Pipelines.salient_boudary_detection)
@PIPELINES.register_module(
Tasks.semantic_segmentation, module_name=Pipelines.camouflaged_detection)
class ImageSalientDetectionPipeline(Pipeline):

def __init__(self, model: str, **kwargs):


+ 21
- 0
tests/pipelines/test_salient_detection.py View File

@@ -23,6 +23,27 @@ class SalientDetectionTest(unittest.TestCase, DemoCompatibilityCheck):
import cv2
cv2.imwrite(input_location + '_salient.jpg', result[OutputKeys.MASKS])

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_salient_boudary_detection(self):
input_location = 'data/test/images/image_salient_detection.jpg'
model_id = 'damo/cv_res2net_salient-detection'
salient_detect = pipeline(Tasks.semantic_segmentation, model=model_id)
result = salient_detect(input_location)
import cv2
cv2.imwrite(input_location + '_boudary_salient.jpg',
result[OutputKeys.MASKS])

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_camouflag_detection(self):
input_location = 'data/test/images/image_camouflag_detection.jpg'
model_id = 'damo/cv_res2net_camouflaged-detection'
camouflag_detect = pipeline(
Tasks.semantic_segmentation, model=model_id)
result = camouflag_detect(input_location)
import cv2
cv2.imwrite(input_location + '_camouflag.jpg',
result[OutputKeys.MASKS])

@unittest.skip('demo compatibility test is only enabled on a needed-basis')
def test_demo_compatibility(self):
self.compatibility_check()


Loading…
Cancel
Save