Browse Source

[to #43259593]cv:add human pose eastimation to maas-lib

add human pose eastimation to maas-lib v3
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9491970
master
shouzhou.bx yingda.chen 3 years ago
parent
commit
cb1aa66a49
21 changed files with 1113 additions and 4 deletions
  1. +3
    -0
      data/test/images/keypoints_detect/000000438304.jpg
  2. +3
    -0
      data/test/images/keypoints_detect/000000438862.jpg
  3. +3
    -0
      data/test/images/keypoints_detect/000000439522.jpg
  4. +3
    -0
      data/test/images/keypoints_detect/000000440336.jpg
  5. +3
    -0
      data/test/images/keypoints_detect/000000442836.jpg
  6. +3
    -0
      data/test/images/keypoints_detect/000000447088.jpg
  7. +3
    -0
      data/test/images/keypoints_detect/000000447917.jpg
  8. +3
    -0
      data/test/images/keypoints_detect/000000448263.jpg
  9. +3
    -0
      data/test/images/keypoints_detect/body_keypoints_detection.jpg
  10. +2
    -0
      modelscope/metainfo.py
  11. +4
    -4
      modelscope/models/cv/__init__.py
  12. +23
    -0
      modelscope/models/cv/body_2d_keypoints/__init__.py
  13. +397
    -0
      modelscope/models/cv/body_2d_keypoints/hrnet_basic_modules.py
  14. +221
    -0
      modelscope/models/cv/body_2d_keypoints/hrnet_v2.py
  15. +51
    -0
      modelscope/models/cv/body_2d_keypoints/w48.py
  16. +21
    -0
      modelscope/outputs.py
  17. +3
    -0
      modelscope/pipelines/builder.py
  18. +2
    -0
      modelscope/pipelines/cv/__init__.py
  19. +261
    -0
      modelscope/pipelines/cv/body_2d_keypoints_pipeline.py
  20. +1
    -0
      modelscope/utils/constant.py
  21. +100
    -0
      tests/pipelines/test_body_2d_keypoints.py

+ 3
- 0
data/test/images/keypoints_detect/000000438304.jpg View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:64ab6a5556b022cbd398d98cd5bb243a4ee6e4ea6e3285f433eb78b76b53fd4e
size 269177

+ 3
- 0
data/test/images/keypoints_detect/000000438862.jpg View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3689831ed23f734ebab9405f48ffbfbbefb778e9de3101a9d56e421ea45288cf
size 248595

+ 3
- 0
data/test/images/keypoints_detect/000000439522.jpg View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:663545f71af556370c7cba7fd8010a665d00c0b477075562a3d7669c6d853ad3
size 107685

+ 3
- 0
data/test/images/keypoints_detect/000000440336.jpg View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e5c2df473a26427ae57950acec86d1e4d3a49cdf1a18d427cd1a354465408f00
size 102909

+ 3
- 0
data/test/images/keypoints_detect/000000442836.jpg View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:44b225eaff012bd016fcfe8a3dbeace93fd418164f40e4b5f5b9f0d76f39097b
size 308635

+ 3
- 0
data/test/images/keypoints_detect/000000447088.jpg View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:510da487b16303646cf4b500cae0a4168cba2feb3dd706c007a3f5c64400501c
size 148413

+ 3
- 0
data/test/images/keypoints_detect/000000447917.jpg View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:dbaa52b9ecc59b899500db9200ce65b17aa8b87172c8c70de585fa27c80e7ad1
size 238442

+ 3
- 0
data/test/images/keypoints_detect/000000448263.jpg View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:72fcff7fd4da5ede2d3c1a31449769b0595685f7250597f05cd176c4c80ced03
size 37753

+ 3
- 0
data/test/images/keypoints_detect/body_keypoints_detection.jpg View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:379e11d7fc3734d3ec95afd0d86460b4653fbf4bb1f57f993610d6a6fd30fd3d
size 1702339

+ 2
- 0
modelscope/metainfo.py View File

@@ -18,6 +18,7 @@ class Models(object):
cascade_mask_rcnn_swin = 'cascade_mask_rcnn_swin'
gpen = 'gpen'
product_retrieval_embedding = 'product-retrieval-embedding'
body_2d_keypoints = 'body-2d-keypoints'

# nlp models
bert = 'bert'
@@ -77,6 +78,7 @@ class Pipelines(object):
action_recognition = 'TAdaConv_action-recognition'
animal_recognation = 'resnet101-animal_recog'
cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding'
body_2d_keypoints = 'hrnetv2w32_body-2d-keypoints_image'
human_detection = 'resnet18-human-detection'
object_detection = 'vit-object-detection'
image_classification = 'image-classification'


+ 4
- 4
modelscope/models/cv/__init__.py View File

@@ -1,8 +1,8 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from . import (action_recognition, animal_recognition, cartoon,
cmdssl_video_embedding, face_detection, face_generation,
image_classification, image_color_enhance, image_colorization,
image_denoise, image_instance_segmentation,
from . import (action_recognition, animal_recognition, body_2d_keypoints,
cartoon, cmdssl_video_embedding, face_detection,
face_generation, image_classification, image_color_enhance,
image_colorization, image_denoise, image_instance_segmentation,
image_portrait_enhancement, image_to_image_generation,
image_to_image_translation, object_detection,
product_retrieval_embedding, super_resolution, virual_tryon)

+ 23
- 0
modelscope/models/cv/body_2d_keypoints/__init__.py View File

@@ -0,0 +1,23 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING

from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:

from .hrnet_v2 import PoseHighResolutionNetV2

else:
_import_structure = {
'keypoints_detector': ['PoseHighResolutionNetV2'],
}

import sys

sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

+ 397
- 0
modelscope/models/cv/body_2d_keypoints/hrnet_basic_modules.py View File

@@ -0,0 +1,397 @@
# The implementation is based on HRNET, available at https://github.com/HRNet/HigherHRNet-Human-Pose-Estimation.

import torch
import torch.nn as nn

BN_MOMENTUM = 0.1


def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False)


class BasicBlock(nn.Module):
expansion = 1

def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.downsample = downsample
self.stride = stride

def forward(self, x):
residual = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)

