tingwei.gtw yingda.chen 3 years ago
parent
commit
a079ab922f
10 changed files with 393 additions and 1 deletions
  1. +3
    -0
      data/test/images/product_segmentation.jpg
  2. +2
    -0
      modelscope/metainfo.py
  3. +20
    -0
      modelscope/models/cv/product_segmentation/__init__.py
  4. +197
    -0
      modelscope/models/cv/product_segmentation/net.py
  5. +77
    -0
      modelscope/models/cv/product_segmentation/seg_infer.py
  6. +8
    -1
      modelscope/outputs.py
  7. +2
    -0
      modelscope/pipelines/builder.py
  8. +40
    -0
      modelscope/pipelines/cv/product_segmentation_pipeline.py
  9. +1
    -0
      modelscope/utils/constant.py
  10. +43
    -0
      tests/pipelines/test_product_segmentation.py

+ 3
- 0
data/test/images/product_segmentation.jpg View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a16038f7809127eb3e03cbae049592d193707e095309daca78f7d108d67fe4ec
size 108357

+ 2
- 0
modelscope/metainfo.py View File

@@ -42,6 +42,7 @@ class Models(object):
hand_static = 'hand-static'
face_human_hand_detection = 'face-human-hand-detection'
face_emotion = 'face-emotion'
product_segmentation = 'product-segmentation'

# EasyCV models
yolox = 'YOLOX'
@@ -185,6 +186,7 @@ class Pipelines(object):
hand_static = 'hand-static'
face_human_hand_detection = 'face-human-hand-detection'
face_emotion = 'face-emotion'
product_segmentation = 'product-segmentation'

# nlp tasks
sentence_similarity = 'sentence-similarity'


+ 20
- 0
modelscope/models/cv/product_segmentation/__init__.py View File

@@ -0,0 +1,20 @@
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
from typing import TYPE_CHECKING

from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .seg_infer import F3NetProductSegmentation

else:
_import_structure = {'seg_infer': ['F3NetProductSegmentation']}

import sys

sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

+ 197
- 0
modelscope/models/cv/product_segmentation/net.py View File

@@ -0,0 +1,197 @@
# The implementation here is modified based on F3Net,
# originally Apache 2.0 License and publicly avaialbe at https://github.com/weijun88/F3Net

import torch
import torch.nn as nn
import torch.nn.functional as F


class Bottleneck(nn.Module):

def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
dilation=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(
planes,
planes,
kernel_size=3,
stride=stride,
padding=(3 * dilation - 1) // 2,
bias=False,
dilation=dilation)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.downsample = downsample

def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)), inplace=True)
out = F.relu(self.bn2(self.conv2(out)), inplace=True)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
x = self.downsample(x)
return F.relu(out + x, inplace=True)


class ResNet(nn.Module):

