tingwei.gtw yingda.chen 3 years ago
parent
commit
372adb3936
11 changed files with 556 additions and 2 deletions
  1. +3
    -0
      data/test/images/hand_static.jpg
  2. +2
    -0
      modelscope/metainfo.py
  3. +20
    -0
      modelscope/models/cv/hand_static/__init__.py
  4. +93
    -0
      modelscope/models/cv/hand_static/hand_model.py
  5. +358
    -0
      modelscope/models/cv/hand_static/networks.py
  6. +5
    -1
      modelscope/outputs.py
  7. +2
    -0
      modelscope/pipelines/builder.py
  8. +3
    -1
      modelscope/pipelines/cv/__init__.py
  9. +37
    -0
      modelscope/pipelines/cv/hand_static_pipeline.py
  10. +1
    -0
      modelscope/utils/constant.py
  11. +32
    -0
      tests/pipelines/test_hand_static.py

+ 3
- 0
data/test/images/hand_static.jpg View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:94b8e281d77ee6d3ea2a8a0c9408ecdbd29fe75f33ea5399b6ea00070ba77bd6
size 13090

+ 2
- 0
modelscope/metainfo.py View File

@@ -39,6 +39,7 @@ class Models(object):
mtcnn = 'mtcnn'
ulfd = 'ulfd'
video_inpainting = 'video-inpainting'
hand_static = 'hand-static'

# EasyCV models
yolox = 'YOLOX'
@@ -173,6 +174,7 @@ class Pipelines(object):
movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation'
shop_segmentation = 'shop-segmentation'
video_inpainting = 'video-inpainting'
hand_static = 'hand-static'

# nlp tasks
sentence_similarity = 'sentence-similarity'


+ 20
- 0
modelscope/models/cv/hand_static/__init__.py View File

@@ -0,0 +1,20 @@
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
from typing import TYPE_CHECKING

from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .hand_model import HandStatic

else:
_import_structure = {'hand_model': ['HandStatic']}

import sys

sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

+ 93
- 0
modelscope/models/cv/hand_static/hand_model.py View File

@@ -0,0 +1,93 @@
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
import os
import sys

import cv2
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from torchvision.transforms import transforms

from modelscope.metainfo import Models
from modelscope.models.base import TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
from .networks import StaticGestureNet

logger = get_logger()

map_idx = {
0: 'unrecog',
1: 'one',
2: 'two',
3: 'bixin',
4: 'yaogun',
5: 'zan',
6: 'fist',
7: 'ok',
8: 'tuoju',
9: 'd_bixin',
10: 'd_fist_left',
11: 'd_fist_right',
12: 'd_hand',
13: 'fashe',
14: 'five',
15: 'nohand'
}

img_size = [112, 112]