if self.downsample is not None:
residual = self.downsample(x)

out += residual
out = self.relu(out)

return out


class Bottleneck(nn.Module):
expansion = 4

def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.conv2 = nn.Conv2d(
planes,
planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False)
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.conv3 = nn.Conv2d(
planes, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(
planes * self.expansion, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride

def forward(self, x):
residual = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)

out = self.conv3(out)
out = self.bn3(out)

if self.downsample is not None:
residual = self.downsample(x)

out += residual
out = self.relu(out)

return out


class HighResolutionModule(nn.Module):

def __init__(self,
num_branches,
blocks,
num_blocks,
num_inchannels,
num_channels,
fuse_method,
multi_scale_output=True):
super(HighResolutionModule, self).__init__()
self._check_branches(num_branches, blocks, num_blocks, num_inchannels,
num_channels)

self.num_inchannels = num_inchannels
self.fuse_method = fuse_method
self.num_branches = num_branches

self.multi_scale_output = multi_scale_output

self.branches = self._make_branches(num_branches, blocks, num_blocks,
num_channels)
self.fuse_layers = self._make_fuse_layers()
self.relu = nn.ReLU(True)

def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels,
num_channels):
if num_branches != len(num_blocks):
error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
num_branches, len(num_blocks))
raise ValueError(error_msg)

if num_branches != len(num_channels):
error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
num_branches, len(num_channels))
raise ValueError(error_msg)

if num_branches != len(num_inchannels):
error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
num_branches, len(num_inchannels))
raise ValueError(error_msg)

def _make_one_branch(self,
branch_index,
block,
num_blocks,
num_channels,
stride=1):
downsample = None
if stride != 1 or \
self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
self.num_inchannels[branch_index],
num_channels[branch_index] * block.expansion,
kernel_size=1,
stride=stride,
bias=False),
nn.BatchNorm2d(
num_channels[branch_index] * block.expansion,
momentum=BN_MOMENTUM),
)
layers = []
layers.append(
block(self.num_inchannels[branch_index],
num_channels[branch_index], stride, downsample))
self.num_inchannels[branch_index] = \
num_channels[branch_index] * block.expansion
for i in range(1, num_blocks[branch_index]):
layers.append(
block(self.num_inchannels[branch_index],
num_channels[branch_index]))

return nn.Sequential(*layers)

def _make_branches(self, num_branches, block, num_blocks, num_channels):
branches = []

for i in range(num_branches):
branches.append(
self._make_one_branch(i, block, num_blocks, num_channels))

return nn.ModuleList(branches)

def _make_fuse_layers(self):
if self.num_branches == 1:
return None