def __init__(self):
super(ResNet, self).__init__()
self.inplanes = 64
self.conv1 = nn.Conv2d(
3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self.make_layer(64, 3, stride=1, dilation=1)
self.layer2 = self.make_layer(128, 4, stride=2, dilation=1)
self.layer3 = self.make_layer(256, 6, stride=2, dilation=1)
self.layer4 = self.make_layer(512, 3, stride=2, dilation=1)

def make_layer(self, planes, blocks, stride, dilation):
downsample = nn.Sequential(
nn.Conv2d(
self.inplanes,
planes * 4,
kernel_size=1,
stride=stride,
bias=False), nn.BatchNorm2d(planes * 4))
layers = [
Bottleneck(
self.inplanes, planes, stride, downsample, dilation=dilation)
]
self.inplanes = planes * 4
for _ in range(1, blocks):
layers.append(Bottleneck(self.inplanes, planes, dilation=dilation))
return nn.Sequential(*layers)

def forward(self, x):
x = x.reshape(1, 3, 448, 448)
out1 = F.relu(self.bn1(self.conv1(x)), inplace=True)
out1 = F.max_pool2d(out1, kernel_size=3, stride=2, padding=1)
out2 = self.layer1(out1)
out3 = self.layer2(out2)
out4 = self.layer3(out3)
out5 = self.layer4(out4)
return out2, out3, out4, out5


class CFM(nn.Module):

def __init__(self):
super(CFM, self).__init__()
self.conv1h = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
self.bn1h = nn.BatchNorm2d(64)
self.conv2h = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
self.bn2h = nn.BatchNorm2d(64)
self.conv3h = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
self.bn3h = nn.BatchNorm2d(64)
self.conv4h = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
self.bn4h = nn.BatchNorm2d(64)

self.conv1v = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
self.bn1v = nn.BatchNorm2d(64)
self.conv2v = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
self.bn2v = nn.BatchNorm2d(64)
self.conv3v = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
self.bn3v = nn.BatchNorm2d(64)
self.conv4v = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
self.bn4v = nn.BatchNorm2d(64)

def forward(self, left, down):
if down.size()[2:] != left.size()[2:]:
down = F.interpolate(down, size=left.size()[2:], mode='bilinear')
out1h = F.relu(self.bn1h(self.conv1h(left)), inplace=True)
out2h = F.relu(self.bn2h(self.conv2h(out1h)), inplace=True)
out1v = F.relu(self.bn1v(self.conv1v(down)), inplace=True)
out2v = F.relu(self.bn2v(self.conv2v(out1v)), inplace=True)
fuse = out2h * out2v
out3h = F.relu(self.bn3h(self.conv3h(fuse)), inplace=True) + out1h
out4h = F.relu(self.bn4h(self.conv4h(out3h)), inplace=True)
out3v = F.relu(self.bn3v(self.conv3v(fuse)), inplace=True) + out1v
out4v = F.relu(self.bn4v(self.conv4v(out3v)), inplace=True)
return out4h, out4v


class Decoder(nn.Module):

def __init__(self):
super(Decoder, self).__init__()
self.cfm45 = CFM()
self.cfm34 = CFM()
self.cfm23 = CFM()

def forward(self, out2h, out3h, out4h, out5v, fback=None):
if fback is not None:
refine5 = F.interpolate(
fback, size=out5v.size()[2:], mode='bilinear')
refine4 = F.interpolate(
fback, size=out4h.size()[2:], mode='bilinear')
refine3 = F.interpolate(
fback, size=out3h.size()[2:], mode='bilinear')
refine2 = F.interpolate(
fback, size=out2h.size()[2:], mode='bilinear')
out5v = out5v + refine5
out4h, out4v = self.cfm45(out4h + refine4, out5v)
out3h, out3v = self.cfm34(out3h + refine3, out4v)
out2h, pred = self.cfm23(out2h + refine2, out3v)
else:
out4h, out4v = self.cfm45(out4h, out5v)
out3h, out3v = self.cfm34(out3h, out4v)
out2h, pred = self.cfm23(out2h, out3v)
return out2h, out3h, out4h, out5v, pred


class F3Net(nn.Module):

def __init__(self):
super(F3Net, self).__init__()
self.bkbone = ResNet()
self.squeeze5 = nn.Sequential(
nn.Conv2d(2048, 64, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
self.squeeze4 = nn.Sequential(
nn.Conv2d(1024, 64, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
self.squeeze3 = nn.Sequential(
nn.Conv2d(512, 64, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
self.squeeze2 = nn.Sequential(
nn.Conv2d(256, 64, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))

self.decoder1 = Decoder()
self.decoder2 = Decoder()
self.linearp1 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1)
self.linearp2 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1)

self.linearr2 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1)
self.linearr3 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1)
self.linearr4 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1)
self.linearr5 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1)

def forward(self, x, shape=None):
x = x.reshape(1, 3, 448, 448)
out2h, out3h, out4h, out5v = self.bkbone(x)
out2h, out3h, out4h, out5v = self.squeeze2(out2h), self.squeeze3(
out3h), self.squeeze4(out4h), self.squeeze5(out5v)
out2h, out3h, out4h, out5v, pred1 = self.decoder1(
out2h, out3h, out4h, out5v)
out2h, out3h, out4h, out5v, pred2 = self.decoder2(
out2h, out3h, out4h, out5v, pred1)

shape = x.size()[2:] if shape is None else shape
pred1 = F.interpolate(
self.linearp1(pred1), size=shape, mode='bilinear')
pred2 = F.interpolate(
self.linearp2(pred2), size=shape, mode='bilinear')

out2h = F.interpolate(
self.linearr2(out2h), size=shape, mode='bilinear')
out3h = F.interpolate(
self.linearr3(out3h), size=shape, mode='bilinear')
out4h = F.interpolate(
self.linearr4(out4h), size=shape, mode='bilinear')
out5h = F.interpolate(
self.linearr5(out5v), size=shape, mode='bilinear')
return pred1, pred2, out2h, out3h, out4h, out5h

+ 77
- 0
modelscope/models/cv/product_segmentation/seg_infer.py View File

@@ -0,0 +1,77 @@
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.

import cv2
import numpy as np
import torch
from PIL import Image

from modelscope.metainfo import Models
from modelscope.models.base import TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
from .net import F3Net

logger = get_logger()


def load_state_dict(model_dir, device):
_dict = torch.load(
'{}/{}'.format(model_dir, ModelFile.TORCH_MODEL_BIN_FILE),
map_location=device)
state_dict = {}
for k, v in _dict.items():
if k.startswith('module'):
k = k[7:]
state_dict[k] = v
return state_dict


@MODELS.register_module(
Tasks.product_segmentation, module_name=Models.product_segmentation)
class F3NetForProductSegmentation(TorchModel):

def __init__(self, model_dir, device_id=0, *args, **kwargs):

super().__init__(
model_dir=model_dir, device_id=device_id, *args, **kwargs)

self.model = F3Net()
if torch.cuda.is_available():
self.device = 'cuda'
logger.info('Use GPU')
else:
self.device = 'cpu'
logger.info('Use CPU')

self.params = load_state_dict(model_dir, self.device)
self.model.load_state_dict(self.params)
self.model.to(self.device)
self.model.eval()
self.model.to(self.device)

def forward(self, x):
pred_result = self.model(x)
return pred_result


mean, std = np.array([[[124.55, 118.90,
102.94]]]), np.array([[[56.77, 55.97, 57.50]]])


def inference(model, device, input_path):
img = Image.open(input_path)
img = np.array(img.convert('RGB')).astype(np.float32)
img = (img - mean) / std
img = cv2.resize(img, dsize=(448, 448), interpolation=cv2.INTER_LINEAR)
img = torch.from_numpy(img)
img = img.permute(2, 0, 1)
img = img.to(device).float()
outputs = model(img)
out = outputs[0]
pred = (torch.sigmoid(out[0, 0]) * 255).cpu().numpy()
pred[pred < 20] = 0
pred = pred[:, :, np.newaxis]
pred = np.round(pred)
logger.info('Inference Done')
return pred

+ 8
- 1
modelscope/outputs.py View File

@@ -674,5 +674,12 @@ TASK_OUTPUTS = {
# {
# {'output': 'Happiness', 'boxes': (203, 104, 663, 564)}
# }
Tasks.face_emotion: [OutputKeys.OUTPUT, OutputKeys.BOXES]
Tasks.face_emotion: [OutputKeys.OUTPUT, OutputKeys.BOXES],

# {
# "masks": [
# np.array # 2D array containing only 0, 255
# ]
# }
Tasks.product_segmentation: [OutputKeys.MASKS],
}

+ 2
- 0
modelscope/pipelines/builder.py View File

@@ -187,6 +187,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
(Pipelines.face_human_hand_detection,
'damo/cv_nanodet_face-human-hand-detection'),
Tasks.face_emotion: (Pipelines.face_emotion, 'damo/cv_face-emotion'),
Tasks.product_segmentation: (Pipelines.product_segmentation,
'damo/cv_F3Net_product-segmentation'),
}