spatial_transform = transforms.Compose([
transforms.Resize(img_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])


@MODELS.register_module(Tasks.hand_static, module_name=Models.hand_static)
class HandStatic(TorchModel):

def __init__(self, model_dir, device_id=0, *args, **kwargs):

super().__init__(
model_dir=model_dir, device_id=device_id, *args, **kwargs)

self.model = StaticGestureNet()
if torch.cuda.is_available():
self.device = 'cuda'
else:
self.device = 'cpu'
self.params = torch.load(
'{}/{}'.format(model_dir, ModelFile.TORCH_MODEL_BIN_FILE),
map_location=self.device)

self.model.load_state_dict(self.params)
self.model.to(self.device)
self.model.eval()
self.device_id = device_id
if self.device_id >= 0 and self.device == 'cuda':
self.model.to('cuda:{}'.format(self.device_id))
logger.info('Use GPU: {}'.format(self.device_id))
else:
self.device_id = -1
logger.info('Use CPU for inference')

def forward(self, x):
pred_result = self.model(x)
return pred_result


def infer(img_path, model, device):

img = Image.open(img_path)
clip = spatial_transform(img)
clip = clip.unsqueeze(0).to(device).float()
outputs = model(clip)
predicted = int(outputs.max(1)[1])
pred_result = map_idx.get(predicted)
logger.info('pred result: {}'.format(pred_result))

return pred_result

+ 358
- 0
modelscope/models/cv/hand_static/networks.py View File

@@ -0,0 +1,358 @@
""" HandStatic
The implementation here is modified based on MobileFaceNet,
originally Apache 2.0 License and publicly avaialbe at https://github.com/xuexingyu24/MobileFaceNet_Tutorial_Pytorch
"""

import os

import torch
import torch.nn as nn
import torchvision
import torchvision.models as models
from torch.nn import (AdaptiveAvgPool2d, BatchNorm1d, BatchNorm2d, Conv2d,
Dropout, Linear, MaxPool2d, Module, PReLU, ReLU,
Sequential, Sigmoid)


class StaticGestureNet(torch.nn.Module):

def __init__(self, train=True):
super().__init__()

model = MobileFaceNet(512)
self.feature_extractor = model
self.fc_layer = torch.nn.Sequential(
nn.Linear(512, 128), nn.Softplus(), nn.Linear(128, 15))
self.sigmoid = nn.Sigmoid()

def forward(self, inputs):
out = self.feature_extractor(inputs)
out = self.fc_layer(out)
out = self.sigmoid(out)
return out


class Flatten(Module):

def forward(self, input):
return input.view(input.size(0), -1)


def l2_norm(input, axis=1):
norm = torch.norm(input, 2, axis, True)
output = torch.div(input, norm)
return output


class SEModule(Module):

def __init__(self, channels, reduction):
super(SEModule, self).__init__()
self.avg_pool = AdaptiveAvgPool2d(1)
self.fc1 = Conv2d(
channels,
channels // reduction,
kernel_size=1,
padding=0,
bias=False)
self.relu = ReLU(inplace=True)
self.fc2 = Conv2d(
channels // reduction,
channels,
kernel_size=1,
padding=0,
bias=False)
self.sigmoid = Sigmoid()

def forward(self, x):
module_input = x
x = self.avg_pool(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.sigmoid(x)
return module_input * x


class BottleneckIR(Module):

def __init__(self, in_channel, depth, stride):
super(BottleneckIR, self).__init__()
if in_channel == depth:
self.shortcut_layer = MaxPool2d(1, stride)
else:
self.shortcut_layer = Sequential(
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
BatchNorm2d(depth))
self.res_layer = Sequential(
BatchNorm2d(in_channel),
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
BatchNorm2d(depth))

def forward(self, x):
shortcut = self.shortcut_layer(x)
res = self.res_layer(x)
return res + shortcut


class BottleneckIRSE(Module):

def __init__(self, in_channel, depth, stride):
super(BottleneckIRSE, self).__init__()
if in_channel == depth:
self.shortcut_layer = MaxPool2d(1, stride)
else:
self.shortcut_layer = Sequential(
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
BatchNorm2d(depth))
self.res_layer = Sequential(
BatchNorm2d(in_channel),
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
BatchNorm2d(depth), SEModule(depth, 16))

def forward(self, x):
shortcut = self.shortcut_layer(x)
res = self.res_layer(x)
return res + shortcut


def get_block(in_channel, depth, num_units, stride=2):
return [Bottleneck(in_channel, depth, stride)
] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]


def get_blocks(num_layers):
if num_layers == 50:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=4),
get_block(in_channel=128, depth=256, num_units=14),
get_block(in_channel=256, depth=512, num_units=3)
]
elif num_layers == 100:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=13),
get_block(in_channel=128, depth=256, num_units=30),
get_block(in_channel=256, depth=512, num_units=3)
]
elif num_layers == 152:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=8),
get_block(in_channel=128, depth=256, num_units=36),
get_block(in_channel=256, depth=512, num_units=3)
]
return blocks


class Backbone(Module):