num_branches = self.num_branches
num_inchannels = self.num_inchannels
fuse_layers = []
for i in range(num_branches if self.multi_scale_output else 1):
fuse_layer = []
for j in range(num_branches):
if j > i:
fuse_layer.append(
nn.Sequential(
nn.Conv2d(
num_inchannels[j],
num_inchannels[i],
1,
1,
0,
bias=False), nn.BatchNorm2d(num_inchannels[i]),
nn.Upsample(
scale_factor=2**(j - i), mode='nearest')))
elif j == i:
fuse_layer.append(None)
else:
conv3x3s = []
for k in range(i - j):
if k == i - j - 1:
num_outchannels_conv3x3 = num_inchannels[i]
conv3x3s.append(
nn.Sequential(
nn.Conv2d(
num_inchannels[j],
num_outchannels_conv3x3,
3,
2,
1,
bias=False),
nn.BatchNorm2d(num_outchannels_conv3x3)))
else:
num_outchannels_conv3x3 = num_inchannels[j]
conv3x3s.append(
nn.Sequential(
nn.Conv2d(
num_inchannels[j],
num_outchannels_conv3x3,
3,
2,
1,
bias=False),
nn.BatchNorm2d(num_outchannels_conv3x3),
nn.ReLU(True)))
fuse_layer.append(nn.Sequential(*conv3x3s))
fuse_layers.append(nn.ModuleList(fuse_layer))

return nn.ModuleList(fuse_layers)

def get_num_inchannels(self):
return self.num_inchannels

def forward(self, x):
if self.num_branches == 1:
return [self.branches[0](x[0])]

for i in range(self.num_branches):
x[i] = self.branches[i](x[i])

x_fuse = []

for i in range(len(self.fuse_layers)):
y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
for j in range(1, self.num_branches):
if i == j:
y = y + x[j]
else:
y = y + self.fuse_layers[i][j](x[j])
x_fuse.append(self.relu(y))

return x_fuse


def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1):
result = nn.Sequential()
result.add_module(
'conv',
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias=False))
result.add_module('bn', nn.BatchNorm2d(num_features=out_channels))
return result


def upsample(scale, oup):
return nn.Sequential(
nn.Upsample(scale_factor=scale, mode='bilinear'),
nn.Conv2d(
in_channels=oup,
out_channels=oup,
kernel_size=3,
stride=1,
padding=1,
groups=1,
bias=False), nn.BatchNorm2d(oup), nn.PReLU())


class SE_Block(nn.Module):

