add human pose eastimation to maas-lib v3 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9491970master
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:64ab6a5556b022cbd398d98cd5bb243a4ee6e4ea6e3285f433eb78b76b53fd4e | |||
size 269177 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:3689831ed23f734ebab9405f48ffbfbbefb778e9de3101a9d56e421ea45288cf | |||
size 248595 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:663545f71af556370c7cba7fd8010a665d00c0b477075562a3d7669c6d853ad3 | |||
size 107685 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:e5c2df473a26427ae57950acec86d1e4d3a49cdf1a18d427cd1a354465408f00 | |||
size 102909 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:44b225eaff012bd016fcfe8a3dbeace93fd418164f40e4b5f5b9f0d76f39097b | |||
size 308635 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:510da487b16303646cf4b500cae0a4168cba2feb3dd706c007a3f5c64400501c | |||
size 148413 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:dbaa52b9ecc59b899500db9200ce65b17aa8b87172c8c70de585fa27c80e7ad1 | |||
size 238442 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:72fcff7fd4da5ede2d3c1a31449769b0595685f7250597f05cd176c4c80ced03 | |||
size 37753 |
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:379e11d7fc3734d3ec95afd0d86460b4653fbf4bb1f57f993610d6a6fd30fd3d | |||
size 1702339 |
@@ -18,6 +18,7 @@ class Models(object): | |||
cascade_mask_rcnn_swin = 'cascade_mask_rcnn_swin' | |||
gpen = 'gpen' | |||
product_retrieval_embedding = 'product-retrieval-embedding' | |||
body_2d_keypoints = 'body-2d-keypoints' | |||
# nlp models | |||
bert = 'bert' | |||
@@ -77,6 +78,7 @@ class Pipelines(object): | |||
action_recognition = 'TAdaConv_action-recognition' | |||
animal_recognation = 'resnet101-animal_recog' | |||
cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' | |||
body_2d_keypoints = 'hrnetv2w32_body-2d-keypoints_image' | |||
human_detection = 'resnet18-human-detection' | |||
object_detection = 'vit-object-detection' | |||
image_classification = 'image-classification' | |||
@@ -1,8 +1,8 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from . import (action_recognition, animal_recognition, cartoon, | |||
cmdssl_video_embedding, face_detection, face_generation, | |||
image_classification, image_color_enhance, image_colorization, | |||
image_denoise, image_instance_segmentation, | |||
from . import (action_recognition, animal_recognition, body_2d_keypoints, | |||
cartoon, cmdssl_video_embedding, face_detection, | |||
face_generation, image_classification, image_color_enhance, | |||
image_colorization, image_denoise, image_instance_segmentation, | |||
image_portrait_enhancement, image_to_image_generation, | |||
image_to_image_translation, object_detection, | |||
product_retrieval_embedding, super_resolution, virual_tryon) |
@@ -0,0 +1,23 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .hrnet_v2 import PoseHighResolutionNetV2 | |||
else: | |||
_import_structure = { | |||
'keypoints_detector': ['PoseHighResolutionNetV2'], | |||
} | |||
import sys | |||
sys.modules[__name__] = LazyImportModule( | |||
__name__, | |||
globals()['__file__'], | |||
_import_structure, | |||
module_spec=__spec__, | |||
extra_objects={}, | |||
) |
@@ -0,0 +1,397 @@ | |||
# The implementation is based on HRNET, available at https://github.com/HRNet/HigherHRNet-Human-Pose-Estimation. | |||
import torch | |||
import torch.nn as nn | |||
BN_MOMENTUM = 0.1 | |||
def conv3x3(in_planes, out_planes, stride=1): | |||
"""3x3 convolution with padding""" | |||
return nn.Conv2d( | |||
in_planes, | |||
out_planes, | |||
kernel_size=3, | |||
stride=stride, | |||
padding=1, | |||
bias=False) | |||
class BasicBlock(nn.Module): | |||
expansion = 1 | |||
def __init__(self, inplanes, planes, stride=1, downsample=None): | |||
super(BasicBlock, self).__init__() | |||
self.conv1 = conv3x3(inplanes, planes, stride) | |||
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) | |||
self.relu = nn.ReLU(inplace=True) | |||
self.conv2 = conv3x3(planes, planes) | |||
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) | |||
self.downsample = downsample | |||
self.stride = stride | |||
def forward(self, x): | |||
residual = x | |||
out = self.conv1(x) | |||
out = self.bn1(out) | |||
out = self.relu(out) | |||
out = self.conv2(out) | |||
out = self.bn2(out) | |||
if self.downsample is not None: | |||
residual = self.downsample(x) | |||
out += residual | |||
out = self.relu(out) | |||
return out | |||
class Bottleneck(nn.Module): | |||
expansion = 4 | |||
def __init__(self, inplanes, planes, stride=1, downsample=None): | |||
super(Bottleneck, self).__init__() | |||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) | |||
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) | |||
self.conv2 = nn.Conv2d( | |||
planes, | |||
planes, | |||
kernel_size=3, | |||
stride=stride, | |||
padding=1, | |||
bias=False) | |||
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) | |||
self.conv3 = nn.Conv2d( | |||
planes, planes * self.expansion, kernel_size=1, bias=False) | |||
self.bn3 = nn.BatchNorm2d( | |||
planes * self.expansion, momentum=BN_MOMENTUM) | |||
self.relu = nn.ReLU(inplace=True) | |||
self.downsample = downsample | |||
self.stride = stride | |||
def forward(self, x): | |||
residual = x | |||
out = self.conv1(x) | |||
out = self.bn1(out) | |||
out = self.relu(out) | |||
out = self.conv2(out) | |||
out = self.bn2(out) | |||
out = self.relu(out) | |||
out = self.conv3(out) | |||
out = self.bn3(out) | |||
if self.downsample is not None: | |||
residual = self.downsample(x) | |||
out += residual | |||
out = self.relu(out) | |||
return out | |||
class HighResolutionModule(nn.Module): | |||
def __init__(self, | |||
num_branches, | |||
blocks, | |||
num_blocks, | |||
num_inchannels, | |||
num_channels, | |||
fuse_method, | |||
multi_scale_output=True): | |||
super(HighResolutionModule, self).__init__() | |||
self._check_branches(num_branches, blocks, num_blocks, num_inchannels, | |||
num_channels) | |||
self.num_inchannels = num_inchannels | |||
self.fuse_method = fuse_method | |||
self.num_branches = num_branches | |||
self.multi_scale_output = multi_scale_output | |||
self.branches = self._make_branches(num_branches, blocks, num_blocks, | |||
num_channels) | |||
self.fuse_layers = self._make_fuse_layers() | |||
self.relu = nn.ReLU(True) | |||
def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels, | |||
num_channels): | |||
if num_branches != len(num_blocks): | |||
error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( | |||
num_branches, len(num_blocks)) | |||
raise ValueError(error_msg) | |||
if num_branches != len(num_channels): | |||
error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( | |||
num_branches, len(num_channels)) | |||
raise ValueError(error_msg) | |||
if num_branches != len(num_inchannels): | |||
error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( | |||
num_branches, len(num_inchannels)) | |||
raise ValueError(error_msg) | |||
def _make_one_branch(self, | |||
branch_index, | |||
block, | |||
num_blocks, | |||
num_channels, | |||
stride=1): | |||
downsample = None | |||
if stride != 1 or \ | |||
self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: | |||
downsample = nn.Sequential( | |||
nn.Conv2d( | |||
self.num_inchannels[branch_index], | |||
num_channels[branch_index] * block.expansion, | |||
kernel_size=1, | |||
stride=stride, | |||
bias=False), | |||
nn.BatchNorm2d( | |||
num_channels[branch_index] * block.expansion, | |||
momentum=BN_MOMENTUM), | |||
) | |||
layers = [] | |||
layers.append( | |||
block(self.num_inchannels[branch_index], | |||
num_channels[branch_index], stride, downsample)) | |||
self.num_inchannels[branch_index] = \ | |||
num_channels[branch_index] * block.expansion | |||
for i in range(1, num_blocks[branch_index]): | |||
layers.append( | |||
block(self.num_inchannels[branch_index], | |||
num_channels[branch_index])) | |||
return nn.Sequential(*layers) | |||
def _make_branches(self, num_branches, block, num_blocks, num_channels): | |||
branches = [] | |||
for i in range(num_branches): | |||
branches.append( | |||
self._make_one_branch(i, block, num_blocks, num_channels)) | |||
return nn.ModuleList(branches) | |||
def _make_fuse_layers(self): | |||
if self.num_branches == 1: | |||
return None | |||
num_branches = self.num_branches | |||
num_inchannels = self.num_inchannels | |||
fuse_layers = [] | |||
for i in range(num_branches if self.multi_scale_output else 1): | |||
fuse_layer = [] | |||
for j in range(num_branches): | |||
if j > i: | |||
fuse_layer.append( | |||
nn.Sequential( | |||
nn.Conv2d( | |||
num_inchannels[j], | |||
num_inchannels[i], | |||
1, | |||
1, | |||
0, | |||
bias=False), nn.BatchNorm2d(num_inchannels[i]), | |||
nn.Upsample( | |||
scale_factor=2**(j - i), mode='nearest'))) | |||
elif j == i: | |||
fuse_layer.append(None) | |||
else: | |||
conv3x3s = [] | |||
for k in range(i - j): | |||
if k == i - j - 1: | |||
num_outchannels_conv3x3 = num_inchannels[i] | |||
conv3x3s.append( | |||
nn.Sequential( | |||
nn.Conv2d( | |||
num_inchannels[j], | |||
num_outchannels_conv3x3, | |||
3, | |||
2, | |||
1, | |||
bias=False), | |||
nn.BatchNorm2d(num_outchannels_conv3x3))) | |||
else: | |||
num_outchannels_conv3x3 = num_inchannels[j] | |||
conv3x3s.append( | |||
nn.Sequential( | |||
nn.Conv2d( | |||
num_inchannels[j], | |||
num_outchannels_conv3x3, | |||
3, | |||
2, | |||
1, | |||
bias=False), | |||
nn.BatchNorm2d(num_outchannels_conv3x3), | |||
nn.ReLU(True))) | |||
fuse_layer.append(nn.Sequential(*conv3x3s)) | |||
fuse_layers.append(nn.ModuleList(fuse_layer)) | |||
return nn.ModuleList(fuse_layers) | |||
def get_num_inchannels(self): | |||
return self.num_inchannels | |||
def forward(self, x): | |||
if self.num_branches == 1: | |||
return [self.branches[0](x[0])] | |||
for i in range(self.num_branches): | |||
x[i] = self.branches[i](x[i]) | |||
x_fuse = [] | |||
for i in range(len(self.fuse_layers)): | |||
y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) | |||
for j in range(1, self.num_branches): | |||
if i == j: | |||
y = y + x[j] | |||
else: | |||
y = y + self.fuse_layers[i][j](x[j]) | |||
x_fuse.append(self.relu(y)) | |||
return x_fuse | |||
def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1): | |||
result = nn.Sequential() | |||
result.add_module( | |||
'conv', | |||
nn.Conv2d( | |||
in_channels=in_channels, | |||
out_channels=out_channels, | |||
kernel_size=kernel_size, | |||
stride=stride, | |||
padding=padding, | |||
groups=groups, | |||
bias=False)) | |||
result.add_module('bn', nn.BatchNorm2d(num_features=out_channels)) | |||
return result | |||
def upsample(scale, oup): | |||
return nn.Sequential( | |||
nn.Upsample(scale_factor=scale, mode='bilinear'), | |||
nn.Conv2d( | |||
in_channels=oup, | |||
out_channels=oup, | |||
kernel_size=3, | |||
stride=1, | |||
padding=1, | |||
groups=1, | |||
bias=False), nn.BatchNorm2d(oup), nn.PReLU()) | |||
class SE_Block(nn.Module): | |||
def __init__(self, c, r=16): | |||
super().__init__() | |||
self.squeeze = nn.AdaptiveAvgPool2d(1) | |||
self.excitation = nn.Sequential( | |||
nn.Linear(c, c // r, bias=False), nn.ReLU(inplace=True), | |||
nn.Linear(c // r, c, bias=False), nn.Sigmoid()) | |||
def forward(self, x): | |||
bs, c, _, _ = x.shape | |||
y = self.squeeze(x).view(bs, c) | |||
y = self.excitation(y).view(bs, c, 1, 1) | |||
return x * y.expand_as(x) | |||
class BasicBlockSE(nn.Module): | |||
expansion = 1 | |||
def __init__(self, inplanes, planes, stride=1, downsample=None, r=64): | |||
super(BasicBlockSE, self).__init__() | |||
self.conv1 = conv3x3(inplanes, planes, stride) | |||
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) | |||
self.relu = nn.ReLU(inplace=True) | |||
self.conv2 = conv3x3(planes, planes) | |||
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) | |||
self.downsample = downsample | |||
self.stride = stride | |||
self.se = SE_Block(planes, r) | |||
def forward(self, x): | |||
residual = x | |||
out = self.conv1(x) | |||
out = self.bn1(out) | |||
out = self.relu(out) | |||
out = self.conv2(out) | |||
out = self.bn2(out) | |||
out = self.se(out) | |||
if self.downsample is not None: | |||
residual = self.downsample(x) | |||
out += residual | |||
out = self.relu(out) | |||
return out | |||
class BottleneckSE(nn.Module): | |||
expansion = 4 | |||
def __init__(self, inplanes, planes, stride=1, downsample=None, r=64): | |||
super(BottleneckSE, self).__init__() | |||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) | |||
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) | |||
self.conv2 = nn.Conv2d( | |||
planes, | |||
planes, | |||
kernel_size=3, | |||
stride=stride, | |||
padding=1, | |||
bias=False) | |||
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) | |||
self.conv3 = nn.Conv2d( | |||
planes, planes * self.expansion, kernel_size=1, bias=False) | |||
self.bn3 = nn.BatchNorm2d( | |||
planes * self.expansion, momentum=BN_MOMENTUM) | |||
self.relu = nn.ReLU(inplace=True) | |||
self.downsample = downsample | |||
self.stride = stride | |||
self.se = SE_Block(planes * self.expansion, r) | |||
def forward(self, x): | |||
residual = x | |||
out = self.conv1(x) | |||
out = self.bn1(out) | |||
out = self.relu(out) | |||
out = self.conv2(out) | |||
out = self.bn2(out) | |||
out = self.relu(out) | |||
out = self.conv3(out) | |||
out = self.bn3(out) | |||
out = self.se(out) | |||
if self.downsample is not None: | |||
residual = self.downsample(x) | |||
out += residual | |||
out = self.relu(out) | |||
return out | |||
blocks_dict = { | |||
'BASIC': BasicBlock, | |||
'BOTTLENECK': Bottleneck, | |||
'BASICSE': BasicBlockSE, | |||
'BOTTLENECKSE': BottleneckSE, | |||
} |
@@ -0,0 +1,221 @@ | |||
import os | |||
import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from modelscope.metainfo import Models | |||
from modelscope.models.base.base_torch_model import TorchModel | |||
from modelscope.models.builder import MODELS | |||
from modelscope.models.cv.body_2d_keypoints.hrnet_basic_modules import ( | |||
BN_MOMENTUM, BasicBlock, Bottleneck, HighResolutionModule, blocks_dict) | |||
from modelscope.models.cv.body_2d_keypoints.w48 import cfg_128x128_15 | |||
from modelscope.utils.constant import Tasks | |||
@MODELS.register_module( | |||
Tasks.body_2d_keypoints, module_name=Models.body_2d_keypoints) | |||
class PoseHighResolutionNetV2(TorchModel): | |||
def __init__(self, cfg=None, **kwargs): | |||
if cfg is None: | |||
cfg = cfg_128x128_15 | |||
self.inplanes = 64 | |||
extra = cfg['MODEL']['EXTRA'] | |||
super(PoseHighResolutionNetV2, self).__init__(**kwargs) | |||
# stem net | |||
self.conv1 = nn.Conv2d( | |||
3, 64, kernel_size=3, stride=2, padding=1, bias=False) | |||
self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) | |||
self.conv2 = nn.Conv2d( | |||
64, 64, kernel_size=3, stride=2, padding=1, bias=False) | |||
self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) | |||
self.relu = nn.ReLU(inplace=True) | |||
self.layer1 = self._make_layer(Bottleneck, 64, 4) | |||
self.stage2_cfg = cfg['MODEL']['EXTRA']['STAGE2'] | |||
num_channels = self.stage2_cfg['NUM_CHANNELS'] | |||
block = blocks_dict[self.stage2_cfg['BLOCK']] | |||
num_channels = [ | |||
num_channels[i] * block.expansion | |||
for i in range(len(num_channels)) | |||
] | |||
self.transition1 = self._make_transition_layer([256], num_channels) | |||
self.stage2, pre_stage_channels = self._make_stage( | |||
self.stage2_cfg, num_channels) | |||
self.stage3_cfg = cfg['MODEL']['EXTRA']['STAGE3'] | |||
num_channels = self.stage3_cfg['NUM_CHANNELS'] | |||
block = blocks_dict[self.stage3_cfg['BLOCK']] | |||
num_channels = [ | |||
num_channels[i] * block.expansion | |||
for i in range(len(num_channels)) | |||
] | |||
self.transition2 = self._make_transition_layer(pre_stage_channels, | |||
num_channels) | |||
self.stage3, pre_stage_channels = self._make_stage( | |||
self.stage3_cfg, num_channels) | |||
self.stage4_cfg = cfg['MODEL']['EXTRA']['STAGE4'] | |||
num_channels = self.stage4_cfg['NUM_CHANNELS'] | |||
block = blocks_dict[self.stage4_cfg['BLOCK']] | |||
num_channels = [ | |||
num_channels[i] * block.expansion | |||
for i in range(len(num_channels)) | |||
] | |||
self.transition3 = self._make_transition_layer(pre_stage_channels, | |||
num_channels) | |||
self.stage4, pre_stage_channels = self._make_stage( | |||
self.stage4_cfg, num_channels, multi_scale_output=True) | |||
"""final four layers""" | |||
last_inp_channels = np.int(np.sum(pre_stage_channels)) | |||
self.final_layer = nn.Sequential( | |||
nn.Conv2d( | |||
in_channels=last_inp_channels, | |||
out_channels=last_inp_channels, | |||
kernel_size=1, | |||
stride=1, | |||
padding=0), | |||
nn.BatchNorm2d(last_inp_channels, momentum=BN_MOMENTUM), | |||
nn.ReLU(inplace=False), | |||
nn.Conv2d( | |||
in_channels=last_inp_channels, | |||
out_channels=cfg['MODEL']['NUM_JOINTS'], | |||
kernel_size=extra['FINAL_CONV_KERNEL'], | |||
stride=1, | |||
padding=1 if extra['FINAL_CONV_KERNEL'] == 3 else 0)) | |||
self.pretrained_layers = cfg['MODEL']['EXTRA']['PRETRAINED_LAYERS'] | |||
def _make_transition_layer(self, num_channels_pre_layer, | |||
num_channels_cur_layer): | |||
num_branches_cur = len(num_channels_cur_layer) | |||
num_branches_pre = len(num_channels_pre_layer) | |||
transition_layers = [] | |||
for i in range(num_branches_cur): | |||
if i < num_branches_pre: | |||
if num_channels_cur_layer[i] != num_channels_pre_layer[i]: | |||
transition_layers.append( | |||
nn.Sequential( | |||
nn.Conv2d( | |||
num_channels_pre_layer[i], | |||
num_channels_cur_layer[i], | |||
3, | |||
1, | |||
1, | |||
bias=False), | |||
nn.BatchNorm2d(num_channels_cur_layer[i]), | |||
nn.ReLU(inplace=True))) | |||
else: | |||
transition_layers.append(None) | |||
else: | |||
conv3x3s = [] | |||
for j in range(i + 1 - num_branches_pre): | |||
inchannels = num_channels_pre_layer[-1] | |||
outchannels = num_channels_cur_layer[ | |||
i] if j == i - num_branches_pre else inchannels | |||
conv3x3s.append( | |||
nn.Sequential( | |||
nn.Conv2d( | |||
inchannels, outchannels, 3, 2, 1, bias=False), | |||
nn.BatchNorm2d(outchannels), | |||
nn.ReLU(inplace=True))) | |||
transition_layers.append(nn.Sequential(*conv3x3s)) | |||
return nn.ModuleList(transition_layers) | |||
def _make_layer(self, block, planes, blocks, stride=1): | |||
downsample = None | |||
if stride != 1 or self.inplanes != planes * block.expansion: | |||
downsample = nn.Sequential( | |||
nn.Conv2d( | |||
self.inplanes, | |||
planes * block.expansion, | |||
kernel_size=1, | |||
stride=stride, | |||
bias=False), | |||
nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), | |||
) | |||
layers = [] | |||
layers.append(block(self.inplanes, planes, stride, downsample)) | |||
self.inplanes = planes * block.expansion | |||
for i in range(1, blocks): | |||
layers.append(block(self.inplanes, planes)) | |||
return nn.Sequential(*layers) | |||
def _make_stage(self, | |||
layer_config, | |||
num_inchannels, | |||
multi_scale_output=True): | |||
num_modules = layer_config['NUM_MODULES'] | |||
num_branches = layer_config['NUM_BRANCHES'] | |||
num_blocks = layer_config['NUM_BLOCKS'] | |||
num_channels = layer_config['NUM_CHANNELS'] | |||
block = blocks_dict[layer_config['BLOCK']] | |||
fuse_method = layer_config['FUSE_METHOD'] | |||
modules = [] | |||
for i in range(num_modules): | |||
if not multi_scale_output and i == num_modules - 1: | |||
reset_multi_scale_output = False | |||
else: | |||
reset_multi_scale_output = True | |||
modules.append( | |||
HighResolutionModule(num_branches, block, num_blocks, | |||
num_inchannels, num_channels, fuse_method, | |||
reset_multi_scale_output)) | |||
num_inchannels = modules[-1].get_num_inchannels() | |||
return nn.Sequential(*modules), num_inchannels | |||
def forward(self, x): | |||
x = self.conv1(x) | |||
x = self.bn1(x) | |||
x = self.relu(x) | |||
x = self.conv2(x) | |||
x = self.bn2(x) | |||
x = self.relu(x) | |||
x = self.layer1(x) | |||
x_list = [] | |||
for i in range(self.stage2_cfg['NUM_BRANCHES']): | |||
if self.transition1[i] is not None: | |||
x_list.append(self.transition1[i](x)) | |||
else: | |||
x_list.append(x) | |||
y_list = self.stage2(x_list) | |||
x_list = [] | |||
for i in range(self.stage3_cfg['NUM_BRANCHES']): | |||
if self.transition2[i] is not None: | |||
x_list.append(self.transition2[i](y_list[-1])) | |||
else: | |||
x_list.append(y_list[i]) | |||
y_list = self.stage3(x_list) | |||
x_list = [] | |||
for i in range(self.stage4_cfg['NUM_BRANCHES']): | |||
if self.transition3[i] is not None: | |||
x_list.append(self.transition3[i](y_list[-1])) | |||
else: | |||
x_list.append(y_list[i]) | |||
y_list = self.stage4(x_list) | |||
y0_h, y0_w = y_list[0].size(2), y_list[0].size(3) | |||
y1 = F.upsample(y_list[1], size=(y0_h, y0_w), mode='bilinear') | |||
y2 = F.upsample(y_list[2], size=(y0_h, y0_w), mode='bilinear') | |||
y3 = F.upsample(y_list[3], size=(y0_h, y0_w), mode='bilinear') | |||
y = torch.cat([y_list[0], y1, y2, y3], 1) | |||
output = self.final_layer(y) | |||
return output |
@@ -0,0 +1,51 @@ | |||
cfg_128x128_15 = { | |||
'DATASET': { | |||
'TYPE': 'DAMO', | |||
'PARENT_IDS': [0, 0, 1, 2, 3, 1, 5, 6, 14, 8, 9, 14, 11, 12, 1], | |||
'LEFT_IDS': [2, 3, 4, 8, 9, 10], | |||
'RIGHT_IDS': [5, 6, 7, 11, 12, 13], | |||
'SPINE_IDS': [0, 1, 14] | |||
}, | |||
'MODEL': { | |||
'INIT_WEIGHTS': True, | |||
'NAME': 'pose_hrnet', | |||
'NUM_JOINTS': 15, | |||
'PRETRAINED': '', | |||
'TARGET_TYPE': 'gaussian', | |||
'IMAGE_SIZE': [128, 128], | |||
'HEATMAP_SIZE': [32, 32], | |||
'SIGMA': 2.0, | |||
'EXTRA': { | |||
'PRETRAINED_LAYERS': [ | |||
'conv1', 'bn1', 'conv2', 'bn2', 'layer1', 'transition1', | |||
'stage2', 'transition2', 'stage3', 'transition3', 'stage4' | |||
], | |||
'FINAL_CONV_KERNEL': | |||
1, | |||
'STAGE2': { | |||
'NUM_MODULES': 1, | |||
'NUM_BRANCHES': 2, | |||
'BLOCK': 'BASIC', | |||
'NUM_BLOCKS': [4, 4], | |||
'NUM_CHANNELS': [48, 96], | |||
'FUSE_METHOD': 'SUM' | |||
}, | |||
'STAGE3': { | |||
'NUM_MODULES': 4, | |||
'NUM_BRANCHES': 3, | |||
'BLOCK': 'BASIC', | |||
'NUM_BLOCKS': [4, 4, 4], | |||
'NUM_CHANNELS': [48, 96, 192], | |||
'FUSE_METHOD': 'SUM' | |||
}, | |||
'STAGE4': { | |||
'NUM_MODULES': 3, | |||
'NUM_BRANCHES': 4, | |||
'BLOCK': 'BASIC', | |||
'NUM_BLOCKS': [4, 4, 4, 4], | |||
'NUM_CHANNELS': [48, 96, 192, 384], | |||
'FUSE_METHOD': 'SUM' | |||
}, | |||
} | |||
} | |||
} |
@@ -158,6 +158,27 @@ TASK_OUTPUTS = { | |||
# } | |||
Tasks.action_recognition: [OutputKeys.LABELS], | |||
# human body keypoints detection result for single sample | |||
# { | |||
# "poses": [ | |||
# [x, y], | |||
# [x, y], | |||
# [x, y] | |||
# ] | |||
# "scores": [ | |||
# [score], | |||
# [score], | |||
# [score], | |||
# ] | |||
# "boxes": [ | |||
# [x1, y1, x2, y2], | |||
# [x1, y1, x2, y2], | |||
# [x1, y1, x2, y2], | |||
# ] | |||
# } | |||
Tasks.body_2d_keypoints: | |||
[OutputKeys.POSES, OutputKeys.SCORES, OutputKeys.BOXES], | |||
# live category recognition result for single video | |||
# { | |||
# "scores": [0.885272, 0.014790631, 0.014558001], | |||
@@ -87,6 +87,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
Tasks.text_to_image_synthesis: | |||
(Pipelines.text_to_image_synthesis, | |||
'damo/cv_diffusion_text-to-image-synthesis_tiny'), | |||
Tasks.body_2d_keypoints: (Pipelines.body_2d_keypoints, | |||
'damo/cv_hrnetv2w32_body-2d-keypoints_image'), | |||
Tasks.face_detection: (Pipelines.face_detection, | |||
'damo/cv_resnet_facedetection_scrfd10gkps'), | |||
Tasks.face_recognition: (Pipelines.face_recognition, | |||
@@ -238,6 +240,7 @@ def pipeline(task: str = None, | |||
cfg = ConfigDict(type=pipeline_name, model=model) | |||
cfg.device = device | |||
if kwargs: | |||
cfg.update(kwargs) | |||
@@ -6,6 +6,7 @@ from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .action_recognition_pipeline import ActionRecognitionPipeline | |||
from .animal_recognition_pipeline import AnimalRecognitionPipeline | |||
from .body_2d_keypoints_pipeline import Body2DKeypointsPipeline | |||
from .cmdssl_video_embedding_pipeline import CMDSSLVideoEmbeddingPipeline | |||
from .image_detection_pipeline import ImageDetectionPipeline | |||
from .face_detection_pipeline import FaceDetectionPipeline | |||
@@ -34,6 +35,7 @@ else: | |||
_import_structure = { | |||
'action_recognition_pipeline': ['ActionRecognitionPipeline'], | |||
'animal_recognition_pipeline': ['AnimalRecognitionPipeline'], | |||
'body_2d_keypoints_pipeline': ['Body2DKeypointsPipeline'], | |||
'cmdssl_video_embedding_pipeline': ['CMDSSLVideoEmbeddingPipeline'], | |||
'image_detection_pipeline': ['ImageDetectionPipeline'], | |||
'face_detection_pipeline': ['FaceDetectionPipeline'], | |||
@@ -0,0 +1,261 @@ | |||
import os.path as osp | |||
from typing import Any, Dict, List, Union | |||
import cv2 | |||
import json | |||
import numpy as np | |||
import torch | |||
from PIL import Image | |||
from torchvision import transforms | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.models.cv.body_2d_keypoints.hrnet_v2 import \ | |||
PoseHighResolutionNetV2 | |||
from modelscope.models.cv.body_2d_keypoints.w48 import cfg_128x128_15 | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines import pipeline | |||
from modelscope.pipelines.base import Input, Model, Pipeline, Tensor | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import load_image | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
@PIPELINES.register_module( | |||
Tasks.body_2d_keypoints, module_name=Pipelines.body_2d_keypoints) | |||
class Body2DKeypointsPipeline(Pipeline): | |||
def __init__(self, model: str, human_detector: Pipeline, **kwargs): | |||
super().__init__(model=model, **kwargs) | |||
self.keypoint_model = KeypointsDetection(model) | |||
self.keypoint_model.eval() | |||
self.human_detector = human_detector | |||
def preprocess(self, input: Input) -> Dict[Tensor, Union[str, np.ndarray]]: | |||
output = self.human_detector(input) | |||
if isinstance(input, str): | |||
image = cv2.imread(input, -1)[:, :, 0:3] | |||
elif isinstance(input, np.ndarray): | |||
if len(input.shape) == 2: | |||
image = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) | |||
image = image[:, :, 0:3] | |||
return {'image': image, 'output': output} | |||
def forward(self, input: Tensor) -> Dict[Tensor, Dict[str, np.ndarray]]: | |||
input_image = input['image'] | |||
output = input['output'] | |||
bboxes = [] | |||
scores = np.array(output[OutputKeys.SCORES].cpu(), dtype=np.float32) | |||
boxes = np.array(output[OutputKeys.BOXES].cpu(), dtype=np.float32) | |||
for id, box in enumerate(boxes): | |||
box_tmp = [ | |||
box[0], box[1], box[2] - box[0], box[3] - box[1], scores[id], 0 | |||
] | |||
bboxes.append(box_tmp) | |||
if len(bboxes) == 0: | |||
logger.error('cannot detect human in the image') | |||
return [None, None] | |||
human_images, metas = self.keypoint_model.preprocess( | |||
[bboxes, input_image]) | |||
outputs = self.keypoint_model.forward(human_images) | |||
return [outputs, metas] | |||
def postprocess(self, input: Dict[Tensor, Dict[str, np.ndarray]], | |||
**kwargs) -> str: | |||
if input[0] is None or input[1] is None: | |||
return { | |||
OutputKeys.BOXES: [], | |||
OutputKeys.POSES: [], | |||
OutputKeys.SCORES: [] | |||
} | |||
poses, scores, boxes = self.keypoint_model.postprocess(input) | |||
return { | |||
OutputKeys.BOXES: boxes, | |||
OutputKeys.POSES: poses, | |||
OutputKeys.SCORES: scores | |||
} | |||
class KeypointsDetection(): | |||
def __init__(self, model: str, **kwargs): | |||
self.model = model | |||
cfg = cfg_128x128_15 | |||
self.key_points_model = PoseHighResolutionNetV2(cfg) | |||
pretrained_state_dict = torch.load( | |||
osp.join(self.model, ModelFile.TORCH_MODEL_FILE)) | |||
self.key_points_model.load_state_dict( | |||
pretrained_state_dict, strict=False) | |||
self.input_size = cfg['MODEL']['IMAGE_SIZE'] | |||
self.lst_parent_ids = cfg['DATASET']['PARENT_IDS'] | |||
self.lst_left_ids = cfg['DATASET']['LEFT_IDS'] | |||
self.lst_right_ids = cfg['DATASET']['RIGHT_IDS'] | |||
self.box_enlarge_ratio = 0.05 | |||
def train(self): | |||
return self.key_points_model.train() | |||
def eval(self): | |||
return self.key_points_model.eval() | |||
def forward(self, input: Tensor) -> Tensor: | |||
with torch.no_grad(): | |||
return self.key_points_model.forward(input) | |||
def get_pts(self, heatmaps): | |||
[pts_num, height, width] = heatmaps.shape | |||
pts = [] | |||
scores = [] | |||
for i in range(pts_num): | |||
heatmap = heatmaps[i, :, :] | |||
pt = np.where(heatmap == np.max(heatmap)) | |||
scores.append(np.max(heatmap)) | |||
x = pt[1][0] | |||
y = pt[0][0] | |||
[h, w] = heatmap.shape | |||
if x >= 1 and x <= w - 2 and y >= 1 and y <= h - 2: | |||
x_diff = heatmap[y, x + 1] - heatmap[y, x - 1] | |||
y_diff = heatmap[y + 1, x] - heatmap[y - 1, x] | |||
x_sign = 0 | |||
y_sign = 0 | |||
if x_diff < 0: | |||
x_sign = -1 | |||
if x_diff > 0: | |||
x_sign = 1 | |||
if y_diff < 0: | |||
y_sign = -1 | |||
if y_diff > 0: | |||
y_sign = 1 | |||
x = x + x_sign * 0.25 | |||
y = y + y_sign * 0.25 | |||
pts.append([x, y]) | |||
return pts, scores | |||
def pts_transform(self, meta, pts, lt_x, lt_y): | |||
pts_new = [] | |||
s = meta['s'] | |||
o = meta['o'] | |||
size = len(pts) | |||
for i in range(size): | |||
ratio = 4 | |||
x = (int(pts[i][0] * ratio) - o[0]) / s[0] | |||
y = (int(pts[i][1] * ratio) - o[1]) / s[1] | |||
pt = [x, y] | |||
pts_new.append(pt) | |||
return pts_new | |||
def postprocess(self, inputs: Dict[Tensor, Dict[str, np.ndarray]], | |||
**kwargs): | |||
output_poses = [] | |||
output_scores = [] | |||
output_boxes = [] | |||
for i in range(inputs[0].shape[0]): | |||
outputs, scores = self.get_pts( | |||
(inputs[0][i]).detach().cpu().numpy()) | |||
outputs = self.pts_transform(inputs[1][i], outputs, 0, 0) | |||
box = np.array(inputs[1][i]['human_box'][0:4]).reshape(2, 2) | |||
outputs = np.array(outputs) + box[0] | |||
output_poses.append(outputs.tolist()) | |||
output_scores.append(scores) | |||
output_boxes.append(box.tolist()) | |||
return output_poses, output_scores, output_boxes | |||
def image_crop_resize(self, input, margin=[0, 0]): | |||
pad_img = np.zeros((self.input_size[1], self.input_size[0], 3), | |||
dtype=np.uint8) | |||
h, w, ch = input.shape | |||
h_new = self.input_size[1] - margin[1] * 2 | |||
w_new = self.input_size[0] - margin[0] * 2 | |||
s0 = float(h_new) / h | |||
s1 = float(w_new) / w | |||
s = min(s0, s1) | |||
w_new = int(s * w) | |||
h_new = int(s * h) | |||
img_new = cv2.resize(input, (w_new, h_new), cv2.INTER_LINEAR) | |||
cx = self.input_size[0] // 2 | |||
cy = self.input_size[1] // 2 | |||
pad_img[cy - h_new // 2:cy - h_new // 2 + h_new, | |||
cx - w_new // 2:cx - w_new // 2 + w_new, :] = img_new | |||
return pad_img, np.array([cx, cy]), np.array([s, s]), np.array( | |||
[cx - w_new // 2, cy - h_new // 2]) | |||
def image_transform(self, input: Input) -> Dict[Tensor, Any]: | |||
if isinstance(input, str): | |||
image = cv2.imread(input, -1)[:, :, 0:3] | |||
elif isinstance(input, np.ndarray): | |||
if len(input.shape) == 2: | |||
image = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) | |||
else: | |||
image = input | |||
image = image[:, :, 0:3] | |||
elif isinstance(input, torch.Tensor): | |||
image = input.cpu().numpy()[:, :, 0:3] | |||
w, h, _ = image.shape | |||
w_new = self.input_size[0] | |||
h_new = self.input_size[1] | |||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |||
img_resize, c, s, o = self.image_crop_resize(image) | |||
img_resize = np.float32(img_resize) / 255. | |||
mean = [0.485, 0.456, 0.406] | |||
std = [0.229, 0.224, 0.225] | |||
img_resize = (img_resize - mean) / std | |||
input_data = np.zeros([1, 3, h_new, w_new], dtype=np.float32) | |||
img_resize = img_resize.transpose((2, 0, 1)) | |||
input_data[0, :] = img_resize | |||
meta = {'c': c, 's': s, 'o': o} | |||
return [torch.from_numpy(input_data), meta] | |||
def crop_image(self, image, box): | |||
height, width, _ = image.shape | |||
w, h = box[1] - box[0] | |||
box[0, :] -= (w * self.box_enlarge_ratio, h * self.box_enlarge_ratio) | |||
box[1, :] += (w * self.box_enlarge_ratio, h * self.box_enlarge_ratio) | |||
box[0, 0] = min(max(box[0, 0], 0.0), width) | |||
box[0, 1] = min(max(box[0, 1], 0.0), height) | |||
box[1, 0] = min(max(box[1, 0], 0.0), width) | |||
box[1, 1] = min(max(box[1, 1], 0.0), height) | |||
cropped_image = image[int(box[0][1]):int(box[1][1]), | |||
int(box[0][0]):int(box[1][0])] | |||
return cropped_image | |||
def preprocess(self, input: Dict[Tensor, Tensor]) -> Dict[Tensor, Any]: | |||
bboxes = input[0] | |||
image = input[1] | |||
lst_human_images = [] | |||
lst_meta = [] | |||
for i in range(len(bboxes)): | |||
box = np.array(bboxes[i][0:4]).reshape(2, 2) | |||
box[1] += box[0] | |||
human_image = self.crop_image(image.clone(), box) | |||
human_image, meta = self.image_transform(human_image) | |||
lst_human_images.append(human_image) | |||
meta['human_box'] = box | |||
lst_meta.append(meta) | |||
return [torch.cat(lst_human_images, dim=0), lst_meta] |
@@ -22,6 +22,7 @@ class CVTasks(object): | |||
human_detection = 'human-detection' | |||
human_object_interaction = 'human-object-interaction' | |||
face_image_generation = 'face-image-generation' | |||
body_2d_keypoints = 'body-2d-keypoints' | |||
image_classification = 'image-classification' | |||
image_multilabel_classification = 'image-multilabel-classification' | |||
@@ -0,0 +1,100 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
import os.path as osp | |||
import pdb | |||
import unittest | |||
import cv2 | |||
import numpy as np | |||
import torch | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines import pipeline | |||
from modelscope.pipelines.base import Pipeline | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.test_utils import test_level | |||
lst_parent_ids_17 = [0, 0, 0, 1, 2, 0, 0, 5, 6, 7, 8, 5, 6, 11, 12, 13, 14] | |||
lst_left_ids_17 = [1, 3, 5, 7, 9, 11, 13, 15] | |||
lst_right_ids_17 = [2, 4, 6, 8, 10, 12, 14, 16] | |||
lst_spine_ids_17 = [0] | |||
lst_parent_ids_15 = [0, 0, 1, 2, 3, 1, 5, 6, 14, 8, 9, 14, 11, 12, 1] | |||
lst_left_ids_15 = [2, 3, 4, 8, 9, 10] | |||
lst_right_ids_15 = [5, 6, 7, 11, 12, 13] | |||
lst_spine_ids_15 = [0, 1, 14] | |||
def draw_joints(image, np_kps, score, threshold=0.2): | |||
if np_kps.shape[0] == 17: | |||
lst_parent_ids = lst_parent_ids_17 | |||
lst_left_ids = lst_left_ids_17 | |||
lst_right_ids = lst_right_ids_17 | |||
elif np_kps.shape[0] == 15: | |||
lst_parent_ids = lst_parent_ids_15 | |||
lst_left_ids = lst_left_ids_15 | |||
lst_right_ids = lst_right_ids_15 | |||
for i in range(len(lst_parent_ids)): | |||
pid = lst_parent_ids[i] | |||
if i == pid: | |||
continue | |||
if (score[i] < threshold or score[1] < threshold): | |||
continue | |||
if i in lst_left_ids and pid in lst_left_ids: | |||
color = (0, 255, 0) | |||
elif i in lst_right_ids and pid in lst_right_ids: | |||
color = (255, 0, 0) | |||
else: | |||
color = (0, 255, 255) | |||
cv2.line(image, (int(np_kps[i, 0]), int(np_kps[i, 1])), | |||
(int(np_kps[pid][0]), int(np_kps[pid, 1])), color, 3) | |||
for i in range(np_kps.shape[0]): | |||
if score[i] < threshold: | |||
continue | |||
cv2.circle(image, (int(np_kps[i, 0]), int(np_kps[i, 1])), 5, | |||
(0, 0, 255), -1) | |||
def draw_box(image, box): | |||
cv2.rectangle(image, (int(box[0][0]), int(box[0][1])), | |||
(int(box[1][0]), int(box[1][1])), (0, 0, 255), 2) | |||
class Body2DKeypointsTest(unittest.TestCase): | |||
def setUp(self) -> None: | |||
self.model_id = 'damo/cv_hrnetv2w32_body-2d-keypoints_image' | |||
self.test_image = 'data/test/images/keypoints_detect/000000438862.jpg' | |||
self.human_detect_model_id = 'damo/cv_resnet18_human-detection' | |||
def pipeline_inference(self, pipeline: Pipeline): | |||
output = pipeline(self.test_image) | |||
poses = np.array(output[OutputKeys.POSES]) | |||
scores = np.array(output[OutputKeys.SCORES]) | |||
boxes = np.array(output[OutputKeys.BOXES]) | |||
assert len(poses) == len(scores) and len(poses) == len(boxes) | |||
image = cv2.imread(self.test_image, -1) | |||
for i in range(len(poses)): | |||
draw_box(image, np.array(boxes[i])) | |||
draw_joints(image, np.array(poses[i]), np.array(scores[i])) | |||
cv2.imwrite('pose_keypoint.jpg', image) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_modelhub(self): | |||
human_detector = pipeline( | |||
Tasks.human_detection, model=self.human_detect_model_id) | |||
body_2d_keypoints = pipeline( | |||
Tasks.body_2d_keypoints, | |||
human_detector=human_detector, | |||
model=self.model_id) | |||
self.pipeline_inference(body_2d_keypoints) | |||
if __name__ == '__main__': | |||
unittest.main() |