Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10252583master
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:a16038f7809127eb3e03cbae049592d193707e095309daca78f7d108d67fe4ec | |||
size 108357 |
@@ -42,6 +42,7 @@ class Models(object): | |||
hand_static = 'hand-static' | |||
face_human_hand_detection = 'face-human-hand-detection' | |||
face_emotion = 'face-emotion' | |||
product_segmentation = 'product-segmentation' | |||
# EasyCV models | |||
yolox = 'YOLOX' | |||
@@ -185,6 +186,7 @@ class Pipelines(object): | |||
hand_static = 'hand-static' | |||
face_human_hand_detection = 'face-human-hand-detection' | |||
face_emotion = 'face-emotion' | |||
product_segmentation = 'product-segmentation' | |||
# nlp tasks | |||
sentence_similarity = 'sentence-similarity' | |||
@@ -0,0 +1,20 @@ | |||
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. | |||
from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .seg_infer import F3NetProductSegmentation | |||
else: | |||
_import_structure = {'seg_infer': ['F3NetProductSegmentation']} | |||
import sys | |||
sys.modules[__name__] = LazyImportModule( | |||
__name__, | |||
globals()['__file__'], | |||
_import_structure, | |||
module_spec=__spec__, | |||
extra_objects={}, | |||
) |
@@ -0,0 +1,197 @@ | |||
# The implementation here is modified based on F3Net, | |||
# originally Apache 2.0 License and publicly avaialbe at https://github.com/weijun88/F3Net | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
class Bottleneck(nn.Module): | |||
def __init__(self, | |||
inplanes, | |||
planes, | |||
stride=1, | |||
downsample=None, | |||
dilation=1): | |||
super(Bottleneck, self).__init__() | |||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) | |||
self.bn1 = nn.BatchNorm2d(planes) | |||
self.conv2 = nn.Conv2d( | |||
planes, | |||
planes, | |||
kernel_size=3, | |||
stride=stride, | |||
padding=(3 * dilation - 1) // 2, | |||
bias=False, | |||
dilation=dilation) | |||
self.bn2 = nn.BatchNorm2d(planes) | |||
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) | |||
self.bn3 = nn.BatchNorm2d(planes * 4) | |||
self.downsample = downsample | |||
def forward(self, x): | |||
out = F.relu(self.bn1(self.conv1(x)), inplace=True) | |||
out = F.relu(self.bn2(self.conv2(out)), inplace=True) | |||
out = self.bn3(self.conv3(out)) | |||
if self.downsample is not None: | |||
x = self.downsample(x) | |||
return F.relu(out + x, inplace=True) | |||
class ResNet(nn.Module): | |||
def __init__(self): | |||
super(ResNet, self).__init__() | |||
self.inplanes = 64 | |||
self.conv1 = nn.Conv2d( | |||
3, 64, kernel_size=7, stride=2, padding=3, bias=False) | |||
self.bn1 = nn.BatchNorm2d(64) | |||
self.layer1 = self.make_layer(64, 3, stride=1, dilation=1) | |||
self.layer2 = self.make_layer(128, 4, stride=2, dilation=1) | |||
self.layer3 = self.make_layer(256, 6, stride=2, dilation=1) | |||
self.layer4 = self.make_layer(512, 3, stride=2, dilation=1) | |||
def make_layer(self, planes, blocks, stride, dilation): | |||
downsample = nn.Sequential( | |||
nn.Conv2d( | |||
self.inplanes, | |||
planes * 4, | |||
kernel_size=1, | |||
stride=stride, | |||
bias=False), nn.BatchNorm2d(planes * 4)) | |||
layers = [ | |||
Bottleneck( | |||
self.inplanes, planes, stride, downsample, dilation=dilation) | |||
] | |||
self.inplanes = planes * 4 | |||
for _ in range(1, blocks): | |||
layers.append(Bottleneck(self.inplanes, planes, dilation=dilation)) | |||
return nn.Sequential(*layers) | |||
def forward(self, x): | |||
x = x.reshape(1, 3, 448, 448) | |||
out1 = F.relu(self.bn1(self.conv1(x)), inplace=True) | |||
out1 = F.max_pool2d(out1, kernel_size=3, stride=2, padding=1) | |||
out2 = self.layer1(out1) | |||
out3 = self.layer2(out2) | |||
out4 = self.layer3(out3) | |||
out5 = self.layer4(out4) | |||
return out2, out3, out4, out5 | |||
class CFM(nn.Module): | |||
def __init__(self): | |||
super(CFM, self).__init__() | |||
self.conv1h = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) | |||
self.bn1h = nn.BatchNorm2d(64) | |||
self.conv2h = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) | |||
self.bn2h = nn.BatchNorm2d(64) | |||
self.conv3h = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) | |||
self.bn3h = nn.BatchNorm2d(64) | |||
self.conv4h = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) | |||
self.bn4h = nn.BatchNorm2d(64) | |||
self.conv1v = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) | |||
self.bn1v = nn.BatchNorm2d(64) | |||
self.conv2v = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) | |||
self.bn2v = nn.BatchNorm2d(64) | |||
self.conv3v = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) | |||
self.bn3v = nn.BatchNorm2d(64) | |||
self.conv4v = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) | |||
self.bn4v = nn.BatchNorm2d(64) | |||
def forward(self, left, down): | |||
if down.size()[2:] != left.size()[2:]: | |||
down = F.interpolate(down, size=left.size()[2:], mode='bilinear') | |||
out1h = F.relu(self.bn1h(self.conv1h(left)), inplace=True) | |||
out2h = F.relu(self.bn2h(self.conv2h(out1h)), inplace=True) | |||
out1v = F.relu(self.bn1v(self.conv1v(down)), inplace=True) | |||
out2v = F.relu(self.bn2v(self.conv2v(out1v)), inplace=True) | |||
fuse = out2h * out2v | |||
out3h = F.relu(self.bn3h(self.conv3h(fuse)), inplace=True) + out1h | |||
out4h = F.relu(self.bn4h(self.conv4h(out3h)), inplace=True) | |||
out3v = F.relu(self.bn3v(self.conv3v(fuse)), inplace=True) + out1v | |||
out4v = F.relu(self.bn4v(self.conv4v(out3v)), inplace=True) | |||
return out4h, out4v | |||
class Decoder(nn.Module): | |||
def __init__(self): | |||
super(Decoder, self).__init__() | |||
self.cfm45 = CFM() | |||
self.cfm34 = CFM() | |||
self.cfm23 = CFM() | |||
def forward(self, out2h, out3h, out4h, out5v, fback=None): | |||
if fback is not None: | |||
refine5 = F.interpolate( | |||
fback, size=out5v.size()[2:], mode='bilinear') | |||
refine4 = F.interpolate( | |||
fback, size=out4h.size()[2:], mode='bilinear') | |||
refine3 = F.interpolate( | |||
fback, size=out3h.size()[2:], mode='bilinear') | |||
refine2 = F.interpolate( | |||
fback, size=out2h.size()[2:], mode='bilinear') | |||
out5v = out5v + refine5 | |||
out4h, out4v = self.cfm45(out4h + refine4, out5v) | |||
out3h, out3v = self.cfm34(out3h + refine3, out4v) | |||
out2h, pred = self.cfm23(out2h + refine2, out3v) | |||
else: | |||
out4h, out4v = self.cfm45(out4h, out5v) | |||
out3h, out3v = self.cfm34(out3h, out4v) | |||
out2h, pred = self.cfm23(out2h, out3v) | |||
return out2h, out3h, out4h, out5v, pred | |||
class F3Net(nn.Module): | |||
def __init__(self): | |||
super(F3Net, self).__init__() | |||
self.bkbone = ResNet() | |||
self.squeeze5 = nn.Sequential( | |||
nn.Conv2d(2048, 64, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) | |||
self.squeeze4 = nn.Sequential( | |||
nn.Conv2d(1024, 64, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) | |||
self.squeeze3 = nn.Sequential( | |||
nn.Conv2d(512, 64, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) | |||
self.squeeze2 = nn.Sequential( | |||
nn.Conv2d(256, 64, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) | |||
self.decoder1 = Decoder() | |||
self.decoder2 = Decoder() | |||
self.linearp1 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) | |||
self.linearp2 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) | |||
self.linearr2 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) | |||
self.linearr3 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) | |||
self.linearr4 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) | |||
self.linearr5 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) | |||
def forward(self, x, shape=None): | |||
x = x.reshape(1, 3, 448, 448) | |||
out2h, out3h, out4h, out5v = self.bkbone(x) | |||
out2h, out3h, out4h, out5v = self.squeeze2(out2h), self.squeeze3( | |||
out3h), self.squeeze4(out4h), self.squeeze5(out5v) | |||
out2h, out3h, out4h, out5v, pred1 = self.decoder1( | |||
out2h, out3h, out4h, out5v) | |||
out2h, out3h, out4h, out5v, pred2 = self.decoder2( | |||
out2h, out3h, out4h, out5v, pred1) | |||
shape = x.size()[2:] if shape is None else shape | |||
pred1 = F.interpolate( | |||
self.linearp1(pred1), size=shape, mode='bilinear') | |||
pred2 = F.interpolate( | |||
self.linearp2(pred2), size=shape, mode='bilinear') | |||
out2h = F.interpolate( | |||
self.linearr2(out2h), size=shape, mode='bilinear') | |||
out3h = F.interpolate( | |||
self.linearr3(out3h), size=shape, mode='bilinear') | |||
out4h = F.interpolate( | |||
self.linearr4(out4h), size=shape, mode='bilinear') | |||
out5h = F.interpolate( | |||
self.linearr5(out5v), size=shape, mode='bilinear') | |||
return pred1, pred2, out2h, out3h, out4h, out5h |
@@ -0,0 +1,77 @@ | |||
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. | |||
import cv2 | |||
import numpy as np | |||
import torch | |||
from PIL import Image | |||
from modelscope.metainfo import Models | |||
from modelscope.models.base import TorchModel | |||
from modelscope.models.builder import MODELS | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.logger import get_logger | |||
from .net import F3Net | |||
logger = get_logger() | |||
def load_state_dict(model_dir, device): | |||
_dict = torch.load( | |||
'{}/{}'.format(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), | |||
map_location=device) | |||
state_dict = {} | |||
for k, v in _dict.items(): | |||
if k.startswith('module'): | |||
k = k[7:] | |||
state_dict[k] = v | |||
return state_dict | |||
@MODELS.register_module( | |||
Tasks.product_segmentation, module_name=Models.product_segmentation) | |||
class F3NetForProductSegmentation(TorchModel): | |||
def __init__(self, model_dir, device_id=0, *args, **kwargs): | |||
super().__init__( | |||
model_dir=model_dir, device_id=device_id, *args, **kwargs) | |||
self.model = F3Net() | |||
if torch.cuda.is_available(): | |||
self.device = 'cuda' | |||
logger.info('Use GPU') | |||
else: | |||
self.device = 'cpu' | |||
logger.info('Use CPU') | |||
self.params = load_state_dict(model_dir, self.device) | |||
self.model.load_state_dict(self.params) | |||
self.model.to(self.device) | |||
self.model.eval() | |||
self.model.to(self.device) | |||
def forward(self, x): | |||
pred_result = self.model(x) | |||
return pred_result | |||
mean, std = np.array([[[124.55, 118.90, | |||
102.94]]]), np.array([[[56.77, 55.97, 57.50]]]) | |||
def inference(model, device, input_path): | |||
img = Image.open(input_path) | |||
img = np.array(img.convert('RGB')).astype(np.float32) | |||
img = (img - mean) / std | |||
img = cv2.resize(img, dsize=(448, 448), interpolation=cv2.INTER_LINEAR) | |||
img = torch.from_numpy(img) | |||
img = img.permute(2, 0, 1) | |||
img = img.to(device).float() | |||
outputs = model(img) | |||
out = outputs[0] | |||
pred = (torch.sigmoid(out[0, 0]) * 255).cpu().numpy() | |||
pred[pred < 20] = 0 | |||
pred = pred[:, :, np.newaxis] | |||
pred = np.round(pred) | |||
logger.info('Inference Done') | |||
return pred |
@@ -674,5 +674,12 @@ TASK_OUTPUTS = { | |||
# { | |||
# {'output': 'Happiness', 'boxes': (203, 104, 663, 564)} | |||
# } | |||
Tasks.face_emotion: [OutputKeys.OUTPUT, OutputKeys.BOXES] | |||
Tasks.face_emotion: [OutputKeys.OUTPUT, OutputKeys.BOXES], | |||
# { | |||
# "masks": [ | |||
# np.array # 2D array containing only 0, 255 | |||
# ] | |||
# } | |||
Tasks.product_segmentation: [OutputKeys.MASKS], | |||
} |
@@ -187,6 +187,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
(Pipelines.face_human_hand_detection, | |||
'damo/cv_nanodet_face-human-hand-detection'), | |||
Tasks.face_emotion: (Pipelines.face_emotion, 'damo/cv_face-emotion'), | |||
Tasks.product_segmentation: (Pipelines.product_segmentation, | |||
'damo/cv_F3Net_product-segmentation'), | |||
} | |||
@@ -0,0 +1,40 @@ | |||
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. | |||
from typing import Any, Dict | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.models.cv.product_segmentation import seg_infer | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.base import Input, Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
@PIPELINES.register_module( | |||
Tasks.product_segmentation, module_name=Pipelines.product_segmentation) | |||
class F3NetForProductSegmentationPipeline(Pipeline): | |||
def __init__(self, model: str, **kwargs): | |||
""" | |||
use `model` to create product segmentation pipeline for prediction | |||
Args: | |||
model: model id on modelscope hub. | |||
""" | |||
super().__init__(model=model, **kwargs) | |||
logger.info('load model done') | |||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||
return input | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
mask = seg_infer.inference(self.model, self.device, | |||
input['input_path']) | |||
return {OutputKeys.MASKS: mask} | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
return inputs |
@@ -45,6 +45,7 @@ class CVTasks(object): | |||
hand_static = 'hand-static' | |||
face_human_hand_detection = 'face-human-hand-detection' | |||
face_emotion = 'face-emotion' | |||
product_segmentation = 'product-segmentation' | |||
# image editing | |||
skin_retouching = 'skin-retouching' | |||
@@ -0,0 +1,43 @@ | |||
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. | |||
import unittest | |||
import cv2 | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines import pipeline | |||
from modelscope.pipelines.base import Pipeline | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.logger import get_logger | |||
from modelscope.utils.test_utils import test_level | |||
logger = get_logger() | |||
class ProductSegmentationTest(unittest.TestCase): | |||
def setUp(self) -> None: | |||
self.model_id = 'damo/cv_F3Net_product-segmentation' | |||
self.input = { | |||
'input_path': 'data/test/images/product_segmentation.jpg' | |||
} | |||
def pipeline_inference(self, pipeline: Pipeline, input: str): | |||
result = pipeline(input) | |||
cv2.imwrite('test_product_segmentation_mask.jpg', | |||
result[OutputKeys.MASKS]) | |||
logger.info('test done') | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_modelhub(self): | |||
product_segmentation = pipeline( | |||
Tasks.product_segmentation, model=self.model_id) | |||
self.pipeline_inference(product_segmentation, self.input) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_modelhub_default_model(self): | |||
product_segmentation = pipeline(Tasks.product_segmentation) | |||
self.pipeline_inference(product_segmentation, self.input) | |||
if __name__ == '__main__': | |||
unittest.main() |