def __init__(self, num_layers, drop_ratio, mode='ir'):
super(Backbone, self).__init__()
assert num_layers in [50, 100,
152], 'num_layers should be 50,100, or 152'
assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
blocks = get_blocks(num_layers)
if mode == 'ir':
unit_module = BottleneckIR
elif mode == 'ir_se':
unit_module = BottleneckIRSE
self.input_layer = Sequential(
Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64),
PReLU(64))
self.output_layer = Sequential(
BatchNorm2d(512), Dropout(drop_ratio), Flatten(),
Linear(512 * 7 * 7, 512), BatchNorm1d(512))
modules = []
for block in blocks:
for bottleneck in block:
modules.append(
unit_module(bottleneck.in_channel, bottleneck.depth,
bottleneck.stride))
self.body = Sequential(*modules)

def forward(self, x):
x = self.input_layer(x)
x = self.body(x)
x = self.output_layer(x)
return l2_norm(x)


class ConvBlock(Module):

def __init__(self,
in_c,
out_c,
kernel=(1, 1),
stride=(1, 1),
padding=(0, 0),
groups=1):
super(ConvBlock, self).__init__()
self.conv = Conv2d(
in_c,
out_channels=out_c,
kernel_size=kernel,
groups=groups,
stride=stride,
padding=padding,
bias=False)
self.bn = BatchNorm2d(out_c)
self.prelu = PReLU(out_c)

def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.prelu(x)
return x


class LinearBlock(Module):

def __init__(self,
in_c,
out_c,
kernel=(1, 1),
stride=(1, 1),
padding=(0, 0),
groups=1):
super(LinearBlock, self).__init__()
self.conv = Conv2d(
in_c,
out_channels=out_c,
kernel_size=kernel,
groups=groups,
stride=stride,
padding=padding,
bias=False)
self.bn = BatchNorm2d(out_c)

def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x


class DepthWise(Module):

