add human pose eastimation to maas-lib v3 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9491970master
@@ -0,0 +1,3 @@ | |||||
version https://git-lfs.github.com/spec/v1 | |||||
oid sha256:64ab6a5556b022cbd398d98cd5bb243a4ee6e4ea6e3285f433eb78b76b53fd4e | |||||
size 269177 |
@@ -0,0 +1,3 @@ | |||||
version https://git-lfs.github.com/spec/v1 | |||||
oid sha256:3689831ed23f734ebab9405f48ffbfbbefb778e9de3101a9d56e421ea45288cf | |||||
size 248595 |
@@ -0,0 +1,3 @@ | |||||
version https://git-lfs.github.com/spec/v1 | |||||
oid sha256:663545f71af556370c7cba7fd8010a665d00c0b477075562a3d7669c6d853ad3 | |||||
size 107685 |
@@ -0,0 +1,3 @@ | |||||
version https://git-lfs.github.com/spec/v1 | |||||
oid sha256:e5c2df473a26427ae57950acec86d1e4d3a49cdf1a18d427cd1a354465408f00 | |||||
size 102909 |
@@ -0,0 +1,3 @@ | |||||
version https://git-lfs.github.com/spec/v1 | |||||
oid sha256:44b225eaff012bd016fcfe8a3dbeace93fd418164f40e4b5f5b9f0d76f39097b | |||||
size 308635 |
@@ -0,0 +1,3 @@ | |||||
version https://git-lfs.github.com/spec/v1 | |||||
oid sha256:510da487b16303646cf4b500cae0a4168cba2feb3dd706c007a3f5c64400501c | |||||
size 148413 |
@@ -0,0 +1,3 @@ | |||||
version https://git-lfs.github.com/spec/v1 | |||||
oid sha256:dbaa52b9ecc59b899500db9200ce65b17aa8b87172c8c70de585fa27c80e7ad1 | |||||
size 238442 |
@@ -0,0 +1,3 @@ | |||||
version https://git-lfs.github.com/spec/v1 | |||||
oid sha256:72fcff7fd4da5ede2d3c1a31449769b0595685f7250597f05cd176c4c80ced03 | |||||
size 37753 |
@@ -0,0 +1,3 @@ | |||||
version https://git-lfs.github.com/spec/v1 | |||||
oid sha256:379e11d7fc3734d3ec95afd0d86460b4653fbf4bb1f57f993610d6a6fd30fd3d | |||||
size 1702339 |
@@ -18,6 +18,7 @@ class Models(object): | |||||
cascade_mask_rcnn_swin = 'cascade_mask_rcnn_swin' | cascade_mask_rcnn_swin = 'cascade_mask_rcnn_swin' | ||||
gpen = 'gpen' | gpen = 'gpen' | ||||
product_retrieval_embedding = 'product-retrieval-embedding' | product_retrieval_embedding = 'product-retrieval-embedding' | ||||
body_2d_keypoints = 'body-2d-keypoints' | |||||
# nlp models | # nlp models | ||||
bert = 'bert' | bert = 'bert' | ||||
@@ -77,6 +78,7 @@ class Pipelines(object): | |||||
action_recognition = 'TAdaConv_action-recognition' | action_recognition = 'TAdaConv_action-recognition' | ||||
animal_recognation = 'resnet101-animal_recog' | animal_recognation = 'resnet101-animal_recog' | ||||
cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' | cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' | ||||
body_2d_keypoints = 'hrnetv2w32_body-2d-keypoints_image' | |||||
human_detection = 'resnet18-human-detection' | human_detection = 'resnet18-human-detection' | ||||
object_detection = 'vit-object-detection' | object_detection = 'vit-object-detection' | ||||
image_classification = 'image-classification' | image_classification = 'image-classification' | ||||
@@ -1,8 +1,8 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
from . import (action_recognition, animal_recognition, cartoon, | |||||
cmdssl_video_embedding, face_detection, face_generation, | |||||
image_classification, image_color_enhance, image_colorization, | |||||
image_denoise, image_instance_segmentation, | |||||
from . import (action_recognition, animal_recognition, body_2d_keypoints, | |||||
cartoon, cmdssl_video_embedding, face_detection, | |||||
face_generation, image_classification, image_color_enhance, | |||||
image_colorization, image_denoise, image_instance_segmentation, | |||||
image_portrait_enhancement, image_to_image_generation, | image_portrait_enhancement, image_to_image_generation, | ||||
image_to_image_translation, object_detection, | image_to_image_translation, object_detection, | ||||
product_retrieval_embedding, super_resolution, virual_tryon) | product_retrieval_embedding, super_resolution, virual_tryon) |
@@ -0,0 +1,23 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from typing import TYPE_CHECKING | |||||
from modelscope.utils.import_utils import LazyImportModule | |||||
if TYPE_CHECKING: | |||||
from .hrnet_v2 import PoseHighResolutionNetV2 | |||||
else: | |||||
_import_structure = { | |||||
'keypoints_detector': ['PoseHighResolutionNetV2'], | |||||
} | |||||
import sys | |||||
sys.modules[__name__] = LazyImportModule( | |||||
__name__, | |||||
globals()['__file__'], | |||||
_import_structure, | |||||
module_spec=__spec__, | |||||
extra_objects={}, | |||||
) |
@@ -0,0 +1,397 @@ | |||||
# The implementation is based on HRNET, available at https://github.com/HRNet/HigherHRNet-Human-Pose-Estimation. | |||||
import torch | |||||
import torch.nn as nn | |||||
BN_MOMENTUM = 0.1 | |||||
def conv3x3(in_planes, out_planes, stride=1): | |||||
"""3x3 convolution with padding""" | |||||
return nn.Conv2d( | |||||
in_planes, | |||||
out_planes, | |||||
kernel_size=3, | |||||
stride=stride, | |||||
padding=1, | |||||
bias=False) | |||||
class BasicBlock(nn.Module): | |||||
expansion = 1 | |||||
def __init__(self, inplanes, planes, stride=1, downsample=None): | |||||
super(BasicBlock, self).__init__() | |||||
self.conv1 = conv3x3(inplanes, planes, stride) | |||||
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) | |||||
self.relu = nn.ReLU(inplace=True) | |||||
self.conv2 = conv3x3(planes, planes) | |||||
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) | |||||
self.downsample = downsample | |||||
self.stride = stride | |||||
def forward(self, x): | |||||
residual = x | |||||
out = self.conv1(x) | |||||
out = self.bn1(out) | |||||
out = self.relu(out) | |||||
out = self.conv2(out) | |||||
out = self.bn2(out) | |||||
if self.downsample is not None: | |||||
residual = self.downsample(x) | |||||
out += residual | |||||
out = self.relu(out) | |||||
return out | |||||
class Bottleneck(nn.Module): | |||||
expansion = 4 | |||||
def __init__(self, inplanes, planes, stride=1, downsample=None): | |||||
super(Bottleneck, self).__init__() | |||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) | |||||
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) | |||||
self.conv2 = nn.Conv2d( | |||||
planes, | |||||
planes, | |||||
kernel_size=3, | |||||
stride=stride, | |||||
padding=1, | |||||
bias=False) | |||||
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) | |||||
self.conv3 = nn.Conv2d( | |||||
planes, planes * self.expansion, kernel_size=1, bias=False) | |||||
self.bn3 = nn.BatchNorm2d( | |||||
planes * self.expansion, momentum=BN_MOMENTUM) | |||||
self.relu = nn.ReLU(inplace=True) | |||||
self.downsample = downsample | |||||
self.stride = stride | |||||
def forward(self, x): | |||||
residual = x | |||||
out = self.conv1(x) | |||||
out = self.bn1(out) | |||||
out = self.relu(out) | |||||
out = self.conv2(out) | |||||
out = self.bn2(out) | |||||
out = self.relu(out) | |||||
out = self.conv3(out) | |||||
out = self.bn3(out) | |||||
if self.downsample is not None: | |||||
residual = self.downsample(x) | |||||
out += residual | |||||
out = self.relu(out) | |||||
return out | |||||
class HighResolutionModule(nn.Module): | |||||
def __init__(self, | |||||
num_branches, | |||||
blocks, | |||||
num_blocks, | |||||
num_inchannels, | |||||
num_channels, | |||||
fuse_method, | |||||
multi_scale_output=True): | |||||
super(HighResolutionModule, self).__init__() | |||||
self._check_branches(num_branches, blocks, num_blocks, num_inchannels, | |||||
num_channels) | |||||
self.num_inchannels = num_inchannels | |||||
self.fuse_method = fuse_method | |||||
self.num_branches = num_branches | |||||
self.multi_scale_output = multi_scale_output | |||||
self.branches = self._make_branches(num_branches, blocks, num_blocks, | |||||
num_channels) | |||||
self.fuse_layers = self._make_fuse_layers() | |||||
self.relu = nn.ReLU(True) | |||||
def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels, | |||||
num_channels): | |||||
if num_branches != len(num_blocks): | |||||
error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( | |||||
num_branches, len(num_blocks)) | |||||
raise ValueError(error_msg) | |||||
if num_branches != len(num_channels): | |||||
error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( | |||||
num_branches, len(num_channels)) | |||||
raise ValueError(error_msg) | |||||
if num_branches != len(num_inchannels): | |||||
error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( | |||||
num_branches, len(num_inchannels)) | |||||
raise ValueError(error_msg) | |||||
def _make_one_branch(self, | |||||
branch_index, | |||||
block, | |||||
num_blocks, | |||||
num_channels, | |||||
stride=1): | |||||
downsample = None | |||||
if stride != 1 or \ | |||||
self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: | |||||
downsample = nn.Sequential( | |||||
nn.Conv2d( | |||||
self.num_inchannels[branch_index], | |||||
num_channels[branch_index] * block.expansion, | |||||
kernel_size=1, | |||||
stride=stride, | |||||
bias=False), | |||||
nn.BatchNorm2d( | |||||
num_channels[branch_index] * block.expansion, | |||||
momentum=BN_MOMENTUM), | |||||
) | |||||
layers = [] | |||||
layers.append( | |||||
block(self.num_inchannels[branch_index], | |||||
num_channels[branch_index], stride, downsample)) | |||||
self.num_inchannels[branch_index] = \ | |||||
num_channels[branch_index] * block.expansion | |||||
for i in range(1, num_blocks[branch_index]): | |||||
layers.append( | |||||
block(self.num_inchannels[branch_index], | |||||
num_channels[branch_index])) | |||||
return nn.Sequential(*layers) | |||||
def _make_branches(self, num_branches, block, num_blocks, num_channels): | |||||
branches = [] | |||||
for i in range(num_branches): | |||||
branches.append( | |||||
self._make_one_branch(i, block, num_blocks, num_channels)) | |||||
return nn.ModuleList(branches) | |||||
def _make_fuse_layers(self): | |||||
if self.num_branches == 1: | |||||
return None | |||||
num_branches = self.num_branches | |||||
num_inchannels = self.num_inchannels | |||||
fuse_layers = [] | |||||
for i in range(num_branches if self.multi_scale_output else 1): | |||||
fuse_layer = [] | |||||
for j in range(num_branches): | |||||
if j > i: | |||||
fuse_layer.append( | |||||
nn.Sequential( | |||||
nn.Conv2d( | |||||
num_inchannels[j], | |||||
num_inchannels[i], | |||||
1, | |||||
1, | |||||
0, | |||||
bias=False), nn.BatchNorm2d(num_inchannels[i]), | |||||
nn.Upsample( | |||||
scale_factor=2**(j - i), mode='nearest'))) | |||||
elif j == i: | |||||
fuse_layer.append(None) | |||||
else: | |||||
conv3x3s = [] | |||||
for k in range(i - j): | |||||
if k == i - j - 1: | |||||
num_outchannels_conv3x3 = num_inchannels[i] | |||||
conv3x3s.append( | |||||
nn.Sequential( | |||||
nn.Conv2d( | |||||
num_inchannels[j], | |||||
num_outchannels_conv3x3, | |||||
3, | |||||
2, | |||||
1, | |||||
bias=False), | |||||
nn.BatchNorm2d(num_outchannels_conv3x3))) | |||||
else: | |||||
num_outchannels_conv3x3 = num_inchannels[j] | |||||
conv3x3s.append( | |||||
nn.Sequential( | |||||
nn.Conv2d( | |||||
num_inchannels[j], | |||||
num_outchannels_conv3x3, | |||||
3, | |||||
2, | |||||
1, | |||||
bias=False), | |||||
nn.BatchNorm2d(num_outchannels_conv3x3), | |||||
nn.ReLU(True))) | |||||
fuse_layer.append(nn.Sequential(*conv3x3s)) | |||||
fuse_layers.append(nn.ModuleList(fuse_layer)) | |||||
return nn.ModuleList(fuse_layers) | |||||
def get_num_inchannels(self): | |||||
return self.num_inchannels | |||||
def forward(self, x): | |||||
if self.num_branches == 1: | |||||
return [self.branches[0](x[0])] | |||||
for i in range(self.num_branches): | |||||
x[i] = self.branches[i](x[i]) | |||||
x_fuse = [] | |||||
for i in range(len(self.fuse_layers)): | |||||
y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) | |||||
for j in range(1, self.num_branches): | |||||
if i == j: | |||||
y = y + x[j] | |||||
else: | |||||
y = y + self.fuse_layers[i][j](x[j]) | |||||
x_fuse.append(self.relu(y)) | |||||
return x_fuse | |||||
def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1): | |||||
result = nn.Sequential() | |||||
result.add_module( | |||||
'conv', | |||||
nn.Conv2d( | |||||
in_channels=in_channels, | |||||
out_channels=out_channels, | |||||
kernel_size=kernel_size, | |||||
stride=stride, | |||||
padding=padding, | |||||
groups=groups, | |||||
bias=False)) | |||||
result.add_module('bn', nn.BatchNorm2d(num_features=out_channels)) | |||||
return result | |||||
def upsample(scale, oup): | |||||
return nn.Sequential( | |||||
nn.Upsample(scale_factor=scale, mode='bilinear'), | |||||
nn.Conv2d( | |||||
in_channels=oup, | |||||
out_channels=oup, | |||||
kernel_size=3, | |||||
stride=1, | |||||
padding=1, | |||||
groups=1, | |||||
bias=False), nn.BatchNorm2d(oup), nn.PReLU()) | |||||
class SE_Block(nn.Module): | |||||
def __init__(self, c, r=16): | |||||
super().__init__() | |||||
self.squeeze = nn.AdaptiveAvgPool2d(1) | |||||
self.excitation = nn.Sequential( | |||||
nn.Linear(c, c // r, bias=False), nn.ReLU(inplace=True), | |||||
nn.Linear(c // r, c, bias=False), nn.Sigmoid()) | |||||
def forward(self, x): | |||||
bs, c, _, _ = x.shape | |||||
y = self.squeeze(x).view(bs, c) | |||||
y = self.excitation(y).view(bs, c, 1, 1) | |||||
return x * y.expand_as(x) | |||||
class BasicBlockSE(nn.Module): | |||||
expansion = 1 | |||||
def __init__(self, inplanes, planes, stride=1, downsample=None, r=64): | |||||
super(BasicBlockSE, self).__init__() | |||||
self.conv1 = conv3x3(inplanes, planes, stride) | |||||
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) | |||||
self.relu = nn.ReLU(inplace=True) | |||||
self.conv2 = conv3x3(planes, planes) | |||||
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) | |||||
self.downsample = downsample | |||||
self.stride = stride | |||||
self.se = SE_Block(planes, r) | |||||
def forward(self, x): | |||||
residual = x | |||||
out = self.conv1(x) | |||||
out = self.bn1(out) | |||||
out = self.relu(out) | |||||
out = self.conv2(out) | |||||
out = self.bn2(out) | |||||
out = self.se(out) | |||||
if self.downsample is not None: | |||||
residual = self.downsample(x) | |||||
out += residual | |||||
out = self.relu(out) | |||||
return out | |||||
class BottleneckSE(nn.Module): | |||||
expansion = 4 | |||||
def __init__(self, inplanes, planes, stride=1, downsample=None, r=64): | |||||
super(BottleneckSE, self).__init__() | |||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) | |||||
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) | |||||
self.conv2 = nn.Conv2d( | |||||
planes, | |||||
planes, | |||||
kernel_size=3, | |||||
stride=stride, | |||||
padding=1, | |||||
bias=False) | |||||
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) | |||||
self.conv3 = nn.Conv2d( | |||||
planes, planes * self.expansion, kernel_size=1, bias=False) | |||||
self.bn3 = nn.BatchNorm2d( | |||||
planes * self.expansion, momentum=BN_MOMENTUM) | |||||
self.relu = nn.ReLU(inplace=True) | |||||
self.downsample = downsample | |||||
self.stride = stride | |||||
self.se = SE_Block(planes * self.expansion, r) | |||||
def forward(self, x): | |||||
residual = x | |||||
out = self.conv1(x) | |||||
out = self.bn1(out) | |||||
out = self.relu(out) | |||||
out = self.conv2(out) | |||||
out = self.bn2(out) | |||||
out = self.relu(out) | |||||
out = self.conv3(out) | |||||
out = self.bn3(out) | |||||
out = self.se(out) | |||||
if self.downsample is not None: | |||||
residual = self.downsample(x) | |||||
out += residual | |||||
out = self.relu(out) | |||||
return out | |||||
blocks_dict = { | |||||
'BASIC': BasicBlock, | |||||
'BOTTLENECK': Bottleneck, | |||||
'BASICSE': BasicBlockSE, | |||||
'BOTTLENECKSE': BottleneckSE, | |||||
} |
@@ -0,0 +1,221 @@ | |||||
import os | |||||
import numpy as np | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
from modelscope.metainfo import Models | |||||
from modelscope.models.base.base_torch_model import TorchModel | |||||
from modelscope.models.builder import MODELS | |||||
from modelscope.models.cv.body_2d_keypoints.hrnet_basic_modules import ( | |||||
BN_MOMENTUM, BasicBlock, Bottleneck, HighResolutionModule, blocks_dict) | |||||
from modelscope.models.cv.body_2d_keypoints.w48 import cfg_128x128_15 | |||||
from modelscope.utils.constant import Tasks | |||||
@MODELS.register_module( | |||||
Tasks.body_2d_keypoints, module_name=Models.body_2d_keypoints) | |||||
class PoseHighResolutionNetV2(TorchModel): | |||||
def __init__(self, cfg=None, **kwargs): | |||||
if cfg is None: | |||||
cfg = cfg_128x128_15 | |||||
self.inplanes = 64 | |||||
extra = cfg['MODEL']['EXTRA'] | |||||
super(PoseHighResolutionNetV2, self).__init__(**kwargs) | |||||
# stem net | |||||
self.conv1 = nn.Conv2d( | |||||
3, 64, kernel_size=3, stride=2, padding=1, bias=False) | |||||
self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) | |||||
self.conv2 = nn.Conv2d( | |||||
64, 64, kernel_size=3, stride=2, padding=1, bias=False) | |||||
self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) | |||||
self.relu = nn.ReLU(inplace=True) | |||||
self.layer1 = self._make_layer(Bottleneck, 64, 4) | |||||
self.stage2_cfg = cfg['MODEL']['EXTRA']['STAGE2'] | |||||
num_channels = self.stage2_cfg['NUM_CHANNELS'] | |||||
block = blocks_dict[self.stage2_cfg['BLOCK']] | |||||
num_channels = [ | |||||
num_channels[i] * block.expansion | |||||
for i in range(len(num_channels)) | |||||
] | |||||
self.transition1 = self._make_transition_layer([256], num_channels) | |||||
self.stage2, pre_stage_channels = self._make_stage( | |||||
self.stage2_cfg, num_channels) | |||||
self.stage3_cfg = cfg['MODEL']['EXTRA']['STAGE3'] | |||||
num_channels = self.stage3_cfg['NUM_CHANNELS'] | |||||
block = blocks_dict[self.stage3_cfg['BLOCK']] | |||||
num_channels = [ | |||||
num_channels[i] * block.expansion | |||||
for i in range(len(num_channels)) | |||||
] | |||||
self.transition2 = self._make_transition_layer(pre_stage_channels, | |||||
num_channels) | |||||
self.stage3, pre_stage_channels = self._make_stage( | |||||
self.stage3_cfg, num_channels) | |||||
self.stage4_cfg = cfg['MODEL']['EXTRA']['STAGE4'] | |||||
num_channels = self.stage4_cfg['NUM_CHANNELS'] | |||||
block = blocks_dict[self.stage4_cfg['BLOCK']] | |||||
num_channels = [ | |||||
num_channels[i] * block.expansion | |||||
for i in range(len(num_channels)) | |||||
] | |||||
self.transition3 = self._make_transition_layer(pre_stage_channels, | |||||
num_channels) | |||||
self.stage4, pre_stage_channels = self._make_stage( | |||||
self.stage4_cfg, num_channels, multi_scale_output=True) | |||||
"""final four layers""" | |||||
last_inp_channels = np.int(np.sum(pre_stage_channels)) | |||||
self.final_layer = nn.Sequential( | |||||
nn.Conv2d( | |||||
in_channels=last_inp_channels, | |||||
out_channels=last_inp_channels, | |||||
kernel_size=1, | |||||
stride=1, | |||||
padding=0), | |||||
nn.BatchNorm2d(last_inp_channels, momentum=BN_MOMENTUM), | |||||
nn.ReLU(inplace=False), | |||||
nn.Conv2d( | |||||
in_channels=last_inp_channels, | |||||
out_channels=cfg['MODEL']['NUM_JOINTS'], | |||||
kernel_size=extra['FINAL_CONV_KERNEL'], | |||||
stride=1, | |||||
padding=1 if extra['FINAL_CONV_KERNEL'] == 3 else 0)) | |||||
self.pretrained_layers = cfg['MODEL']['EXTRA']['PRETRAINED_LAYERS'] | |||||
def _make_transition_layer(self, num_channels_pre_layer, | |||||
num_channels_cur_layer): | |||||
num_branches_cur = len(num_channels_cur_layer) | |||||
num_branches_pre = len(num_channels_pre_layer) | |||||
transition_layers = [] | |||||
for i in range(num_branches_cur): | |||||
if i < num_branches_pre: | |||||
if num_channels_cur_layer[i] != num_channels_pre_layer[i]: | |||||
transition_layers.append( | |||||
nn.Sequential( | |||||
nn.Conv2d( | |||||
num_channels_pre_layer[i], | |||||
num_channels_cur_layer[i], | |||||
3, | |||||
1, | |||||
1, | |||||
bias=False), | |||||
nn.BatchNorm2d(num_channels_cur_layer[i]), | |||||
nn.ReLU(inplace=True))) | |||||
else: | |||||
transition_layers.append(None) | |||||
else: | |||||
conv3x3s = [] | |||||
for j in range(i + 1 - num_branches_pre): | |||||
inchannels = num_channels_pre_layer[-1] | |||||
outchannels = num_channels_cur_layer[ | |||||
i] if j == i - num_branches_pre else inchannels | |||||
conv3x3s.append( | |||||
nn.Sequential( | |||||
nn.Conv2d( | |||||
inchannels, outchannels, 3, 2, 1, bias=False), | |||||
nn.BatchNorm2d(outchannels), | |||||
nn.ReLU(inplace=True))) | |||||
transition_layers.append(nn.Sequential(*conv3x3s)) | |||||
return nn.ModuleList(transition_layers) | |||||
def _make_layer(self, block, planes, blocks, stride=1): | |||||
downsample = None | |||||
if stride != 1 or self.inplanes != planes * block.expansion: | |||||
downsample = nn.Sequential( | |||||
nn.Conv2d( | |||||
self.inplanes, | |||||
planes * block.expansion, | |||||
kernel_size=1, | |||||
stride=stride, | |||||
bias=False), | |||||
nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), | |||||
) | |||||
layers = [] | |||||
layers.append(block(self.inplanes, planes, stride, downsample)) | |||||
self.inplanes = planes * block.expansion | |||||
for i in range(1, blocks): | |||||
layers.append(block(self.inplanes, planes)) | |||||
return nn.Sequential(*layers) | |||||
def _make_stage(self, | |||||
layer_config, | |||||
num_inchannels, | |||||
multi_scale_output=True): | |||||
num_modules = layer_config['NUM_MODULES'] | |||||
num_branches = layer_config['NUM_BRANCHES'] | |||||
num_blocks = layer_config['NUM_BLOCKS'] | |||||
num_channels = layer_config['NUM_CHANNELS'] | |||||
block = blocks_dict[layer_config['BLOCK']] | |||||
fuse_method = layer_config['FUSE_METHOD'] | |||||
modules = [] | |||||
for i in range(num_modules): | |||||
if not multi_scale_output and i == num_modules - 1: | |||||
reset_multi_scale_output = False | |||||
else: | |||||
reset_multi_scale_output = True | |||||
modules.append( | |||||
HighResolutionModule(num_branches, block, num_blocks, | |||||
num_inchannels, num_channels, fuse_method, | |||||
reset_multi_scale_output)) | |||||
num_inchannels = modules[-1].get_num_inchannels() | |||||
return nn.Sequential(*modules), num_inchannels | |||||
def forward(self, x): | |||||
x = self.conv1(x) | |||||
x = self.bn1(x) | |||||
x = self.relu(x) | |||||
x = self.conv2(x) | |||||
x = self.bn2(x) | |||||
x = self.relu(x) | |||||
x = self.layer1(x) | |||||
x_list = [] | |||||
for i in range(self.stage2_cfg['NUM_BRANCHES']): | |||||
if self.transition1[i] is not None: | |||||
x_list.append(self.transition1[i](x)) | |||||
else: | |||||
x_list.append(x) | |||||
y_list = self.stage2(x_list) | |||||
x_list = [] | |||||
for i in range(self.stage3_cfg['NUM_BRANCHES']): | |||||
if self.transition2[i] is not None: | |||||
x_list.append(self.transition2[i](y_list[-1])) | |||||
else: | |||||
x_list.append(y_list[i]) | |||||
y_list = self.stage3(x_list) | |||||
x_list = [] | |||||
for i in range(self.stage4_cfg['NUM_BRANCHES']): | |||||
if self.transition3[i] is not None: | |||||
x_list.append(self.transition3[i](y_list[-1])) | |||||
else: | |||||
x_list.append(y_list[i]) | |||||
y_list = self.stage4(x_list) | |||||
y0_h, y0_w = y_list[0].size(2), y_list[0].size(3) | |||||
y1 = F.upsample(y_list[1], size=(y0_h, y0_w), mode='bilinear') | |||||
y2 = F.upsample(y_list[2], size=(y0_h, y0_w), mode='bilinear') | |||||
y3 = F.upsample(y_list[3], size=(y0_h, y0_w), mode='bilinear') | |||||
y = torch.cat([y_list[0], y1, y2, y3], 1) | |||||
output = self.final_layer(y) | |||||
return output |
@@ -0,0 +1,51 @@ | |||||
cfg_128x128_15 = { | |||||
'DATASET': { | |||||
'TYPE': 'DAMO', | |||||
'PARENT_IDS': [0, 0, 1, 2, 3, 1, 5, 6, 14, 8, 9, 14, 11, 12, 1], | |||||
'LEFT_IDS': [2, 3, 4, 8, 9, 10], | |||||
'RIGHT_IDS': [5, 6, 7, 11, 12, 13], | |||||
'SPINE_IDS': [0, 1, 14] | |||||
}, | |||||
'MODEL': { | |||||
'INIT_WEIGHTS': True, | |||||
'NAME': 'pose_hrnet', | |||||
'NUM_JOINTS': 15, | |||||
'PRETRAINED': '', | |||||
'TARGET_TYPE': 'gaussian', | |||||
'IMAGE_SIZE': [128, 128], | |||||
'HEATMAP_SIZE': [32, 32], | |||||
'SIGMA': 2.0, | |||||
'EXTRA': { | |||||
'PRETRAINED_LAYERS': [ | |||||
'conv1', 'bn1', 'conv2', 'bn2', 'layer1', 'transition1', | |||||
'stage2', 'transition2', 'stage3', 'transition3', 'stage4' | |||||
], | |||||
'FINAL_CONV_KERNEL': | |||||
1, | |||||
'STAGE2': { | |||||
'NUM_MODULES': 1, | |||||
'NUM_BRANCHES': 2, | |||||
'BLOCK': 'BASIC', | |||||
'NUM_BLOCKS': [4, 4], | |||||
'NUM_CHANNELS': [48, 96], | |||||
'FUSE_METHOD': 'SUM' | |||||
}, | |||||
'STAGE3': { | |||||
'NUM_MODULES': 4, | |||||
'NUM_BRANCHES': 3, | |||||
'BLOCK': 'BASIC', | |||||
'NUM_BLOCKS': [4, 4, 4], | |||||
'NUM_CHANNELS': [48, 96, 192], | |||||
'FUSE_METHOD': 'SUM' | |||||
}, | |||||
'STAGE4': { | |||||
'NUM_MODULES': 3, | |||||
'NUM_BRANCHES': 4, | |||||
'BLOCK': 'BASIC', | |||||
'NUM_BLOCKS': [4, 4, 4, 4], | |||||
'NUM_CHANNELS': [48, 96, 192, 384], | |||||
'FUSE_METHOD': 'SUM' | |||||
}, | |||||
} | |||||
} | |||||
} |
@@ -158,6 +158,27 @@ TASK_OUTPUTS = { | |||||
# } | # } | ||||
Tasks.action_recognition: [OutputKeys.LABELS], | Tasks.action_recognition: [OutputKeys.LABELS], | ||||
# human body keypoints detection result for single sample | |||||
# { | |||||
# "poses": [ | |||||
# [x, y], | |||||
# [x, y], | |||||
# [x, y] | |||||
# ] | |||||
# "scores": [ | |||||
# [score], | |||||
# [score], | |||||
# [score], | |||||
# ] | |||||
# "boxes": [ | |||||
# [x1, y1, x2, y2], | |||||
# [x1, y1, x2, y2], | |||||
# [x1, y1, x2, y2], | |||||
# ] | |||||
# } | |||||
Tasks.body_2d_keypoints: | |||||
[OutputKeys.POSES, OutputKeys.SCORES, OutputKeys.BOXES], | |||||
# live category recognition result for single video | # live category recognition result for single video | ||||
# { | # { | ||||
# "scores": [0.885272, 0.014790631, 0.014558001], | # "scores": [0.885272, 0.014790631, 0.014558001], | ||||
@@ -87,6 +87,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
Tasks.text_to_image_synthesis: | Tasks.text_to_image_synthesis: | ||||
(Pipelines.text_to_image_synthesis, | (Pipelines.text_to_image_synthesis, | ||||
'damo/cv_diffusion_text-to-image-synthesis_tiny'), | 'damo/cv_diffusion_text-to-image-synthesis_tiny'), | ||||
Tasks.body_2d_keypoints: (Pipelines.body_2d_keypoints, | |||||
'damo/cv_hrnetv2w32_body-2d-keypoints_image'), | |||||
Tasks.face_detection: (Pipelines.face_detection, | Tasks.face_detection: (Pipelines.face_detection, | ||||
'damo/cv_resnet_facedetection_scrfd10gkps'), | 'damo/cv_resnet_facedetection_scrfd10gkps'), | ||||
Tasks.face_recognition: (Pipelines.face_recognition, | Tasks.face_recognition: (Pipelines.face_recognition, | ||||
@@ -238,6 +240,7 @@ def pipeline(task: str = None, | |||||
cfg = ConfigDict(type=pipeline_name, model=model) | cfg = ConfigDict(type=pipeline_name, model=model) | ||||
cfg.device = device | cfg.device = device | ||||
if kwargs: | if kwargs: | ||||
cfg.update(kwargs) | cfg.update(kwargs) | ||||
@@ -6,6 +6,7 @@ from modelscope.utils.import_utils import LazyImportModule | |||||
if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
from .action_recognition_pipeline import ActionRecognitionPipeline | from .action_recognition_pipeline import ActionRecognitionPipeline | ||||
from .animal_recognition_pipeline import AnimalRecognitionPipeline | from .animal_recognition_pipeline import AnimalRecognitionPipeline | ||||
from .body_2d_keypoints_pipeline import Body2DKeypointsPipeline | |||||
from .cmdssl_video_embedding_pipeline import CMDSSLVideoEmbeddingPipeline | from .cmdssl_video_embedding_pipeline import CMDSSLVideoEmbeddingPipeline | ||||
from .image_detection_pipeline import ImageDetectionPipeline | from .image_detection_pipeline import ImageDetectionPipeline | ||||
from .face_detection_pipeline import FaceDetectionPipeline | from .face_detection_pipeline import FaceDetectionPipeline | ||||
@@ -34,6 +35,7 @@ else: | |||||
_import_structure = { | _import_structure = { | ||||
'action_recognition_pipeline': ['ActionRecognitionPipeline'], | 'action_recognition_pipeline': ['ActionRecognitionPipeline'], | ||||
'animal_recognition_pipeline': ['AnimalRecognitionPipeline'], | 'animal_recognition_pipeline': ['AnimalRecognitionPipeline'], | ||||
'body_2d_keypoints_pipeline': ['Body2DKeypointsPipeline'], | |||||
'cmdssl_video_embedding_pipeline': ['CMDSSLVideoEmbeddingPipeline'], | 'cmdssl_video_embedding_pipeline': ['CMDSSLVideoEmbeddingPipeline'], | ||||
'image_detection_pipeline': ['ImageDetectionPipeline'], | 'image_detection_pipeline': ['ImageDetectionPipeline'], | ||||
'face_detection_pipeline': ['FaceDetectionPipeline'], | 'face_detection_pipeline': ['FaceDetectionPipeline'], | ||||
@@ -0,0 +1,261 @@ | |||||
import os.path as osp | |||||
from typing import Any, Dict, List, Union | |||||
import cv2 | |||||
import json | |||||
import numpy as np | |||||
import torch | |||||
from PIL import Image | |||||
from torchvision import transforms | |||||
from modelscope.metainfo import Pipelines | |||||
from modelscope.models.cv.body_2d_keypoints.hrnet_v2 import \ | |||||
PoseHighResolutionNetV2 | |||||
from modelscope.models.cv.body_2d_keypoints.w48 import cfg_128x128_15 | |||||
from modelscope.outputs import OutputKeys | |||||
from modelscope.pipelines import pipeline | |||||
from modelscope.pipelines.base import Input, Model, Pipeline, Tensor | |||||
from modelscope.pipelines.builder import PIPELINES | |||||
from modelscope.preprocessors import load_image | |||||
from modelscope.utils.constant import ModelFile, Tasks | |||||
from modelscope.utils.logger import get_logger | |||||
logger = get_logger() | |||||
@PIPELINES.register_module( | |||||
Tasks.body_2d_keypoints, module_name=Pipelines.body_2d_keypoints) | |||||
class Body2DKeypointsPipeline(Pipeline): | |||||
def __init__(self, model: str, human_detector: Pipeline, **kwargs): | |||||
super().__init__(model=model, **kwargs) | |||||
self.keypoint_model = KeypointsDetection(model) | |||||
self.keypoint_model.eval() | |||||
self.human_detector = human_detector | |||||
def preprocess(self, input: Input) -> Dict[Tensor, Union[str, np.ndarray]]: | |||||
output = self.human_detector(input) | |||||
if isinstance(input, str): | |||||
image = cv2.imread(input, -1)[:, :, 0:3] | |||||
elif isinstance(input, np.ndarray): | |||||
if len(input.shape) == 2: | |||||
image = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) | |||||
image = image[:, :, 0:3] | |||||
return {'image': image, 'output': output} | |||||
def forward(self, input: Tensor) -> Dict[Tensor, Dict[str, np.ndarray]]: | |||||
input_image = input['image'] | |||||
output = input['output'] | |||||
bboxes = [] | |||||
scores = np.array(output[OutputKeys.SCORES].cpu(), dtype=np.float32) | |||||
boxes = np.array(output[OutputKeys.BOXES].cpu(), dtype=np.float32) | |||||
for id, box in enumerate(boxes): | |||||
box_tmp = [ | |||||
box[0], box[1], box[2] - box[0], box[3] - box[1], scores[id], 0 | |||||
] | |||||
bboxes.append(box_tmp) | |||||
if len(bboxes) == 0: | |||||
logger.error('cannot detect human in the image') | |||||
return [None, None] | |||||
human_images, metas = self.keypoint_model.preprocess( | |||||
[bboxes, input_image]) | |||||
outputs = self.keypoint_model.forward(human_images) | |||||
return [outputs, metas] | |||||
def postprocess(self, input: Dict[Tensor, Dict[str, np.ndarray]], | |||||
**kwargs) -> str: | |||||
if input[0] is None or input[1] is None: | |||||
return { | |||||
OutputKeys.BOXES: [], | |||||
OutputKeys.POSES: [], | |||||
OutputKeys.SCORES: [] | |||||
} | |||||
poses, scores, boxes = self.keypoint_model.postprocess(input) | |||||
return { | |||||
OutputKeys.BOXES: boxes, | |||||
OutputKeys.POSES: poses, | |||||
OutputKeys.SCORES: scores | |||||
} | |||||
class KeypointsDetection(): | |||||
def __init__(self, model: str, **kwargs): | |||||
self.model = model | |||||
cfg = cfg_128x128_15 | |||||
self.key_points_model = PoseHighResolutionNetV2(cfg) | |||||
pretrained_state_dict = torch.load( | |||||
osp.join(self.model, ModelFile.TORCH_MODEL_FILE)) | |||||
self.key_points_model.load_state_dict( | |||||
pretrained_state_dict, strict=False) | |||||
self.input_size = cfg['MODEL']['IMAGE_SIZE'] | |||||
self.lst_parent_ids = cfg['DATASET']['PARENT_IDS'] | |||||
self.lst_left_ids = cfg['DATASET']['LEFT_IDS'] | |||||
self.lst_right_ids = cfg['DATASET']['RIGHT_IDS'] | |||||
self.box_enlarge_ratio = 0.05 | |||||
def train(self): | |||||
return self.key_points_model.train() | |||||
def eval(self): | |||||
return self.key_points_model.eval() | |||||
def forward(self, input: Tensor) -> Tensor: | |||||
with torch.no_grad(): | |||||
return self.key_points_model.forward(input) | |||||
def get_pts(self, heatmaps): | |||||
[pts_num, height, width] = heatmaps.shape | |||||
pts = [] | |||||
scores = [] | |||||
for i in range(pts_num): | |||||
heatmap = heatmaps[i, :, :] | |||||
pt = np.where(heatmap == np.max(heatmap)) | |||||
scores.append(np.max(heatmap)) | |||||
x = pt[1][0] | |||||
y = pt[0][0] | |||||
[h, w] = heatmap.shape | |||||
if x >= 1 and x <= w - 2 and y >= 1 and y <= h - 2: | |||||
x_diff = heatmap[y, x + 1] - heatmap[y, x - 1] | |||||
y_diff = heatmap[y + 1, x] - heatmap[y - 1, x] | |||||
x_sign = 0 | |||||
y_sign = 0 | |||||
if x_diff < 0: | |||||
x_sign = -1 | |||||
if x_diff > 0: | |||||
x_sign = 1 | |||||
if y_diff < 0: | |||||
y_sign = -1 | |||||
if y_diff > 0: | |||||
y_sign = 1 | |||||
x = x + x_sign * 0.25 | |||||
y = y + y_sign * 0.25 | |||||
pts.append([x, y]) | |||||
return pts, scores | |||||
def pts_transform(self, meta, pts, lt_x, lt_y): | |||||
pts_new = [] | |||||
s = meta['s'] | |||||
o = meta['o'] | |||||
size = len(pts) | |||||
for i in range(size): | |||||
ratio = 4 | |||||
x = (int(pts[i][0] * ratio) - o[0]) / s[0] | |||||
y = (int(pts[i][1] * ratio) - o[1]) / s[1] | |||||
pt = [x, y] | |||||
pts_new.append(pt) | |||||
return pts_new | |||||
def postprocess(self, inputs: Dict[Tensor, Dict[str, np.ndarray]], | |||||
**kwargs): | |||||
output_poses = [] | |||||
output_scores = [] | |||||
output_boxes = [] | |||||
for i in range(inputs[0].shape[0]): | |||||
outputs, scores = self.get_pts( | |||||
(inputs[0][i]).detach().cpu().numpy()) | |||||
outputs = self.pts_transform(inputs[1][i], outputs, 0, 0) | |||||
box = np.array(inputs[1][i]['human_box'][0:4]).reshape(2, 2) | |||||
outputs = np.array(outputs) + box[0] | |||||
output_poses.append(outputs.tolist()) | |||||
output_scores.append(scores) | |||||
output_boxes.append(box.tolist()) | |||||
return output_poses, output_scores, output_boxes | |||||
def image_crop_resize(self, input, margin=[0, 0]): | |||||
pad_img = np.zeros((self.input_size[1], self.input_size[0], 3), | |||||
dtype=np.uint8) | |||||
h, w, ch = input.shape | |||||
h_new = self.input_size[1] - margin[1] * 2 | |||||
w_new = self.input_size[0] - margin[0] * 2 | |||||
s0 = float(h_new) / h | |||||
s1 = float(w_new) / w | |||||
s = min(s0, s1) | |||||
w_new = int(s * w) | |||||
h_new = int(s * h) | |||||
img_new = cv2.resize(input, (w_new, h_new), cv2.INTER_LINEAR) | |||||
cx = self.input_size[0] // 2 | |||||
cy = self.input_size[1] // 2 | |||||
pad_img[cy - h_new // 2:cy - h_new // 2 + h_new, | |||||
cx - w_new // 2:cx - w_new // 2 + w_new, :] = img_new | |||||
return pad_img, np.array([cx, cy]), np.array([s, s]), np.array( | |||||
[cx - w_new // 2, cy - h_new // 2]) | |||||
def image_transform(self, input: Input) -> Dict[Tensor, Any]: | |||||
if isinstance(input, str): | |||||
image = cv2.imread(input, -1)[:, :, 0:3] | |||||
elif isinstance(input, np.ndarray): | |||||
if len(input.shape) == 2: | |||||
image = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) | |||||
else: | |||||
image = input | |||||
image = image[:, :, 0:3] | |||||
elif isinstance(input, torch.Tensor): | |||||
image = input.cpu().numpy()[:, :, 0:3] | |||||
w, h, _ = image.shape | |||||
w_new = self.input_size[0] | |||||
h_new = self.input_size[1] | |||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |||||
img_resize, c, s, o = self.image_crop_resize(image) | |||||
img_resize = np.float32(img_resize) / 255. | |||||
mean = [0.485, 0.456, 0.406] | |||||
std = [0.229, 0.224, 0.225] | |||||
img_resize = (img_resize - mean) / std | |||||
input_data = np.zeros([1, 3, h_new, w_new], dtype=np.float32) | |||||
img_resize = img_resize.transpose((2, 0, 1)) | |||||
input_data[0, :] = img_resize | |||||
meta = {'c': c, 's': s, 'o': o} | |||||
return [torch.from_numpy(input_data), meta] | |||||
def crop_image(self, image, box): | |||||
height, width, _ = image.shape | |||||
w, h = box[1] - box[0] | |||||
box[0, :] -= (w * self.box_enlarge_ratio, h * self.box_enlarge_ratio) | |||||
box[1, :] += (w * self.box_enlarge_ratio, h * self.box_enlarge_ratio) | |||||
box[0, 0] = min(max(box[0, 0], 0.0), width) | |||||
box[0, 1] = min(max(box[0, 1], 0.0), height) | |||||
box[1, 0] = min(max(box[1, 0], 0.0), width) | |||||
box[1, 1] = min(max(box[1, 1], 0.0), height) | |||||
cropped_image = image[int(box[0][1]):int(box[1][1]), | |||||
int(box[0][0]):int(box[1][0])] | |||||
return cropped_image | |||||
def preprocess(self, input: Dict[Tensor, Tensor]) -> Dict[Tensor, Any]: | |||||
bboxes = input[0] | |||||
image = input[1] | |||||
lst_human_images = [] | |||||
lst_meta = [] | |||||
for i in range(len(bboxes)): | |||||
box = np.array(bboxes[i][0:4]).reshape(2, 2) | |||||
box[1] += box[0] | |||||
human_image = self.crop_image(image.clone(), box) | |||||
human_image, meta = self.image_transform(human_image) | |||||
lst_human_images.append(human_image) | |||||
meta['human_box'] = box | |||||
lst_meta.append(meta) | |||||
return [torch.cat(lst_human_images, dim=0), lst_meta] |
@@ -22,6 +22,7 @@ class CVTasks(object): | |||||
human_detection = 'human-detection' | human_detection = 'human-detection' | ||||
human_object_interaction = 'human-object-interaction' | human_object_interaction = 'human-object-interaction' | ||||
face_image_generation = 'face-image-generation' | face_image_generation = 'face-image-generation' | ||||
body_2d_keypoints = 'body-2d-keypoints' | |||||
image_classification = 'image-classification' | image_classification = 'image-classification' | ||||
image_multilabel_classification = 'image-multilabel-classification' | image_multilabel_classification = 'image-multilabel-classification' | ||||
@@ -0,0 +1,100 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import os | |||||
import os.path as osp | |||||
import pdb | |||||
import unittest | |||||
import cv2 | |||||
import numpy as np | |||||
import torch | |||||
from modelscope.outputs import OutputKeys | |||||
from modelscope.pipelines import pipeline | |||||
from modelscope.pipelines.base import Pipeline | |||||
from modelscope.utils.constant import Tasks | |||||
from modelscope.utils.test_utils import test_level | |||||
lst_parent_ids_17 = [0, 0, 0, 1, 2, 0, 0, 5, 6, 7, 8, 5, 6, 11, 12, 13, 14] | |||||
lst_left_ids_17 = [1, 3, 5, 7, 9, 11, 13, 15] | |||||
lst_right_ids_17 = [2, 4, 6, 8, 10, 12, 14, 16] | |||||
lst_spine_ids_17 = [0] | |||||
lst_parent_ids_15 = [0, 0, 1, 2, 3, 1, 5, 6, 14, 8, 9, 14, 11, 12, 1] | |||||
lst_left_ids_15 = [2, 3, 4, 8, 9, 10] | |||||
lst_right_ids_15 = [5, 6, 7, 11, 12, 13] | |||||
lst_spine_ids_15 = [0, 1, 14] | |||||
def draw_joints(image, np_kps, score, threshold=0.2): | |||||
if np_kps.shape[0] == 17: | |||||
lst_parent_ids = lst_parent_ids_17 | |||||
lst_left_ids = lst_left_ids_17 | |||||
lst_right_ids = lst_right_ids_17 | |||||
elif np_kps.shape[0] == 15: | |||||
lst_parent_ids = lst_parent_ids_15 | |||||
lst_left_ids = lst_left_ids_15 | |||||
lst_right_ids = lst_right_ids_15 | |||||
for i in range(len(lst_parent_ids)): | |||||
pid = lst_parent_ids[i] | |||||
if i == pid: | |||||
continue | |||||
if (score[i] < threshold or score[1] < threshold): | |||||
continue | |||||
if i in lst_left_ids and pid in lst_left_ids: | |||||
color = (0, 255, 0) | |||||
elif i in lst_right_ids and pid in lst_right_ids: | |||||
color = (255, 0, 0) | |||||
else: | |||||
color = (0, 255, 255) | |||||
cv2.line(image, (int(np_kps[i, 0]), int(np_kps[i, 1])), | |||||
(int(np_kps[pid][0]), int(np_kps[pid, 1])), color, 3) | |||||
for i in range(np_kps.shape[0]): | |||||
if score[i] < threshold: | |||||
continue | |||||
cv2.circle(image, (int(np_kps[i, 0]), int(np_kps[i, 1])), 5, | |||||
(0, 0, 255), -1) | |||||
def draw_box(image, box): | |||||
cv2.rectangle(image, (int(box[0][0]), int(box[0][1])), | |||||
(int(box[1][0]), int(box[1][1])), (0, 0, 255), 2) | |||||
class Body2DKeypointsTest(unittest.TestCase): | |||||
def setUp(self) -> None: | |||||
self.model_id = 'damo/cv_hrnetv2w32_body-2d-keypoints_image' | |||||
self.test_image = 'data/test/images/keypoints_detect/000000438862.jpg' | |||||
self.human_detect_model_id = 'damo/cv_resnet18_human-detection' | |||||
def pipeline_inference(self, pipeline: Pipeline): | |||||
output = pipeline(self.test_image) | |||||
poses = np.array(output[OutputKeys.POSES]) | |||||
scores = np.array(output[OutputKeys.SCORES]) | |||||
boxes = np.array(output[OutputKeys.BOXES]) | |||||
assert len(poses) == len(scores) and len(poses) == len(boxes) | |||||
image = cv2.imread(self.test_image, -1) | |||||
for i in range(len(poses)): | |||||
draw_box(image, np.array(boxes[i])) | |||||
draw_joints(image, np.array(poses[i]), np.array(scores[i])) | |||||
cv2.imwrite('pose_keypoint.jpg', image) | |||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
def test_run_modelhub(self): | |||||
human_detector = pipeline( | |||||
Tasks.human_detection, model=self.human_detect_model_id) | |||||
body_2d_keypoints = pipeline( | |||||
Tasks.body_2d_keypoints, | |||||
human_detector=human_detector, | |||||
model=self.model_id) | |||||
self.pipeline_inference(body_2d_keypoints) | |||||
if __name__ == '__main__': | |||||
unittest.main() |