def __init__(self, c, r=16):
super().__init__()
self.squeeze = nn.AdaptiveAvgPool2d(1)
self.excitation = nn.Sequential(
nn.Linear(c, c // r, bias=False), nn.ReLU(inplace=True),
nn.Linear(c // r, c, bias=False), nn.Sigmoid())

def forward(self, x):
bs, c, _, _ = x.shape
y = self.squeeze(x).view(bs, c)
y = self.excitation(y).view(bs, c, 1, 1)
return x * y.expand_as(x)


class BasicBlockSE(nn.Module):
expansion = 1

def __init__(self, inplanes, planes, stride=1, downsample=None, r=64):
super(BasicBlockSE, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.downsample = downsample
self.stride = stride
self.se = SE_Block(planes, r)

def forward(self, x):
residual = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)
out = self.se(out)

if self.downsample is not None:
residual = self.downsample(x)

out += residual
out = self.relu(out)

return out


class BottleneckSE(nn.Module):
expansion = 4

def __init__(self, inplanes, planes, stride=1, downsample=None, r=64):
super(BottleneckSE, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.conv2 = nn.Conv2d(
planes,
planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False)
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.conv3 = nn.Conv2d(
planes, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(
planes * self.expansion, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride

self.se = SE_Block(planes * self.expansion, r)

def forward(self, x):
residual = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)

out = self.conv3(out)
out = self.bn3(out)

out = self.se(out)

if self.downsample is not None:
residual = self.downsample(x)

out += residual
out = self.relu(out)

return out


blocks_dict = {
'BASIC': BasicBlock,
'BOTTLENECK': Bottleneck,
'BASICSE': BasicBlockSE,
'BOTTLENECKSE': BottleneckSE,
}

+ 221
- 0
modelscope/models/cv/body_2d_keypoints/hrnet_v2.py View File

@@ -0,0 +1,221 @@
import os

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from modelscope.metainfo import Models
from modelscope.models.base.base_torch_model import TorchModel
from modelscope.models.builder import MODELS
from modelscope.models.cv.body_2d_keypoints.hrnet_basic_modules import (
BN_MOMENTUM, BasicBlock, Bottleneck, HighResolutionModule, blocks_dict)
from modelscope.models.cv.body_2d_keypoints.w48 import cfg_128x128_15
from modelscope.utils.constant import Tasks


@MODELS.register_module(
Tasks.body_2d_keypoints, module_name=Models.body_2d_keypoints)
class PoseHighResolutionNetV2(TorchModel):

def __init__(self, cfg=None, **kwargs):
if cfg is None:
cfg = cfg_128x128_15
self.inplanes = 64
extra = cfg['MODEL']['EXTRA']
super(PoseHighResolutionNetV2, self).__init__(**kwargs)

# stem net
self.conv1 = nn.Conv2d(
3, 64, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
self.conv2 = nn.Conv2d(
64, 64, kernel_size=3, stride=2, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.layer1 = self._make_layer(Bottleneck, 64, 4)

self.stage2_cfg = cfg['MODEL']['EXTRA']['STAGE2']
num_channels = self.stage2_cfg['NUM_CHANNELS']
block = blocks_dict[self.stage2_cfg['BLOCK']]
num_channels = [
num_channels[i] * block.expansion
for i in range(len(num_channels))
]
self.transition1 = self._make_transition_layer([256], num_channels)
self.stage2, pre_stage_channels = self._make_stage(
self.stage2_cfg, num_channels)

self.stage3_cfg = cfg['MODEL']['EXTRA']['STAGE3']
num_channels = self.stage3_cfg['NUM_CHANNELS']
block = blocks_dict[self.stage3_cfg['BLOCK']]
num_channels = [
num_channels[i] * block.expansion
for i in range(len(num_channels))
]
self.transition2 = self._make_transition_layer(pre_stage_channels,
num_channels)
self.stage3, pre_stage_channels = self._make_stage(
self.stage3_cfg, num_channels)

self.stage4_cfg = cfg['MODEL']['EXTRA']['STAGE4']
num_channels = self.stage4_cfg['NUM_CHANNELS']
block = blocks_dict[self.stage4_cfg['BLOCK']]
num_channels = [
num_channels[i] * block.expansion
for i in range(len(num_channels))
]
self.transition3 = self._make_transition_layer(pre_stage_channels,
num_channels)
self.stage4, pre_stage_channels = self._make_stage(
self.stage4_cfg, num_channels, multi_scale_output=True)
"""final four layers"""
last_inp_channels = np.int(np.sum(pre_stage_channels))
self.final_layer = nn.Sequential(
nn.Conv2d(
in_channels=last_inp_channels,
out_channels=last_inp_channels,
kernel_size=1,
stride=1,
padding=0),
nn.BatchNorm2d(last_inp_channels, momentum=BN_MOMENTUM),
nn.ReLU(inplace=False),
nn.Conv2d(
in_channels=last_inp_channels,
out_channels=cfg['MODEL']['NUM_JOINTS'],
kernel_size=extra['FINAL_CONV_KERNEL'],
stride=1,
padding=1 if extra['FINAL_CONV_KERNEL'] == 3 else 0))

self.pretrained_layers = cfg['MODEL']['EXTRA']['PRETRAINED_LAYERS']

def _make_transition_layer(self, num_channels_pre_layer,
num_channels_cur_layer):
num_branches_cur = len(num_channels_cur_layer)
num_branches_pre = len(num_channels_pre_layer)

transition_layers = []
for i in range(num_branches_cur):
if i < num_branches_pre:
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
transition_layers.append(
nn.Sequential(
nn.Conv2d(
num_channels_pre_layer[i],
num_channels_cur_layer[i],
3,
1,
1,
bias=False),
nn.BatchNorm2d(num_channels_cur_layer[i]),
nn.ReLU(inplace=True)))
else:
transition_layers.append(None)
else:
conv3x3s = []
for j in range(i + 1 - num_branches_pre):
inchannels = num_channels_pre_layer[-1]
outchannels = num_channels_cur_layer[
i] if j == i - num_branches_pre else inchannels
conv3x3s.append(
nn.Sequential(
nn.Conv2d(
inchannels, outchannels, 3, 2, 1, bias=False),
nn.BatchNorm2d(outchannels),
nn.ReLU(inplace=True)))
transition_layers.append(nn.Sequential(*conv3x3s))

return nn.ModuleList(transition_layers)

def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False),
nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
)

layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))

return nn.Sequential(*layers)

def _make_stage(self,
layer_config,
num_inchannels,
multi_scale_output=True):
num_modules = layer_config['NUM_MODULES']
num_branches = layer_config['NUM_BRANCHES']
num_blocks = layer_config['NUM_BLOCKS']
num_channels = layer_config['NUM_CHANNELS']
block = blocks_dict[layer_config['BLOCK']]
fuse_method = layer_config['FUSE_METHOD']

modules = []
for i in range(num_modules):
if not multi_scale_output and i == num_modules - 1:
reset_multi_scale_output = False
else:
reset_multi_scale_output = True

modules.append(
HighResolutionModule(num_branches, block, num_blocks,
num_inchannels, num_channels, fuse_method,
reset_multi_scale_output))
num_inchannels = modules[-1].get_num_inchannels()

return nn.Sequential(*modules), num_inchannels

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.layer1(x)

x_list = []
for i in range(self.stage2_cfg['NUM_BRANCHES']):
if self.transition1[i] is not None:
x_list.append(self.transition1[i](x))
else:
x_list.append(x)

y_list = self.stage2(x_list)

x_list = []
for i in range(self.stage3_cfg['NUM_BRANCHES']):
if self.transition2[i] is not None:
x_list.append(self.transition2[i](y_list[-1]))
else:
x_list.append(y_list[i])

y_list = self.stage3(x_list)

x_list = []
for i in range(self.stage4_cfg['NUM_BRANCHES']):
if self.transition3[i] is not None:
x_list.append(self.transition3[i](y_list[-1]))
else:
x_list.append(y_list[i])

y_list = self.stage4(x_list)

y0_h, y0_w = y_list[0].size(2), y_list[0].size(3)
y1 = F.upsample(y_list[1], size=(y0_h, y0_w), mode='bilinear')
y2 = F.upsample(y_list[2], size=(y0_h, y0_w), mode='bilinear')
y3 = F.upsample(y_list[3], size=(y0_h, y0_w), mode='bilinear')

y = torch.cat([y_list[0], y1, y2, y3], 1)
output = self.final_layer(y)

return output

+ 51
- 0
modelscope/models/cv/body_2d_keypoints/w48.py View File

@@ -0,0 +1,51 @@
cfg_128x128_15 = {
'DATASET': {
'TYPE': 'DAMO',
'PARENT_IDS': [0, 0, 1, 2, 3, 1, 5, 6, 14, 8, 9, 14, 11, 12, 1],
'LEFT_IDS': [2, 3, 4, 8, 9, 10],
'RIGHT_IDS': [5, 6, 7, 11, 12, 13],
'SPINE_IDS': [0, 1, 14]
},
'MODEL': {
'INIT_WEIGHTS': True,
'NAME': 'pose_hrnet',
'NUM_JOINTS': 15,
'PRETRAINED': '',
'TARGET_TYPE': 'gaussian',
'IMAGE_SIZE': [128, 128],
'HEATMAP_SIZE': [32, 32],
'SIGMA': 2.0,
'EXTRA': {
'PRETRAINED_LAYERS': [
'conv1', 'bn1', 'conv2', 'bn2', 'layer1', 'transition1',
'stage2', 'transition2', 'stage3', 'transition3', 'stage4'
],
'FINAL_CONV_KERNEL':
1,
'STAGE2': {
'NUM_MODULES': 1,
'NUM_BRANCHES': 2,
'BLOCK': 'BASIC',
'NUM_BLOCKS': [4, 4],
'NUM_CHANNELS': [48, 96],
'FUSE_METHOD': 'SUM'
},
'STAGE3': {
'NUM_MODULES': 4,
'NUM_BRANCHES': 3,
'BLOCK': 'BASIC',
'NUM_BLOCKS': [4, 4, 4],
'NUM_CHANNELS': [48, 96, 192],
'FUSE_METHOD': 'SUM'
},
'STAGE4': {
'NUM_MODULES': 3,
'NUM_BRANCHES': 4,
'BLOCK': 'BASIC',
'NUM_BLOCKS': [4, 4, 4, 4],
'NUM_CHANNELS': [48, 96, 192, 384],
'FUSE_METHOD': 'SUM'
},
}
}
}

+ 21
- 0
modelscope/outputs.py View File

@@ -158,6 +158,27 @@ TASK_OUTPUTS = {
# }
Tasks.action_recognition: [OutputKeys.LABELS],

# human body keypoints detection result for single sample
# {
# "poses": [
# [x, y],
# [x, y],
# [x, y]
# ]
# "scores": [
# [score],
# [score],
# [score],
# ]
# "boxes": [
# [x1, y1, x2, y2],
# [x1, y1, x2, y2],
# [x1, y1, x2, y2],
# ]
# }
Tasks.body_2d_keypoints:
[OutputKeys.POSES, OutputKeys.SCORES, OutputKeys.BOXES],

# live category recognition result for single video
# {
# "scores": [0.885272, 0.014790631, 0.014558001],


+ 3
- 0
modelscope/pipelines/builder.py View File

@@ -87,6 +87,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.text_to_image_synthesis:
(Pipelines.text_to_image_synthesis,
'damo/cv_diffusion_text-to-image-synthesis_tiny'),
Tasks.body_2d_keypoints: (Pipelines.body_2d_keypoints,
'damo/cv_hrnetv2w32_body-2d-keypoints_image'),
Tasks.face_detection: (Pipelines.face_detection,
'damo/cv_resnet_facedetection_scrfd10gkps'),
Tasks.face_recognition: (Pipelines.face_recognition,
@@ -238,6 +240,7 @@ def pipeline(task: str = None,

cfg = ConfigDict(type=pipeline_name, model=model)
cfg.device = device

if kwargs:
cfg.update(kwargs)



+ 2
- 0
modelscope/pipelines/cv/__init__.py View File

@@ -6,6 +6,7 @@ from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .action_recognition_pipeline import ActionRecognitionPipeline
from .animal_recognition_pipeline import AnimalRecognitionPipeline
from .body_2d_keypoints_pipeline import Body2DKeypointsPipeline
from .cmdssl_video_embedding_pipeline import CMDSSLVideoEmbeddingPipeline
from .image_detection_pipeline import ImageDetectionPipeline
from .face_detection_pipeline import FaceDetectionPipeline
@@ -34,6 +35,7 @@ else:
_import_structure = {
'action_recognition_pipeline': ['ActionRecognitionPipeline'],
'animal_recognition_pipeline': ['AnimalRecognitionPipeline'],
'body_2d_keypoints_pipeline': ['Body2DKeypointsPipeline'],
'cmdssl_video_embedding_pipeline': ['CMDSSLVideoEmbeddingPipeline'],
'image_detection_pipeline': ['ImageDetectionPipeline'],
'face_detection_pipeline': ['FaceDetectionPipeline'],


+ 261
- 0
modelscope/pipelines/cv/body_2d_keypoints_pipeline.py View File

@@ -0,0 +1,261 @@
import os.path as osp
from typing import Any, Dict, List, Union

import cv2
import json
import numpy as np
import torch
from PIL import Image
from torchvision import transforms

from modelscope.metainfo import Pipelines
from modelscope.models.cv.body_2d_keypoints.hrnet_v2 import \
PoseHighResolutionNetV2
from modelscope.models.cv.body_2d_keypoints.w48 import cfg_128x128_15
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.pipelines.base import Input, Model, Pipeline, Tensor
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import load_image
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


@PIPELINES.register_module(
Tasks.body_2d_keypoints, module_name=Pipelines.body_2d_keypoints)
class Body2DKeypointsPipeline(Pipeline):

def __init__(self, model: str, human_detector: Pipeline, **kwargs):
super().__init__(model=model, **kwargs)
self.keypoint_model = KeypointsDetection(model)
self.keypoint_model.eval()
self.human_detector = human_detector

def preprocess(self, input: Input) -> Dict[Tensor, Union[str, np.ndarray]]:
output = self.human_detector(input)

if isinstance(input, str):
image = cv2.imread(input, -1)[:, :, 0:3]
elif isinstance(input, np.ndarray):
if len(input.shape) == 2:
image = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR)
image = image[:, :, 0:3]

return {'image': image, 'output': output}

def forward(self, input: Tensor) -> Dict[Tensor, Dict[str, np.ndarray]]:
input_image = input['image']
output = input['output']

bboxes = []
scores = np.array(output[OutputKeys.SCORES].cpu(), dtype=np.float32)
boxes = np.array(output[OutputKeys.BOXES].cpu(), dtype=np.float32)

for id, box in enumerate(boxes):
box_tmp = [
box[0], box[1], box[2] - box[0], box[3] - box[1], scores[id], 0
]
bboxes.append(box_tmp)
if len(bboxes) == 0:
logger.error('cannot detect human in the image')
return [None, None]
human_images, metas = self.keypoint_model.preprocess(
[bboxes, input_image])
outputs = self.keypoint_model.forward(human_images)
return [outputs, metas]

def postprocess(self, input: Dict[Tensor, Dict[str, np.ndarray]],
**kwargs) -> str:
if input[0] is None or input[1] is None:
return {
OutputKeys.BOXES: [],
OutputKeys.POSES: [],
OutputKeys.SCORES: []
}

poses, scores, boxes = self.keypoint_model.postprocess(input)
return {
OutputKeys.BOXES: boxes,
OutputKeys.POSES: poses,
OutputKeys.SCORES: scores
}


class KeypointsDetection():

def __init__(self, model: str, **kwargs):
self.model = model
cfg = cfg_128x128_15
self.key_points_model = PoseHighResolutionNetV2(cfg)
pretrained_state_dict = torch.load(
osp.join(self.model, ModelFile.TORCH_MODEL_FILE))
self.key_points_model.load_state_dict(
pretrained_state_dict, strict=False)

self.input_size = cfg['MODEL']['IMAGE_SIZE']
self.lst_parent_ids = cfg['DATASET']['PARENT_IDS']
self.lst_left_ids = cfg['DATASET']['LEFT_IDS']
self.lst_right_ids = cfg['DATASET']['RIGHT_IDS']
self.box_enlarge_ratio = 0.05

def train(self):
return self.key_points_model.train()

def eval(self):
return self.key_points_model.eval()

def forward(self, input: Tensor) -> Tensor:
with torch.no_grad():
return self.key_points_model.forward(input)

def get_pts(self, heatmaps):
[pts_num, height, width] = heatmaps.shape
pts = []
scores = []
for i in range(pts_num):
heatmap = heatmaps[i, :, :]
pt = np.where(heatmap == np.max(heatmap))
scores.append(np.max(heatmap))
x = pt[1][0]
y = pt[0][0]

[h, w] = heatmap.shape
if x >= 1 and x <= w - 2 and y >= 1 and y <= h - 2:
x_diff = heatmap[y, x + 1] - heatmap[y, x - 1]
y_diff = heatmap[y + 1, x] - heatmap[y - 1, x]
x_sign = 0
y_sign = 0
if x_diff < 0:
x_sign = -1
if x_diff > 0:
x_sign = 1
if y_diff < 0:
y_sign = -1
if y_diff > 0:
y_sign = 1
x = x + x_sign * 0.25
y = y + y_sign * 0.25

pts.append([x, y])
return pts, scores

def pts_transform(self, meta, pts, lt_x, lt_y):
pts_new = []
s = meta['s']
o = meta['o']
size = len(pts)
for i in range(size):
ratio = 4
x = (int(pts[i][0] * ratio) - o[0]) / s[0]
y = (int(pts[i][1] * ratio) - o[1]) / s[1]

pt = [x, y]
pts_new.append(pt)

return pts_new

def postprocess(self, inputs: Dict[Tensor, Dict[str, np.ndarray]],
**kwargs):
output_poses = []
output_scores = []
output_boxes = []
for i in range(inputs[0].shape[0]):
outputs, scores = self.get_pts(
(inputs[0][i]).detach().cpu().numpy())
outputs = self.pts_transform(inputs[1][i], outputs, 0, 0)
box = np.array(inputs[1][i]['human_box'][0:4]).reshape(2, 2)
outputs = np.array(outputs) + box[0]
output_poses.append(outputs.tolist())
output_scores.append(scores)
output_boxes.append(box.tolist())
return output_poses, output_scores, output_boxes

def image_crop_resize(self, input, margin=[0, 0]):
pad_img = np.zeros((self.input_size[1], self.input_size[0], 3),
dtype=np.uint8)

h, w, ch = input.shape

h_new = self.input_size[1] - margin[1] * 2
w_new = self.input_size[0] - margin[0] * 2
s0 = float(h_new) / h
s1 = float(w_new) / w
s = min(s0, s1)
w_new = int(s * w)
h_new = int(s * h)

img_new = cv2.resize(input, (w_new, h_new), cv2.INTER_LINEAR)

cx = self.input_size[0] // 2
cy = self.input_size[1] // 2

pad_img[cy - h_new // 2:cy - h_new // 2 + h_new,
cx - w_new // 2:cx - w_new // 2 + w_new, :] = img_new

return pad_img, np.array([cx, cy]), np.array([s, s]), np.array(
[cx - w_new // 2, cy - h_new // 2])

def image_transform(self, input: Input) -> Dict[Tensor, Any]:
if isinstance(input, str):
image = cv2.imread(input, -1)[:, :, 0:3]
elif isinstance(input, np.ndarray):
if len(input.shape) == 2:
image = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR)
else:
image = input
image = image[:, :, 0:3]
elif isinstance(input, torch.Tensor):
image = input.cpu().numpy()[:, :, 0:3]

w, h, _ = image.shape
w_new = self.input_size[0]
h_new = self.input_size[1]

image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
img_resize, c, s, o = self.image_crop_resize(image)

img_resize = np.float32(img_resize) / 255.
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
img_resize = (img_resize - mean) / std

input_data = np.zeros([1, 3, h_new, w_new], dtype=np.float32)

img_resize = img_resize.transpose((2, 0, 1))
input_data[0, :] = img_resize
meta = {'c': c, 's': s, 'o': o}
return [torch.from_numpy(input_data), meta]

def crop_image(self, image, box):
height, width, _ = image.shape
w, h = box[1] - box[0]
box[0, :] -= (w * self.box_enlarge_ratio, h * self.box_enlarge_ratio)
box[1, :] += (w * self.box_enlarge_ratio, h * self.box_enlarge_ratio)

box[0, 0] = min(max(box[0, 0], 0.0), width)
box[0, 1] = min(max(box[0, 1], 0.0), height)
box[1, 0] = min(max(box[1, 0], 0.0), width)
box[1, 1] = min(max(box[1, 1], 0.0), height)

cropped_image = image[int(box[0][1]):int(box[1][1]),
int(box[0][0]):int(box[1][0])]
return cropped_image

def preprocess(self, input: Dict[Tensor, Tensor]) -> Dict[Tensor, Any]:
bboxes = input[0]
image = input[1]

lst_human_images = []
lst_meta = []
for i in range(len(bboxes)):
box = np.array(bboxes[i][0:4]).reshape(2, 2)
box[1] += box[0]
human_image = self.crop_image(image.clone(), box)
human_image, meta = self.image_transform(human_image)
lst_human_images.append(human_image)
meta['human_box'] = box
lst_meta.append(meta)

return [torch.cat(lst_human_images, dim=0), lst_meta]

+ 1
- 0
modelscope/utils/constant.py View File

@@ -22,6 +22,7 @@ class CVTasks(object):
human_detection = 'human-detection'
human_object_interaction = 'human-object-interaction'
face_image_generation = 'face-image-generation'
body_2d_keypoints = 'body-2d-keypoints'

image_classification = 'image-classification'
image_multilabel_classification = 'image-multilabel-classification'


+ 100
- 0
tests/pipelines/test_body_2d_keypoints.py View File

@@ -0,0 +1,100 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import os.path as osp
import pdb
import unittest

import cv2
import numpy as np
import torch

from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.pipelines.base import Pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level

lst_parent_ids_17 = [0, 0, 0, 1, 2, 0, 0, 5, 6, 7, 8, 5, 6, 11, 12, 13, 14]
lst_left_ids_17 = [1, 3, 5, 7, 9, 11, 13, 15]
lst_right_ids_17 = [2, 4, 6, 8, 10, 12, 14, 16]
lst_spine_ids_17 = [0]

lst_parent_ids_15 = [0, 0, 1, 2, 3, 1, 5, 6, 14, 8, 9, 14, 11, 12, 1]
lst_left_ids_15 = [2, 3, 4, 8, 9, 10]
lst_right_ids_15 = [5, 6, 7, 11, 12, 13]
lst_spine_ids_15 = [0, 1, 14]


def draw_joints(image, np_kps, score, threshold=0.2):
if np_kps.shape[0] == 17:
lst_parent_ids = lst_parent_ids_17
lst_left_ids = lst_left_ids_17
lst_right_ids = lst_right_ids_17

elif np_kps.shape[0] == 15:
lst_parent_ids = lst_parent_ids_15
lst_left_ids = lst_left_ids_15
lst_right_ids = lst_right_ids_15

for i in range(len(lst_parent_ids)):
pid = lst_parent_ids[i]
if i == pid:
continue

if (score[i] < threshold or score[1] < threshold):
continue

if i in lst_left_ids and pid in lst_left_ids:
color = (0, 255, 0)
elif i in lst_right_ids and pid in lst_right_ids:
color = (255, 0, 0)
else:
color = (0, 255, 255)

cv2.line(image, (int(np_kps[i, 0]), int(np_kps[i, 1])),
(int(np_kps[pid][0]), int(np_kps[pid, 1])), color, 3)

for i in range(np_kps.shape[0]):
if score[i] < threshold:
continue
cv2.circle(image, (int(np_kps[i, 0]), int(np_kps[i, 1])), 5,
(0, 0, 255), -1)


def draw_box(image, box):
cv2.rectangle(image, (int(box[0][0]), int(box[0][1])),
(int(box[1][0]), int(box[1][1])), (0, 0, 255), 2)


class Body2DKeypointsTest(unittest.TestCase):

def setUp(self) -> None:
self.model_id = 'damo/cv_hrnetv2w32_body-2d-keypoints_image'
self.test_image = 'data/test/images/keypoints_detect/000000438862.jpg'
self.human_detect_model_id = 'damo/cv_resnet18_human-detection'

def pipeline_inference(self, pipeline: Pipeline):
output = pipeline(self.test_image)
poses = np.array(output[OutputKeys.POSES])
scores = np.array(output[OutputKeys.SCORES])
boxes = np.array(output[OutputKeys.BOXES])
assert len(poses) == len(scores) and len(poses) == len(boxes)
image = cv2.imread(self.test_image, -1)
for i in range(len(poses)):
draw_box(image, np.array(boxes[i]))
draw_joints(image, np.array(poses[i]), np.array(scores[i]))
cv2.imwrite('pose_keypoint.jpg', image)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub(self):
human_detector = pipeline(
Tasks.human_detection, model=self.human_detect_model_id)
body_2d_keypoints = pipeline(
Tasks.body_2d_keypoints,
human_detector=human_detector,
model=self.model_id)
self.pipeline_inference(body_2d_keypoints)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save