Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10244616master
@@ -0,0 +1,3 @@ | |||||
version https://git-lfs.github.com/spec/v1 | |||||
oid sha256:94b8e281d77ee6d3ea2a8a0c9408ecdbd29fe75f33ea5399b6ea00070ba77bd6 | |||||
size 13090 |
@@ -39,6 +39,7 @@ class Models(object): | |||||
mtcnn = 'mtcnn' | mtcnn = 'mtcnn' | ||||
ulfd = 'ulfd' | ulfd = 'ulfd' | ||||
video_inpainting = 'video-inpainting' | video_inpainting = 'video-inpainting' | ||||
hand_static = 'hand-static' | |||||
# EasyCV models | # EasyCV models | ||||
yolox = 'YOLOX' | yolox = 'YOLOX' | ||||
@@ -173,6 +174,7 @@ class Pipelines(object): | |||||
movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation' | movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation' | ||||
shop_segmentation = 'shop-segmentation' | shop_segmentation = 'shop-segmentation' | ||||
video_inpainting = 'video-inpainting' | video_inpainting = 'video-inpainting' | ||||
hand_static = 'hand-static' | |||||
# nlp tasks | # nlp tasks | ||||
sentence_similarity = 'sentence-similarity' | sentence_similarity = 'sentence-similarity' | ||||
@@ -0,0 +1,20 @@ | |||||
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. | |||||
from typing import TYPE_CHECKING | |||||
from modelscope.utils.import_utils import LazyImportModule | |||||
if TYPE_CHECKING: | |||||
from .hand_model import HandStatic | |||||
else: | |||||
_import_structure = {'hand_model': ['HandStatic']} | |||||
import sys | |||||
sys.modules[__name__] = LazyImportModule( | |||||
__name__, | |||||
globals()['__file__'], | |||||
_import_structure, | |||||
module_spec=__spec__, | |||||
extra_objects={}, | |||||
) |
@@ -0,0 +1,93 @@ | |||||
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. | |||||
import os | |||||
import sys | |||||
import cv2 | |||||
import numpy as np | |||||
import torch | |||||
import torch.nn.functional as F | |||||
from PIL import Image | |||||
from torch import nn | |||||
from torchvision.transforms import transforms | |||||
from modelscope.metainfo import Models | |||||
from modelscope.models.base import TorchModel | |||||
from modelscope.models.builder import MODELS | |||||
from modelscope.utils.constant import ModelFile, Tasks | |||||
from modelscope.utils.logger import get_logger | |||||
from .networks import StaticGestureNet | |||||
logger = get_logger() | |||||
map_idx = { | |||||
0: 'unrecog', | |||||
1: 'one', | |||||
2: 'two', | |||||
3: 'bixin', | |||||
4: 'yaogun', | |||||
5: 'zan', | |||||
6: 'fist', | |||||
7: 'ok', | |||||
8: 'tuoju', | |||||
9: 'd_bixin', | |||||
10: 'd_fist_left', | |||||
11: 'd_fist_right', | |||||
12: 'd_hand', | |||||
13: 'fashe', | |||||
14: 'five', | |||||
15: 'nohand' | |||||
} | |||||
img_size = [112, 112] | |||||
spatial_transform = transforms.Compose([ | |||||
transforms.Resize(img_size), | |||||
transforms.ToTensor(), | |||||
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |||||
]) | |||||
@MODELS.register_module(Tasks.hand_static, module_name=Models.hand_static) | |||||
class HandStatic(TorchModel): | |||||
def __init__(self, model_dir, device_id=0, *args, **kwargs): | |||||
super().__init__( | |||||
model_dir=model_dir, device_id=device_id, *args, **kwargs) | |||||
self.model = StaticGestureNet() | |||||
if torch.cuda.is_available(): | |||||
self.device = 'cuda' | |||||
else: | |||||
self.device = 'cpu' | |||||
self.params = torch.load( | |||||
'{}/{}'.format(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), | |||||
map_location=self.device) | |||||
self.model.load_state_dict(self.params) | |||||
self.model.to(self.device) | |||||
self.model.eval() | |||||
self.device_id = device_id | |||||
if self.device_id >= 0 and self.device == 'cuda': | |||||
self.model.to('cuda:{}'.format(self.device_id)) | |||||
logger.info('Use GPU: {}'.format(self.device_id)) | |||||
else: | |||||
self.device_id = -1 | |||||
logger.info('Use CPU for inference') | |||||
def forward(self, x): | |||||
pred_result = self.model(x) | |||||
return pred_result | |||||
def infer(img_path, model, device): | |||||
img = Image.open(img_path) | |||||
clip = spatial_transform(img) | |||||
clip = clip.unsqueeze(0).to(device).float() | |||||
outputs = model(clip) | |||||
predicted = int(outputs.max(1)[1]) | |||||
pred_result = map_idx.get(predicted) | |||||
logger.info('pred result: {}'.format(pred_result)) | |||||
return pred_result |
@@ -0,0 +1,358 @@ | |||||
""" HandStatic | |||||
The implementation here is modified based on MobileFaceNet, | |||||
originally Apache 2.0 License and publicly avaialbe at https://github.com/xuexingyu24/MobileFaceNet_Tutorial_Pytorch | |||||
""" | |||||
import os | |||||
import torch | |||||
import torch.nn as nn | |||||
import torchvision | |||||
import torchvision.models as models | |||||
from torch.nn import (AdaptiveAvgPool2d, BatchNorm1d, BatchNorm2d, Conv2d, | |||||
Dropout, Linear, MaxPool2d, Module, PReLU, ReLU, | |||||
Sequential, Sigmoid) | |||||
class StaticGestureNet(torch.nn.Module): | |||||
def __init__(self, train=True): | |||||
super().__init__() | |||||
model = MobileFaceNet(512) | |||||
self.feature_extractor = model | |||||
self.fc_layer = torch.nn.Sequential( | |||||
nn.Linear(512, 128), nn.Softplus(), nn.Linear(128, 15)) | |||||
self.sigmoid = nn.Sigmoid() | |||||
def forward(self, inputs): | |||||
out = self.feature_extractor(inputs) | |||||
out = self.fc_layer(out) | |||||
out = self.sigmoid(out) | |||||
return out | |||||
class Flatten(Module): | |||||
def forward(self, input): | |||||
return input.view(input.size(0), -1) | |||||
def l2_norm(input, axis=1): | |||||
norm = torch.norm(input, 2, axis, True) | |||||
output = torch.div(input, norm) | |||||
return output | |||||
class SEModule(Module): | |||||
def __init__(self, channels, reduction): | |||||
super(SEModule, self).__init__() | |||||
self.avg_pool = AdaptiveAvgPool2d(1) | |||||
self.fc1 = Conv2d( | |||||
channels, | |||||
channels // reduction, | |||||
kernel_size=1, | |||||
padding=0, | |||||
bias=False) | |||||
self.relu = ReLU(inplace=True) | |||||
self.fc2 = Conv2d( | |||||
channels // reduction, | |||||
channels, | |||||
kernel_size=1, | |||||
padding=0, | |||||
bias=False) | |||||
self.sigmoid = Sigmoid() | |||||
def forward(self, x): | |||||
module_input = x | |||||
x = self.avg_pool(x) | |||||
x = self.fc1(x) | |||||
x = self.relu(x) | |||||
x = self.fc2(x) | |||||
x = self.sigmoid(x) | |||||
return module_input * x | |||||
class BottleneckIR(Module): | |||||
def __init__(self, in_channel, depth, stride): | |||||
super(BottleneckIR, self).__init__() | |||||
if in_channel == depth: | |||||
self.shortcut_layer = MaxPool2d(1, stride) | |||||
else: | |||||
self.shortcut_layer = Sequential( | |||||
Conv2d(in_channel, depth, (1, 1), stride, bias=False), | |||||
BatchNorm2d(depth)) | |||||
self.res_layer = Sequential( | |||||
BatchNorm2d(in_channel), | |||||
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), | |||||
PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False), | |||||
BatchNorm2d(depth)) | |||||
def forward(self, x): | |||||
shortcut = self.shortcut_layer(x) | |||||
res = self.res_layer(x) | |||||
return res + shortcut | |||||
class BottleneckIRSE(Module): | |||||
def __init__(self, in_channel, depth, stride): | |||||
super(BottleneckIRSE, self).__init__() | |||||
if in_channel == depth: | |||||
self.shortcut_layer = MaxPool2d(1, stride) | |||||
else: | |||||
self.shortcut_layer = Sequential( | |||||
Conv2d(in_channel, depth, (1, 1), stride, bias=False), | |||||
BatchNorm2d(depth)) | |||||
self.res_layer = Sequential( | |||||
BatchNorm2d(in_channel), | |||||
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), | |||||
PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False), | |||||
BatchNorm2d(depth), SEModule(depth, 16)) | |||||
def forward(self, x): | |||||
shortcut = self.shortcut_layer(x) | |||||
res = self.res_layer(x) | |||||
return res + shortcut | |||||
def get_block(in_channel, depth, num_units, stride=2): | |||||
return [Bottleneck(in_channel, depth, stride) | |||||
] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] | |||||
def get_blocks(num_layers): | |||||
if num_layers == 50: | |||||
blocks = [ | |||||
get_block(in_channel=64, depth=64, num_units=3), | |||||
get_block(in_channel=64, depth=128, num_units=4), | |||||
get_block(in_channel=128, depth=256, num_units=14), | |||||
get_block(in_channel=256, depth=512, num_units=3) | |||||
] | |||||
elif num_layers == 100: | |||||
blocks = [ | |||||
get_block(in_channel=64, depth=64, num_units=3), | |||||
get_block(in_channel=64, depth=128, num_units=13), | |||||
get_block(in_channel=128, depth=256, num_units=30), | |||||
get_block(in_channel=256, depth=512, num_units=3) | |||||
] | |||||
elif num_layers == 152: | |||||
blocks = [ | |||||
get_block(in_channel=64, depth=64, num_units=3), | |||||
get_block(in_channel=64, depth=128, num_units=8), | |||||
get_block(in_channel=128, depth=256, num_units=36), | |||||
get_block(in_channel=256, depth=512, num_units=3) | |||||
] | |||||
return blocks | |||||
class Backbone(Module): | |||||
def __init__(self, num_layers, drop_ratio, mode='ir'): | |||||
super(Backbone, self).__init__() | |||||
assert num_layers in [50, 100, | |||||
152], 'num_layers should be 50,100, or 152' | |||||
assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' | |||||
blocks = get_blocks(num_layers) | |||||
if mode == 'ir': | |||||
unit_module = BottleneckIR | |||||
elif mode == 'ir_se': | |||||
unit_module = BottleneckIRSE | |||||
self.input_layer = Sequential( | |||||
Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), | |||||
PReLU(64)) | |||||
self.output_layer = Sequential( | |||||
BatchNorm2d(512), Dropout(drop_ratio), Flatten(), | |||||
Linear(512 * 7 * 7, 512), BatchNorm1d(512)) | |||||
modules = [] | |||||
for block in blocks: | |||||
for bottleneck in block: | |||||
modules.append( | |||||
unit_module(bottleneck.in_channel, bottleneck.depth, | |||||
bottleneck.stride)) | |||||
self.body = Sequential(*modules) | |||||
def forward(self, x): | |||||
x = self.input_layer(x) | |||||
x = self.body(x) | |||||
x = self.output_layer(x) | |||||
return l2_norm(x) | |||||
class ConvBlock(Module): | |||||
def __init__(self, | |||||
in_c, | |||||
out_c, | |||||
kernel=(1, 1), | |||||
stride=(1, 1), | |||||
padding=(0, 0), | |||||
groups=1): | |||||
super(ConvBlock, self).__init__() | |||||
self.conv = Conv2d( | |||||
in_c, | |||||
out_channels=out_c, | |||||
kernel_size=kernel, | |||||
groups=groups, | |||||
stride=stride, | |||||
padding=padding, | |||||
bias=False) | |||||
self.bn = BatchNorm2d(out_c) | |||||
self.prelu = PReLU(out_c) | |||||
def forward(self, x): | |||||
x = self.conv(x) | |||||
x = self.bn(x) | |||||
x = self.prelu(x) | |||||
return x | |||||
class LinearBlock(Module): | |||||
def __init__(self, | |||||
in_c, | |||||
out_c, | |||||
kernel=(1, 1), | |||||
stride=(1, 1), | |||||
padding=(0, 0), | |||||
groups=1): | |||||
super(LinearBlock, self).__init__() | |||||
self.conv = Conv2d( | |||||
in_c, | |||||
out_channels=out_c, | |||||
kernel_size=kernel, | |||||
groups=groups, | |||||
stride=stride, | |||||
padding=padding, | |||||
bias=False) | |||||
self.bn = BatchNorm2d(out_c) | |||||
def forward(self, x): | |||||
x = self.conv(x) | |||||
x = self.bn(x) | |||||
return x | |||||
class DepthWise(Module): | |||||
def __init__(self, | |||||
in_c, | |||||
out_c, | |||||
residual=False, | |||||
kernel=(3, 3), | |||||
stride=(2, 2), | |||||
padding=(1, 1), | |||||
groups=1): | |||||
super(DepthWise, self).__init__() | |||||
self.conv = ConvBlock( | |||||
in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)) | |||||
self.conv_dw = ConvBlock( | |||||
groups, | |||||
groups, | |||||
groups=groups, | |||||
kernel=kernel, | |||||
padding=padding, | |||||
stride=stride) | |||||
self.project = LinearBlock( | |||||
groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1)) | |||||
self.residual = residual | |||||
def forward(self, x): | |||||
if self.residual: | |||||
short_cut = x | |||||
x = self.conv(x) | |||||
x = self.conv_dw(x) | |||||
x = self.project(x) | |||||
if self.residual: | |||||
output = short_cut + x | |||||
else: | |||||
output = x | |||||
return output | |||||
class Residual(Module): | |||||
def __init__(self, | |||||
c, | |||||
num_block, | |||||
groups, | |||||
kernel=(3, 3), | |||||
stride=(1, 1), | |||||
padding=(1, 1)): | |||||
super(Residual, self).__init__() | |||||
modules = [] | |||||
for _ in range(num_block): | |||||
modules.append( | |||||
DepthWise( | |||||
c, | |||||
c, | |||||
residual=True, | |||||
kernel=kernel, | |||||
padding=padding, | |||||
stride=stride, | |||||
groups=groups)) | |||||
self.model = Sequential(*modules) | |||||
def forward(self, x): | |||||
return self.model(x) | |||||
class MobileFaceNet(Module): | |||||
def __init__(self, embedding_size): | |||||
super(MobileFaceNet, self).__init__() | |||||
self.conv1 = ConvBlock( | |||||
3, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1)) | |||||
self.conv2_dw = ConvBlock( | |||||
64, 64, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64) | |||||
self.conv_23 = DepthWise( | |||||
64, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128) | |||||
self.conv_3 = Residual( | |||||
64, | |||||
num_block=4, | |||||
groups=128, | |||||
kernel=(3, 3), | |||||
stride=(1, 1), | |||||
padding=(1, 1)) | |||||
self.conv_34 = DepthWise( | |||||
64, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256) | |||||
self.conv_4 = Residual( | |||||
128, | |||||
num_block=6, | |||||
groups=256, | |||||
kernel=(3, 3), | |||||
stride=(1, 1), | |||||
padding=(1, 1)) | |||||
self.conv_45 = DepthWise( | |||||
128, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512) | |||||
self.conv_5 = Residual( | |||||
128, | |||||
num_block=2, | |||||
groups=256, | |||||
kernel=(3, 3), | |||||
stride=(1, 1), | |||||
padding=(1, 1)) | |||||
self.conv_6_sep = ConvBlock( | |||||
128, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0)) | |||||
self.conv_6_dw = LinearBlock( | |||||
512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)) | |||||
self.conv_6_flatten = Flatten() | |||||
self.linear = Linear(512, embedding_size, bias=False) | |||||
self.bn = BatchNorm1d(embedding_size) | |||||
def forward(self, x): | |||||
out = self.conv1(x) | |||||
out = self.conv2_dw(out) | |||||
out = self.conv_23(out) | |||||
out = self.conv_3(out) | |||||
out = self.conv_34(out) | |||||
out = self.conv_4(out) | |||||
out = self.conv_45(out) | |||||
out = self.conv_5(out) | |||||
out = self.conv_6_sep(out) | |||||
out = self.conv_6_dw(out) | |||||
out = self.conv_6_flatten(out) | |||||
out = self.linear(out) | |||||
return l2_norm(out) |
@@ -632,5 +632,9 @@ TASK_OUTPUTS = { | |||||
# { | # { | ||||
# 'output': ['Done' / 'Decode_Error'] | # 'output': ['Done' / 'Decode_Error'] | ||||
# } | # } | ||||
Tasks.video_inpainting: [OutputKeys.OUTPUT] | |||||
Tasks.video_inpainting: [OutputKeys.OUTPUT], | |||||
# { | |||||
# 'output': ['bixin'] | |||||
# } | |||||
Tasks.hand_static: [OutputKeys.OUTPUT] | |||||
} | } |
@@ -178,6 +178,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
'damo/cv_vitb16_segmentation_shop-seg'), | 'damo/cv_vitb16_segmentation_shop-seg'), | ||||
Tasks.video_inpainting: (Pipelines.video_inpainting, | Tasks.video_inpainting: (Pipelines.video_inpainting, | ||||
'damo/cv_video-inpainting'), | 'damo/cv_video-inpainting'), | ||||
Tasks.hand_static: (Pipelines.hand_static, | |||||
'damo/cv_mobileface_hand-static'), | |||||
} | } | ||||
@@ -52,7 +52,8 @@ if TYPE_CHECKING: | |||||
from .ulfd_face_detection_pipeline import UlfdFaceDetectionPipeline | from .ulfd_face_detection_pipeline import UlfdFaceDetectionPipeline | ||||
from .retina_face_detection_pipeline import RetinaFaceDetectionPipeline | from .retina_face_detection_pipeline import RetinaFaceDetectionPipeline | ||||
from .facial_expression_recognition_pipeline import FacialExpressionRecognitionPipeline | from .facial_expression_recognition_pipeline import FacialExpressionRecognitionPipeline | ||||
from .mtcnn_face_detection_pipeline import MtcnnFaceDetectionPipeline | |||||
from .mtcnn_face_detection_pipeline import MtcnnFaceDetectionPipelin | |||||
from .hand_static_pipeline import HandStaticPipeline | |||||
else: | else: | ||||
_import_structure = { | _import_structure = { | ||||
@@ -119,6 +120,7 @@ else: | |||||
'facial_expression_recognition_pipelin': | 'facial_expression_recognition_pipelin': | ||||
['FacialExpressionRecognitionPipeline'], | ['FacialExpressionRecognitionPipeline'], | ||||
'mtcnn_face_detection_pipeline': ['MtcnnFaceDetectionPipeline'], | 'mtcnn_face_detection_pipeline': ['MtcnnFaceDetectionPipeline'], | ||||
'hand_static_pipeline': ['HandStaticPipeline'], | |||||
} | } | ||||
import sys | import sys | ||||
@@ -0,0 +1,37 @@ | |||||
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. | |||||
from typing import Any, Dict | |||||
from modelscope.metainfo import Pipelines | |||||
from modelscope.models.cv.hand_static import hand_model | |||||
from modelscope.outputs import OutputKeys | |||||
from modelscope.pipelines.base import Input, Pipeline | |||||
from modelscope.pipelines.builder import PIPELINES | |||||
from modelscope.utils.constant import Tasks | |||||
from modelscope.utils.logger import get_logger | |||||
logger = get_logger() | |||||
@PIPELINES.register_module( | |||||
Tasks.hand_static, module_name=Pipelines.hand_static) | |||||
class HandStaticPipeline(Pipeline): | |||||
def __init__(self, model: str, **kwargs): | |||||
""" | |||||
use `model` to create hand static pipeline for prediction | |||||
Args: | |||||
model: model id on modelscope hub. | |||||
""" | |||||
super().__init__(model=model, **kwargs) | |||||
logger.info('load model done') | |||||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||||
return input | |||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
result = hand_model.infer(input['img_path'], self.model, self.device) | |||||
return {OutputKeys.OUTPUT: result} | |||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
return inputs |
@@ -42,6 +42,7 @@ class CVTasks(object): | |||||
portrait_matting = 'portrait-matting' | portrait_matting = 'portrait-matting' | ||||
text_driven_segmentation = 'text-driven-segmentation' | text_driven_segmentation = 'text-driven-segmentation' | ||||
shop_segmentation = 'shop-segmentation' | shop_segmentation = 'shop-segmentation' | ||||
hand_static = 'hand-static' | |||||
# image editing | # image editing | ||||
skin_retouching = 'skin-retouching' | skin_retouching = 'skin-retouching' | ||||
@@ -0,0 +1,32 @@ | |||||
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. | |||||
import unittest | |||||
from modelscope.pipelines import pipeline | |||||
from modelscope.pipelines.base import Pipeline | |||||
from modelscope.utils.constant import Tasks | |||||
from modelscope.utils.test_utils import test_level | |||||
class HandStaticTest(unittest.TestCase): | |||||
def setUp(self) -> None: | |||||
self.model = 'damo/cv_mobileface_hand-static' | |||||
self.input = {'img_path': 'data/test/images/hand_static.jpg'} | |||||
def pipeline_inference(self, pipeline: Pipeline, input: str): | |||||
result = pipeline(input) | |||||
print(result) | |||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
def test_run_modelhub(self): | |||||
hand_static = pipeline(Tasks.hand_static, model=self.model) | |||||
self.pipeline_inference(hand_static, self.input) | |||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
def test_run_modelhub_default_model(self): | |||||
hand_static = pipeline(Tasks.hand_static) | |||||
self.pipeline_inference(hand_static, self.input) | |||||
if __name__ == '__main__': | |||||
unittest.main() |