+ 40
- 0
modelscope/pipelines/cv/product_segmentation_pipeline.py View File

@@ -0,0 +1,40 @@
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.

from typing import Any, Dict

from modelscope.metainfo import Pipelines
from modelscope.models.cv.product_segmentation import seg_infer
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


@PIPELINES.register_module(
Tasks.product_segmentation, module_name=Pipelines.product_segmentation)
class F3NetForProductSegmentationPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
"""
use `model` to create product segmentation pipeline for prediction
Args:
model: model id on modelscope hub.
"""

super().__init__(model=model, **kwargs)
logger.info('load model done')

def preprocess(self, input: Input) -> Dict[str, Any]:
return input

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:

mask = seg_infer.inference(self.model, self.device,
input['input_path'])
return {OutputKeys.MASKS: mask}

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

+ 1
- 0
modelscope/utils/constant.py View File

@@ -45,6 +45,7 @@ class CVTasks(object):
hand_static = 'hand-static'
face_human_hand_detection = 'face-human-hand-detection'
face_emotion = 'face-emotion'
product_segmentation = 'product-segmentation'

# image editing
skin_retouching = 'skin-retouching'


+ 43
- 0
tests/pipelines/test_product_segmentation.py View File

@@ -0,0 +1,43 @@
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
import unittest

import cv2

from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.pipelines.base import Pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
from modelscope.utils.test_utils import test_level

logger = get_logger()


class ProductSegmentationTest(unittest.TestCase):

def setUp(self) -> None:
self.model_id = 'damo/cv_F3Net_product-segmentation'
self.input = {
'input_path': 'data/test/images/product_segmentation.jpg'
}

def pipeline_inference(self, pipeline: Pipeline, input: str):
result = pipeline(input)
cv2.imwrite('test_product_segmentation_mask.jpg',
result[OutputKeys.MASKS])
logger.info('test done')

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub(self):
product_segmentation = pipeline(
Tasks.product_segmentation, model=self.model_id)
self.pipeline_inference(product_segmentation, self.input)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_modelhub_default_model(self):
product_segmentation = pipeline(Tasks.product_segmentation)
self.pipeline_inference(product_segmentation, self.input)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save