def __init__(self,
in_c,
out_c,
residual=False,
kernel=(3, 3),
stride=(2, 2),
padding=(1, 1),
groups=1):
super(DepthWise, self).__init__()
self.conv = ConvBlock(
in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
self.conv_dw = ConvBlock(
groups,
groups,
groups=groups,
kernel=kernel,
padding=padding,
stride=stride)
self.project = LinearBlock(
groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
self.residual = residual

def forward(self, x):
if self.residual:
short_cut = x
x = self.conv(x)
x = self.conv_dw(x)
x = self.project(x)
if self.residual:
output = short_cut + x
else:
output = x
return output


class Residual(Module):

def __init__(self,
c,
num_block,
groups,
kernel=(3, 3),
stride=(1, 1),
padding=(1, 1)):
super(Residual, self).__init__()
modules = []
for _ in range(num_block):
modules.append(
DepthWise(
c,
c,
residual=True,
kernel=kernel,
padding=padding,
stride=stride,
groups=groups))
self.model = Sequential(*modules)

def forward(self, x):
return self.model(x)


class MobileFaceNet(Module):

def __init__(self, embedding_size):
super(MobileFaceNet, self).__init__()
self.conv1 = ConvBlock(
3, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1))
self.conv2_dw = ConvBlock(
64, 64, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
self.conv_23 = DepthWise(
64, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128)
self.conv_3 = Residual(
64,
num_block=4,
groups=128,
kernel=(3, 3),
stride=(1, 1),
padding=(1, 1))
self.conv_34 = DepthWise(
64, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256)
self.conv_4 = Residual(
128,
num_block=6,
groups=256,
kernel=(3, 3),
stride=(1, 1),
padding=(1, 1))
self.conv_45 = DepthWise(
128, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512)
self.conv_5 = Residual(
128,
num_block=2,
groups=256,
kernel=(3, 3),
stride=(1, 1),
padding=(1, 1))
self.conv_6_sep = ConvBlock(
128, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
self.conv_6_dw = LinearBlock(
512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0))
self.conv_6_flatten = Flatten()
self.linear = Linear(512, embedding_size, bias=False)
self.bn = BatchNorm1d(embedding_size)

def forward(self, x):
out = self.conv1(x)
out = self.conv2_dw(out)
out = self.conv_23(out)
out = self.conv_3(out)
out = self.conv_34(out)
out = self.conv_4(out)
out = self.conv_45(out)
out = self.conv_5(out)
out = self.conv_6_sep(out)
out = self.conv_6_dw(out)
out = self.conv_6_flatten(out)
out = self.linear(out)
return l2_norm(out)

+ 5
- 1
modelscope/outputs.py View File

@@ -632,5 +632,9 @@ TASK_OUTPUTS = {
# {
# 'output': ['Done' / 'Decode_Error']
# }
Tasks.video_inpainting: [OutputKeys.OUTPUT]
Tasks.video_inpainting: [OutputKeys.OUTPUT],
# {
# 'output': ['bixin']
# }
Tasks.hand_static: [OutputKeys.OUTPUT]
}

+ 2
- 0
modelscope/pipelines/builder.py View File

@@ -178,6 +178,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/cv_vitb16_segmentation_shop-seg'),
Tasks.video_inpainting: (Pipelines.video_inpainting,
'damo/cv_video-inpainting'),
Tasks.hand_static: (Pipelines.hand_static,
'damo/cv_mobileface_hand-static'),
}




+ 3
- 1
modelscope/pipelines/cv/__init__.py View File

@@ -52,7 +52,8 @@ if TYPE_CHECKING:
from .ulfd_face_detection_pipeline import UlfdFaceDetectionPipeline
from .retina_face_detection_pipeline import RetinaFaceDetectionPipeline
from .facial_expression_recognition_pipeline import FacialExpressionRecognitionPipeline
from .mtcnn_face_detection_pipeline import MtcnnFaceDetectionPipeline
from .mtcnn_face_detection_pipeline import MtcnnFaceDetectionPipelin
from .hand_static_pipeline import HandStaticPipeline

else:
_import_structure = {
@@ -119,6 +120,7 @@ else:
'facial_expression_recognition_pipelin':
['FacialExpressionRecognitionPipeline'],
'mtcnn_face_detection_pipeline': ['MtcnnFaceDetectionPipeline'],
'hand_static_pipeline': ['HandStaticPipeline'],
}

import sys


+ 37
- 0
modelscope/pipelines/cv/hand_static_pipeline.py View File

@@ -0,0 +1,37 @@
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
from typing import Any, Dict

from modelscope.metainfo import Pipelines
from modelscope.models.cv.hand_static import hand_model
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


@PIPELINES.register_module(
Tasks.hand_static, module_name=Pipelines.hand_static)
class HandStaticPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
"""
use `model` to create hand static pipeline for prediction
Args:
model: model id on modelscope hub.
"""

super().__init__(model=model, **kwargs)
logger.info('load model done')

def preprocess(self, input: Input) -> Dict[str, Any]:
return input

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
result = hand_model.infer(input['img_path'], self.model, self.device)
return {OutputKeys.OUTPUT: result}

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

+ 1
- 0
modelscope/utils/constant.py View File

@@ -42,6 +42,7 @@ class CVTasks(object):
portrait_matting = 'portrait-matting'
text_driven_segmentation = 'text-driven-segmentation'
shop_segmentation = 'shop-segmentation'
hand_static = 'hand-static'

# image editing
skin_retouching = 'skin-retouching'


+ 32
- 0
tests/pipelines/test_hand_static.py View File

@@ -0,0 +1,32 @@
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
import unittest

from modelscope.pipelines import pipeline
from modelscope.pipelines.base import Pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class HandStaticTest(unittest.TestCase):

def setUp(self) -> None:
self.model = 'damo/cv_mobileface_hand-static'
self.input = {'img_path': 'data/test/images/hand_static.jpg'}

def pipeline_inference(self, pipeline: Pipeline, input: str):
result = pipeline(input)
print(result)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub(self):
hand_static = pipeline(Tasks.hand_static, model=self.model)
self.pipeline_inference(hand_static, self.input)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_modelhub_default_model(self):
hand_static = pipeline(Tasks.hand_static)
self.pipeline_inference(hand_static, self.input)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save