Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10217723 * add image_body_reshaping codemaster
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:b2c1119e3d521cf2e583b1e85fc9c9afd1d44954b433135039a98050a730932d | |||
size 1127557 |
@@ -43,6 +43,7 @@ class Models(object): | |||
face_human_hand_detection = 'face-human-hand-detection' | |||
face_emotion = 'face-emotion' | |||
product_segmentation = 'product-segmentation' | |||
image_body_reshaping = 'image-body-reshaping' | |||
# EasyCV models | |||
yolox = 'YOLOX' | |||
@@ -187,6 +188,7 @@ class Pipelines(object): | |||
face_human_hand_detection = 'face-human-hand-detection' | |||
face_emotion = 'face-emotion' | |||
product_segmentation = 'product-segmentation' | |||
image_body_reshaping = 'flow-based-body-reshaping' | |||
# nlp tasks | |||
automatic_post_editing = 'automatic-post-editing' | |||
@@ -0,0 +1,20 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .image_body_reshaping import ImageBodyReshaping | |||
else: | |||
_import_structure = {'image_body_reshaping': ['ImageBodyReshaping']} | |||
import sys | |||
sys.modules[__name__] = LazyImportModule( | |||
__name__, | |||
globals()['__file__'], | |||
_import_structure, | |||
module_spec=__spec__, | |||
extra_objects={}, | |||
) |
@@ -0,0 +1,128 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
from typing import Any, Dict | |||
import cv2 | |||
import numpy as np | |||
import torch | |||
from modelscope.metainfo import Models | |||
from modelscope.models.base import Tensor, TorchModel | |||
from modelscope.models.builder import MODELS | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.logger import get_logger | |||
from .model import FlowGenerator | |||
from .person_info import PersonInfo | |||
from .pose_estimator.body import Body | |||
from .slim_utils import image_warp_grid1, resize_on_long_side | |||
logger = get_logger() | |||
__all__ = ['ImageBodyReshaping'] | |||
@MODELS.register_module( | |||
Tasks.image_body_reshaping, module_name=Models.image_body_reshaping) | |||
class ImageBodyReshaping(TorchModel): | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
"""initialize the image body reshaping model from the `model_dir` path. | |||
Args: | |||
model_dir (str): the model path. | |||
""" | |||
super().__init__(model_dir, *args, **kwargs) | |||
if torch.cuda.is_available(): | |||
self.device = torch.device('cuda') | |||
else: | |||
self.device = torch.device('cpu') | |||
self.degree = 1.0 | |||
self.reshape_model = FlowGenerator(n_channels=16).to(self.device) | |||
model_path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE) | |||
checkpoints = torch.load(model_path, map_location=torch.device('cpu')) | |||
self.reshape_model.load_state_dict( | |||
checkpoints['state_dict'], strict=True) | |||
self.reshape_model.eval() | |||
logger.info('load body reshaping model done') | |||
pose_model_ckpt = os.path.join(model_dir, 'body_pose_model.pth') | |||
self.pose_esti = Body(pose_model_ckpt, self.device) | |||
logger.info('load pose model done') | |||
def pred_joints(self, img): | |||
if img is None: | |||
return None | |||
small_src, resize_scale = resize_on_long_side(img, 300) | |||
body_joints = self.pose_esti(small_src) | |||
if body_joints.shape[0] >= 1: | |||
body_joints[:, :, :2] = body_joints[:, :, :2] / resize_scale | |||
return body_joints | |||
def pred_flow(self, img): | |||
body_joints = self.pred_joints(img) | |||
small_size = 1200 | |||
if img.shape[0] > small_size or img.shape[1] > small_size: | |||
_img, _scale = resize_on_long_side(img, small_size) | |||
body_joints[:, :, :2] = body_joints[:, :, :2] * _scale | |||
else: | |||
_img = img | |||
# We only reshape one person | |||
if body_joints.shape[0] < 1 or body_joints.shape[0] > 1: | |||
return None | |||
person = PersonInfo(body_joints[0]) | |||
with torch.no_grad(): | |||
person_pred = person.pred_flow(_img, self.reshape_model, | |||
self.device) | |||
flow = np.dstack((person_pred['rDx'], person_pred['rDy'])) | |||
scale = img.shape[0] * 1.0 / flow.shape[0] | |||
flow = cv2.resize(flow, (img.shape[1], img.shape[0])) | |||
flow *= scale | |||
return flow | |||
def warp(self, src_img, flow): | |||
X_flow = flow[..., 0] | |||
Y_flow = flow[..., 1] | |||
X_flow = np.ascontiguousarray(X_flow) | |||
Y_flow = np.ascontiguousarray(Y_flow) | |||
pred = image_warp_grid1(X_flow, Y_flow, src_img, 1.0, 0, 0) | |||
return pred | |||
def inference(self, img): | |||
img = img.cpu().numpy() | |||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |||
flow = self.pred_flow(img) | |||
if flow is None: | |||
return img | |||
assert flow.shape[:2] == img.shape[:2] | |||
mag, ang = cv2.cartToPolar(flow[..., 0] + 1e-8, flow[..., 1] + 1e-8) | |||
mag -= 3 | |||
mag[mag <= 0] = 0 | |||
x, y = cv2.polarToCart(mag, ang, angleInDegrees=False) | |||
flow = np.dstack((x, y)) | |||
flow *= self.degree | |||
pred = self.warp(img, flow) | |||
out_img = np.clip(pred, 0, 255) | |||
logger.info('model inference done') | |||
return out_img.astype(np.uint8) |
@@ -0,0 +1,189 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
class ConvLayer(nn.Module): | |||
def __init__(self, in_ch, out_ch): | |||
super(ConvLayer, self).__init__() | |||
self.conv = nn.Sequential( | |||
nn.ReflectionPad2d(1), | |||
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=0), | |||
nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True)) | |||
def forward(self, x): | |||
x = self.conv(x) | |||
return x | |||
class SASA(nn.Module): | |||
def __init__(self, in_dim): | |||
super(SASA, self).__init__() | |||
self.chanel_in = in_dim | |||
self.query_conv = nn.Conv2d( | |||
in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) | |||
self.key_conv = nn.Conv2d( | |||
in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) | |||
self.value_conv = nn.Conv2d( | |||
in_channels=in_dim, out_channels=in_dim, kernel_size=1) | |||
self.mag_conv = nn.Conv2d( | |||
in_channels=5, out_channels=in_dim // 32, kernel_size=1) | |||
self.gamma = nn.Parameter(torch.zeros(1)) | |||
self.softmax = nn.Softmax(dim=-1) # | |||
self.sigmoid = nn.Sigmoid() | |||
def structure_encoder(self, paf_mag, target_height, target_width): | |||
torso_mask = torch.sum(paf_mag[:, 1:3, :, :], dim=1, keepdim=True) | |||
torso_mask = torch.clamp(torso_mask, 0, 1) | |||
arms_mask = torch.sum(paf_mag[:, 4:8, :, :], dim=1, keepdim=True) | |||
arms_mask = torch.clamp(arms_mask, 0, 1) | |||
legs_mask = torch.sum(paf_mag[:, 8:12, :, :], dim=1, keepdim=True) | |||
legs_mask = torch.clamp(legs_mask, 0, 1) | |||
fg_mask = paf_mag[:, 12, :, :].unsqueeze(1) | |||
bg_mask = 1 - fg_mask | |||
Y = torch.cat((arms_mask, torso_mask, legs_mask, fg_mask, bg_mask), | |||
dim=1) | |||
Y = F.interpolate(Y, size=(target_height, target_width), mode='area') | |||
return Y | |||
def forward(self, X, PAF_mag): | |||
"""extract self-attention features. | |||
Args: | |||
X : input feature maps( B x C x H x W) | |||
PAF_mag : ( B x C x H x W), 1 denotes connectivity, 0 denotes non-connectivity | |||
Returns: | |||
out : self attention value + input feature | |||
Y: B X N X N (N is Width*Height) | |||
""" | |||
m_batchsize, C, height, width = X.size() | |||
Y = self.structure_encoder(PAF_mag, height, width) | |||
connectivity_mask_vec = self.mag_conv(Y).view(m_batchsize, -1, | |||
width * height) | |||
affinity = torch.bmm( | |||
connectivity_mask_vec.permute(0, 2, 1), connectivity_mask_vec) | |||
affinity_centered = affinity - torch.mean(affinity) | |||
affinity_sigmoid = self.sigmoid(affinity_centered) | |||
proj_query = self.query_conv(X).view(m_batchsize, -1, | |||
width * height).permute(0, 2, 1) | |||
proj_key = self.key_conv(X).view(m_batchsize, -1, width * height) | |||
selfatten_map = torch.bmm(proj_query, proj_key) | |||
selfatten_centered = selfatten_map - torch.mean( | |||
selfatten_map) # centering | |||
selfatten_sigmoid = self.sigmoid(selfatten_centered) | |||
SASA_map = selfatten_sigmoid * affinity_sigmoid | |||
proj_value = self.value_conv(X).view(m_batchsize, -1, width * height) | |||
out = torch.bmm(proj_value, SASA_map.permute(0, 2, 1)) | |||
out = out.view(m_batchsize, C, height, width) | |||
out = self.gamma * out + X | |||
return out, Y | |||
class FlowGenerator(nn.Module): | |||
def __init__(self, n_channels, deep_supervision=False): | |||
super(FlowGenerator, self).__init__() | |||
self.deep_supervision = deep_supervision | |||
self.Encoder = nn.Sequential( | |||
ConvLayer(n_channels, 64), | |||
ConvLayer(64, 64), | |||
nn.MaxPool2d(2), | |||
ConvLayer(64, 128), | |||
ConvLayer(128, 128), | |||
nn.MaxPool2d(2), | |||
ConvLayer(128, 256), | |||
ConvLayer(256, 256), | |||
nn.MaxPool2d(2), | |||
ConvLayer(256, 512), | |||
ConvLayer(512, 512), | |||
nn.MaxPool2d(2), | |||
ConvLayer(512, 1024), | |||
ConvLayer(1024, 1024), | |||
ConvLayer(1024, 1024), | |||
ConvLayer(1024, 1024), | |||
ConvLayer(1024, 1024), | |||
) | |||
self.SASA = SASA(in_dim=1024) | |||
self.Decoder = nn.Sequential( | |||
ConvLayer(1024, 1024), | |||
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), | |||
ConvLayer(1024, 512), | |||
ConvLayer(512, 512), | |||
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), | |||
ConvLayer(512, 256), | |||
ConvLayer(256, 256), | |||
ConvLayer(256, 128), | |||
ConvLayer(128, 64), | |||
ConvLayer(64, 32), | |||
nn.Conv2d(32, 2, kernel_size=1, padding=0), | |||
nn.Tanh(), | |||
nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True), | |||
) | |||
dilation_ksize = 17 | |||
self.dilation = torch.nn.MaxPool2d( | |||
kernel_size=dilation_ksize, | |||
stride=1, | |||
padding=int((dilation_ksize - 1) / 2)) | |||
def warp(self, x, flow, mode='bilinear', padding_mode='zeros', coff=0.2): | |||
n, c, h, w = x.size() | |||
yv, xv = torch.meshgrid([torch.arange(h), torch.arange(w)]) | |||
xv = xv.float() / (w - 1) * 2.0 - 1 | |||
yv = yv.float() / (h - 1) * 2.0 - 1 | |||
grid = torch.cat((xv.unsqueeze(-1), yv.unsqueeze(-1)), -1).unsqueeze(0) | |||
grid = grid.to(flow.device) | |||
grid_x = grid + 2 * flow * coff | |||
warp_x = F.grid_sample(x, grid_x, mode=mode, padding_mode=padding_mode) | |||
return warp_x | |||
def forward(self, img, skeleton_map, coef=0.2): | |||
"""extract self-attention features. | |||
Args: | |||
img : input numpy image | |||
skeleton_map : skeleton map of input image | |||
coef: warp degree | |||
Returns: | |||
warp_x : warped image | |||
flow: predicted flow | |||
""" | |||
img_concat = torch.cat((img, skeleton_map), dim=1) | |||
X = self.Encoder(img_concat) | |||
_, _, height, width = X.size() | |||
# directly get PAF magnitude from skeleton maps via dilation | |||
PAF_mag = self.dilation((skeleton_map + 1.0) * 0.5) | |||
out, Y = self.SASA(X, PAF_mag) | |||
flow = self.Decoder(out) | |||
flow = flow.permute(0, 2, 3, 1) # [n, 2, h, w] ==> [n, h, w, 2] | |||
warp_x = self.warp(img, flow, coff=coef) | |||
warp_x = torch.clamp(warp_x, min=-1.0, max=1.0) | |||
return warp_x, flow |
@@ -0,0 +1,339 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import copy | |||
import cv2 | |||
import numpy as np | |||
import torch | |||
from .slim_utils import (enlarge_box_tblr, gen_skeleton_map, | |||
get_map_fusion_map_cuda, get_mask_bbox, | |||
resize_on_long_side) | |||
class PersonInfo(object): | |||
def __init__(self, joints): | |||
self.joints = joints | |||
self.flow = None | |||
self.pad_boder = False | |||
self.height_expand = 0 | |||
self.width_expand = 0 | |||
self.coeff = 0.2 | |||
self.network_input_W = 256 | |||
self.network_input_H = 256 | |||
self.divider = 20 | |||
self.flow_scales = ['upper_2'] | |||
def update_attribute(self, pad_boder, height_expand, width_expand): | |||
self.pad_boder = pad_boder | |||
self.height_expand = height_expand | |||
self.width_expand = width_expand | |||
if pad_boder: | |||
self.joints[:, 0] += width_expand | |||
self.joints[:, 1] += height_expand | |||
def pred_flow(self, img, flow_net, device): | |||
with torch.no_grad(): | |||
if img is None: | |||
print('image is none') | |||
self.flow = None | |||
if len(img.shape) == 2: | |||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |||
if self.pad_boder: | |||
height_expand = self.height_expand | |||
width_expand = self.width_expand | |||
pad_img = cv2.copyMakeBorder( | |||
img, | |||
height_expand, | |||
height_expand, | |||
width_expand, | |||
width_expand, | |||
cv2.BORDER_CONSTANT, | |||
value=(127, 127, 127)) | |||
else: | |||
height_expand = 0 | |||
width_expand = 0 | |||
pad_img = img.copy() | |||
canvas = np.zeros( | |||
shape=(pad_img.shape[0], pad_img.shape[1]), dtype=np.float32) | |||
self.human_joint_box = self.__joint_to_body_box() | |||
self.human_box = enlarge_box_tblr( | |||
self.human_joint_box, pad_img, ratio=0.25) | |||
human_box_height = self.human_box[1] - self.human_box[0] | |||
human_box_width = self.human_box[3] - self.human_box[2] | |||
self.leg_joint_box = self.__joint_to_leg_box() | |||
self.leg_box = enlarge_box_tblr( | |||
self.leg_joint_box, pad_img, ratio=0.25) | |||
self.arm_joint_box = self.__joint_to_arm_box() | |||
self.arm_box = enlarge_box_tblr( | |||
self.arm_joint_box, pad_img, ratio=0.1) | |||
x_flows = [] | |||
y_flows = [] | |||
multi_bbox = [] | |||
for scale in self.flow_scales: # better for metric | |||
scale_value = float(scale.split('_')[-1]) | |||
arm_box = copy.deepcopy(self.arm_box) | |||
if arm_box[0] is None: | |||
arm_box = self.human_box | |||
arm_box_height = arm_box[1] - arm_box[0] | |||
arm_box_width = arm_box[3] - arm_box[2] | |||
roi_bbox = None | |||
if arm_box_width < human_box_width * 0.1 or arm_box_height < human_box_height * 0.1: | |||
roi_bbox = self.human_box | |||
else: | |||
arm_box = enlarge_box_tblr( | |||
arm_box, pad_img, ratio=scale_value) | |||
if scale == 'upper_0.2': | |||
arm_box[0] = min(arm_box[0], int(self.joints[0][1])) | |||
if scale.startswith('upper'): | |||
roi_bbox = [ | |||
max(self.human_box[0], arm_box[0]), | |||
min(self.human_box[1], arm_box[1]), | |||
max(self.human_box[2], arm_box[2]), | |||
min(self.human_box[3], arm_box[3]) | |||
] | |||
if roi_bbox[1] - roi_bbox[0] < 1 or roi_bbox[ | |||
3] - roi_bbox[2] < 1: | |||
continue | |||
elif scale.startswith('lower'): | |||
roi_bbox = [ | |||
max(self.human_box[0], self.leg_box[0]), | |||
min(self.human_box[1], self.leg_box[1]), | |||
max(self.human_box[2], self.leg_box[2]), | |||
min(self.human_box[3], self.leg_box[3]) | |||
] | |||
if roi_bbox[1] - roi_bbox[0] < 1 or roi_bbox[ | |||
3] - roi_bbox[2] < 1: | |||
continue | |||
skel_map, roi_bbox = gen_skeleton_map( | |||
self.joints, 'depth', input_roi_box=roi_bbox) | |||
if roi_bbox is None: | |||
continue | |||
if skel_map.dtype != np.float32: | |||
skel_map = skel_map.astype(np.float32) | |||
skel_map -= 1.0 # [0,2] ->[-1,1] | |||
multi_bbox.append(roi_bbox) | |||
roi_bbox_height = roi_bbox[1] - roi_bbox[0] | |||
roi_bbox_width = roi_bbox[3] - roi_bbox[2] | |||
assert skel_map.shape[0] == roi_bbox_height | |||
assert skel_map.shape[1] == roi_bbox_width | |||
roi_height_pad = roi_bbox_height // self.divider | |||
roi_width_pad = roi_bbox_width // self.divider | |||
paded_roi_h = roi_bbox_height + 2 * roi_height_pad | |||
paded_roi_w = roi_bbox_width + 2 * roi_width_pad | |||
roi_height_pad_joint = skel_map.shape[0] // self.divider | |||
roi_width_pad_joint = skel_map.shape[1] // self.divider | |||
skel_map = np.pad( | |||
skel_map, | |||
((roi_height_pad_joint, roi_height_pad_joint), | |||
(roi_width_pad_joint, roi_width_pad_joint), (0, 0)), | |||
'constant', | |||
constant_values=-1) | |||
skel_map_resized = cv2.resize( | |||
skel_map, (self.network_input_W, self.network_input_H)) | |||
skel_map_resized[skel_map_resized < 0] = -1.0 | |||
skel_map_resized[skel_map_resized > -0.5] = 1.0 | |||
skel_map_transformed = torch.from_numpy( | |||
skel_map_resized.transpose((2, 0, 1))) | |||
roi_npy = pad_img[roi_bbox[0]:roi_bbox[1], | |||
roi_bbox[2]:roi_bbox[3], :].copy() | |||
if roi_npy.dtype != np.float32: | |||
roi_npy = roi_npy.astype(np.float32) | |||
roi_npy = np.pad(roi_npy, | |||
((roi_height_pad, roi_height_pad), | |||
(roi_width_pad, roi_width_pad), (0, 0)), | |||
'edge') | |||
roi_npy = roi_npy[:, :, ::-1] | |||
roi_npy = cv2.resize( | |||
roi_npy, (self.network_input_W, self.network_input_H)) | |||
roi_npy *= 1.0 / 255 | |||
roi_npy -= 0.5 | |||
roi_npy *= 2 | |||
rgb_tensor = torch.from_numpy(roi_npy.transpose((2, 0, 1))) | |||
rgb_tensor = rgb_tensor.unsqueeze(0).to(device) | |||
skel_map_tensor = skel_map_transformed.unsqueeze(0).to(device) | |||
warped_img_val, flow_field_val = flow_net( | |||
rgb_tensor, skel_map_tensor | |||
) # inference, connectivity_mask [1,12,16,16] | |||
flow_field_val = flow_field_val.detach().squeeze().cpu().numpy( | |||
) | |||
flow_field_val = cv2.resize( | |||
flow_field_val, (paded_roi_w, paded_roi_h), | |||
interpolation=cv2.INTER_LINEAR) | |||
flow_field_val[..., 0] = flow_field_val[ | |||
..., 0] * paded_roi_w * 0.5 * 2 * self.coeff | |||
flow_field_val[..., 1] = flow_field_val[ | |||
..., 1] * paded_roi_h * 0.5 * 2 * self.coeff | |||
# remove pad areas | |||
flow_field_val = flow_field_val[ | |||
roi_height_pad:flow_field_val.shape[0] - roi_height_pad, | |||
roi_width_pad:flow_field_val.shape[1] - roi_width_pad, :] | |||
diffuse_width = max(roi_bbox_width // 3, 1) | |||
diffuse_height = max(roi_bbox_height // 3, 1) | |||
assert roi_bbox_width == flow_field_val.shape[1] | |||
assert roi_bbox_height == flow_field_val.shape[0] | |||
origin_flow = np.zeros( | |||
(pad_img.shape[0] + 2 * diffuse_height, | |||
pad_img.shape[1] + 2 * diffuse_width, 2), | |||
dtype=np.float32) | |||
flow_field_val = np.pad(flow_field_val, | |||
((diffuse_height, diffuse_height), | |||
(diffuse_width, diffuse_width), | |||
(0, 0)), 'linear_ramp') | |||
origin_flow[roi_bbox[0]:roi_bbox[1] + 2 * diffuse_height, | |||
roi_bbox[2]:roi_bbox[3] | |||
+ 2 * diffuse_width] = flow_field_val | |||
origin_flow = origin_flow[diffuse_height:-diffuse_height, | |||
diffuse_width:-diffuse_width, :] | |||
x_flows.append(origin_flow[..., 0]) | |||
y_flows.append(origin_flow[..., 1]) | |||
if len(x_flows) == 0: | |||
return { | |||
'rDx': np.zeros(canvas.shape[:2], dtype=np.float32), | |||
'rDy': np.zeros(canvas.shape[:2], dtype=np.float32), | |||
'multi_bbox': multi_bbox, | |||
'x_fusion_map': | |||
np.ones(canvas.shape[:2], dtype=np.float32), | |||
'y_fusion_map': | |||
np.ones(canvas.shape[:2], dtype=np.float32) | |||
} | |||
else: | |||
origin_rDx, origin_rDy, x_fusion_map, y_fusion_map = self.blend_multiscale_flow( | |||
x_flows, y_flows, device=device) | |||
return { | |||
'rDx': origin_rDx, | |||
'rDy': origin_rDy, | |||
'multi_bbox': multi_bbox, | |||
'x_fusion_map': x_fusion_map, | |||
'y_fusion_map': y_fusion_map | |||
} | |||
@staticmethod | |||
def blend_multiscale_flow(x_flows, y_flows, device=None): | |||
scale_num = len(x_flows) | |||
if scale_num == 1: | |||
return x_flows[0], y_flows[0], np.ones_like( | |||
x_flows[0]), np.ones_like(x_flows[0]) | |||
origin_rDx = np.zeros((x_flows[0].shape[0], x_flows[0].shape[1]), | |||
dtype=np.float32) | |||
origin_rDy = np.zeros((y_flows[0].shape[0], y_flows[0].shape[1]), | |||
dtype=np.float32) | |||
x_fusion_map, x_acc_map = get_map_fusion_map_cuda( | |||
x_flows, 1, device=device) | |||
y_fusion_map, y_acc_map = get_map_fusion_map_cuda( | |||
y_flows, 1, device=device) | |||
x_flow_map = 1.0 / x_fusion_map | |||
y_flow_map = 1.0 / y_fusion_map | |||
all_acc_map = x_acc_map + y_acc_map | |||
all_acc_map = all_acc_map.astype(np.uint8) | |||
roi_box = get_mask_bbox(all_acc_map, threshold=1) | |||
if roi_box[0] is None or roi_box[1] - roi_box[0] <= 0 or roi_box[ | |||
3] - roi_box[2] <= 0: | |||
roi_box = [0, x_flow_map.shape[0], 0, x_flow_map.shape[1]] | |||
roi_x_flow_map = x_flow_map[roi_box[0]:roi_box[1], | |||
roi_box[2]:roi_box[3]] | |||
roi_y_flow_map = y_flow_map[roi_box[0]:roi_box[1], | |||
roi_box[2]:roi_box[3]] | |||
roi_width = roi_x_flow_map.shape[1] | |||
roi_height = roi_x_flow_map.shape[0] | |||
roi_x_flow_map, scale = resize_on_long_side(roi_x_flow_map, 320) | |||
roi_y_flow_map, scale = resize_on_long_side(roi_y_flow_map, 320) | |||
roi_x_flow_map = cv2.blur(roi_x_flow_map, (55, 55)) | |||
roi_y_flow_map = cv2.blur(roi_y_flow_map, (55, 55)) | |||
roi_x_flow_map = cv2.resize(roi_x_flow_map, (roi_width, roi_height)) | |||
roi_y_flow_map = cv2.resize(roi_y_flow_map, (roi_width, roi_height)) | |||
x_flow_map[roi_box[0]:roi_box[1], | |||
roi_box[2]:roi_box[3]] = roi_x_flow_map | |||
y_flow_map[roi_box[0]:roi_box[1], | |||
roi_box[2]:roi_box[3]] = roi_y_flow_map | |||
for i in range(scale_num): | |||
origin_rDx += x_flows[i] | |||
origin_rDy += y_flows[i] | |||
origin_rDx *= x_flow_map | |||
origin_rDy *= y_flow_map | |||
return origin_rDx, origin_rDy, x_flow_map, y_flow_map | |||
def __joint_to_body_box(self): | |||
joint_left = int(np.min(self.joints, axis=0)[0]) | |||
joint_right = int(np.max(self.joints, axis=0)[0]) | |||
joint_top = int(np.min(self.joints, axis=0)[1]) | |||
joint_bottom = int(np.max(self.joints, axis=0)[1]) | |||
return [joint_top, joint_bottom, joint_left, joint_right] | |||
def __joint_to_leg_box(self): | |||
leg_joints = self.joints[8:, :] | |||
if np.max(leg_joints, axis=0)[2] < 0.05: | |||
return [0, 0, 0, 0] | |||
joint_left = int(np.min(leg_joints, axis=0)[0]) | |||
joint_right = int(np.max(leg_joints, axis=0)[0]) | |||
joint_top = int(np.min(leg_joints, axis=0)[1]) | |||
joint_bottom = int(np.max(leg_joints, axis=0)[1]) | |||
return [joint_top, joint_bottom, joint_left, joint_right] | |||
def __joint_to_arm_box(self): | |||
arm_joints = self.joints[2:8, :] | |||
if np.max(arm_joints, axis=0)[2] < 0.05: | |||
return [0, 0, 0, 0] | |||
joint_left = int(np.min(arm_joints, axis=0)[0]) | |||
joint_right = int(np.max(arm_joints, axis=0)[0]) | |||
joint_top = int(np.min(arm_joints, axis=0)[1]) | |||
joint_bottom = int(np.max(arm_joints, axis=0)[1]) | |||
return [joint_top, joint_bottom, joint_left, joint_right] |
@@ -0,0 +1,272 @@ | |||
# The implementation is based on openpose, available at https://github.com/Hzzone/pytorch-openpose. | |||
import math | |||
import cv2 | |||
import numpy as np | |||
import torch | |||
from scipy.ndimage.filters import gaussian_filter | |||
from .model import BodyposeModel | |||
from .util import pad_rightdown_corner, transfer | |||
class Body(object): | |||
def __init__(self, model_path, device): | |||
self.model = BodyposeModel().to(device) | |||
model_dict = transfer(self.model, torch.load(model_path)) | |||
self.model.load_state_dict(model_dict) | |||
self.model.eval() | |||
def __call__(self, oriImg): | |||
scale_search = [0.5] | |||
boxsize = 368 | |||
stride = 8 | |||
padValue = 128 | |||
thre1 = 0.1 | |||
thre2 = 0.05 | |||
bodyparts = 18 | |||
multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search] | |||
heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19)) | |||
paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38)) | |||
for m in range(len(multiplier)): | |||
scale = multiplier[m] | |||
imageToTest = cv2.resize( | |||
oriImg, (0, 0), | |||
fx=scale, | |||
fy=scale, | |||
interpolation=cv2.INTER_CUBIC) | |||
imageToTest_padded, pad = pad_rightdown_corner( | |||
imageToTest, stride, padValue) | |||
im = np.transpose( | |||
np.float32(imageToTest_padded[:, :, :, np.newaxis]), | |||
(3, 2, 0, 1)) / 256 - 0.5 | |||
im = np.ascontiguousarray(im) | |||
data = torch.from_numpy(im).float() | |||
if torch.cuda.is_available(): | |||
data = data.cuda() | |||
with torch.no_grad(): | |||
Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data) | |||
Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy() | |||
Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy() | |||
# extract outputs, resize, and remove padding | |||
heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2), | |||
(1, 2, 0)) # output 1 is heatmaps | |||
heatmap = cv2.resize( | |||
heatmap, (0, 0), | |||
fx=stride, | |||
fy=stride, | |||
interpolation=cv2.INTER_CUBIC) | |||
heatmap = heatmap[:imageToTest_padded.shape[0] | |||
- pad[2], :imageToTest_padded.shape[1] | |||
- pad[3], :] | |||
heatmap = cv2.resize( | |||
heatmap, (oriImg.shape[1], oriImg.shape[0]), | |||
interpolation=cv2.INTER_CUBIC) | |||
paf = np.transpose(np.squeeze(Mconv7_stage6_L1), | |||
(1, 2, 0)) # output 0 is PAFs | |||
paf = cv2.resize( | |||
paf, (0, 0), | |||
fx=stride, | |||
fy=stride, | |||
interpolation=cv2.INTER_CUBIC) | |||
paf = paf[:imageToTest_padded.shape[0] | |||
- pad[2], :imageToTest_padded.shape[1] - pad[3], :] | |||
paf = cv2.resize( | |||
paf, (oriImg.shape[1], oriImg.shape[0]), | |||
interpolation=cv2.INTER_CUBIC) | |||
heatmap_avg += heatmap_avg + heatmap / len(multiplier) | |||
paf_avg += +paf / len(multiplier) | |||
all_peaks = [] | |||
peak_counter = 0 | |||
for part in range(bodyparts): | |||
map_ori = heatmap_avg[:, :, part] | |||
one_heatmap = gaussian_filter(map_ori, sigma=3) | |||
map_left = np.zeros(one_heatmap.shape) | |||
map_left[1:, :] = one_heatmap[:-1, :] | |||
map_right = np.zeros(one_heatmap.shape) | |||
map_right[:-1, :] = one_heatmap[1:, :] | |||
map_up = np.zeros(one_heatmap.shape) | |||
map_up[:, 1:] = one_heatmap[:, :-1] | |||
map_down = np.zeros(one_heatmap.shape) | |||
map_down[:, :-1] = one_heatmap[:, 1:] | |||
peaks_binary = np.logical_and.reduce( | |||
(one_heatmap >= map_left, one_heatmap >= map_right, | |||
one_heatmap >= map_up, one_heatmap >= map_down, | |||
one_heatmap > thre1)) | |||
peaks = list( | |||
zip(np.nonzero(peaks_binary)[1], | |||
np.nonzero(peaks_binary)[0])) # note reverse | |||
peaks_with_score = [x + (map_ori[x[1], x[0]], ) for x in peaks] | |||
peak_id = range(peak_counter, peak_counter + len(peaks)) | |||
peaks_with_score_and_id = [ | |||
peaks_with_score[i] + (peak_id[i], ) | |||
for i in range(len(peak_id)) | |||
] | |||
all_peaks.append(peaks_with_score_and_id) | |||
peak_counter += len(peaks) | |||
# find connection in the specified sequence, center 29 is in the position 15 | |||
limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], | |||
[9, 10], [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], | |||
[1, 15], [15, 17], [1, 16], [16, 18], [3, 17], [6, 18]] | |||
# the middle joints heatmap correpondence | |||
mapIdx = [[31, 32], [39, 40], [33, 34], [35, 36], [41, 42], [43, 44], | |||
[19, 20], [21, 22], [23, 24], [25, 26], [27, 28], [29, 30], | |||
[47, 48], [49, 50], [53, 54], [51, 52], [55, 56], [37, 38], | |||
[45, 46]] | |||
connection_all = [] | |||
special_k = [] | |||
mid_num = 10 | |||
for k in range(len(mapIdx)): | |||
score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]] | |||
candA = all_peaks[limbSeq[k][0] - 1] | |||
candB = all_peaks[limbSeq[k][1] - 1] | |||
nA = len(candA) | |||
nB = len(candB) | |||
if (nA != 0 and nB != 0): | |||
connection_candidate = [] | |||
for i in range(nA): | |||
for j in range(nB): | |||
vec = np.subtract(candB[j][:2], candA[i][:2]) | |||
norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1]) | |||
norm = max(0.001, norm) | |||
vec = np.divide(vec, norm) | |||
startend = list( | |||
zip( | |||
np.linspace( | |||
candA[i][0], candB[j][0], num=mid_num), | |||
np.linspace( | |||
candA[i][1], candB[j][1], num=mid_num))) | |||
vec_x = np.array([ | |||
score_mid[int(round(startend[item][1])), | |||
int(round(startend[item][0])), 0] | |||
for item in range(len(startend)) | |||
]) | |||
vec_y = np.array([ | |||
score_mid[int(round(startend[item][1])), | |||
int(round(startend[item][0])), 1] | |||
for item in range(len(startend)) | |||
]) | |||
score_midpts = np.multiply( | |||
vec_x, vec[0]) + np.multiply(vec_y, vec[1]) | |||
temp1 = sum(score_midpts) / len(score_midpts) | |||
temp2 = min(0.5 * oriImg.shape[0] / norm - 1, 0) | |||
score_with_dist_prior = temp1 + temp2 | |||
criterion1 = len(np.nonzero( | |||
score_midpts > thre2)[0]) > 0.8 * len(score_midpts) | |||
criterion2 = score_with_dist_prior > 0 | |||
if criterion1 and criterion2: | |||
connection_candidate.append([ | |||
i, j, score_with_dist_prior, | |||
score_with_dist_prior + candA[i][2] | |||
+ candB[j][2] | |||
]) | |||
connection_candidate = sorted( | |||
connection_candidate, key=lambda x: x[2], reverse=True) | |||
connection = np.zeros((0, 5)) | |||
for c in range(len(connection_candidate)): | |||
i, j, s = connection_candidate[c][0:3] | |||
if (i not in connection[:, 3] | |||
and j not in connection[:, 4]): | |||
connection = np.vstack( | |||
[connection, [candA[i][3], candB[j][3], s, i, j]]) | |||
if (len(connection) >= min(nA, nB)): | |||
break | |||
connection_all.append(connection) | |||
else: | |||
special_k.append(k) | |||
connection_all.append([]) | |||
# last number in each row is the total parts number of that person | |||
# the second last number in each row is the score of the overall configuration | |||
subset = -1 * np.ones((0, 20)) | |||
candidate = np.array( | |||
[item for sublist in all_peaks for item in sublist]) | |||
for k in range(len(mapIdx)): | |||
if k not in special_k: | |||
partAs = connection_all[k][:, 0] | |||
partBs = connection_all[k][:, 1] | |||
indexA, indexB = np.array(limbSeq[k]) - 1 | |||
for i in range(len(connection_all[k])): # = 1:size(temp,1) | |||
found = 0 | |||
subset_idx = [-1, -1] | |||
for j in range(len(subset)): # 1:size(subset,1): | |||
if subset[j][indexA] == partAs[i] or subset[j][ | |||
indexB] == partBs[i]: | |||
subset_idx[found] = j | |||
found += 1 | |||
if found == 1: | |||
j = subset_idx[0] | |||
if subset[j][indexB] != partBs[i]: | |||
subset[j][indexB] = partBs[i] | |||
subset[j][-1] += 1 | |||
subset[j][-2] += candidate[ | |||
partBs[i].astype(int), | |||
2] + connection_all[k][i][2] | |||
elif found == 2: # if found 2 and disjoint, merge them | |||
j1, j2 = subset_idx | |||
tmp1 = (subset[j1] >= 0).astype(int) | |||
tmp2 = (subset[j2] >= 0).astype(int) | |||
membership = (tmp1 + tmp2)[:-2] | |||
if len(np.nonzero(membership == 2)[0]) == 0: # merge | |||
subset[j1][:-2] += (subset[j2][:-2] + 1) | |||
subset[j1][-2:] += subset[j2][-2:] | |||
subset[j1][-2] += connection_all[k][i][2] | |||
subset = np.delete(subset, j2, 0) | |||
else: # as like found == 1 | |||
subset[j1][indexB] = partBs[i] | |||
subset[j1][-1] += 1 | |||
subset[j1][-2] += candidate[ | |||
partBs[i].astype(int), | |||
2] + connection_all[k][i][2] | |||
# if find no partA in the subset, create a new subset | |||
elif not found and k < 17: | |||
row = -1 * np.ones(20) | |||
row[indexA] = partAs[i] | |||
row[indexB] = partBs[i] | |||
row[-1] = 2 | |||
row[-2] = sum( | |||
candidate[connection_all[k][i, :2].astype(int), | |||
2]) + connection_all[k][i][2] | |||
subset = np.vstack([subset, row]) | |||
# delete some rows of subset which has few parts occur | |||
deleteIdx = [] | |||
for i in range(len(subset)): | |||
if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4: | |||
deleteIdx.append(i) | |||
subset = np.delete(subset, deleteIdx, axis=0) | |||
# subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts | |||
# candidate: x, y, score, id | |||
count = subset.shape[0] | |||
joints = np.zeros(shape=(count, bodyparts, 3)) | |||
for i in range(count): | |||
for j in range(bodyparts): | |||
joints[i, j, :3] = candidate[int(subset[i, j]), :3] | |||
confidence = 1.0 if subset[i, j] >= 0 else 0.0 | |||
joints[i, j, 2] *= confidence | |||
return joints |
@@ -0,0 +1,141 @@ | |||
# The implementation is based on openpose, available at https://github.com/Hzzone/pytorch-openpose. | |||
from collections import OrderedDict | |||
import torch | |||
import torch.nn as nn | |||
def make_layers(block, no_relu_layers): | |||
layers = [] | |||
for layer_name, v in block.items(): | |||
if 'pool' in layer_name: | |||
layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], padding=v[2]) | |||
layers.append((layer_name, layer)) | |||
else: | |||
conv2d = nn.Conv2d( | |||
in_channels=v[0], | |||
out_channels=v[1], | |||
kernel_size=v[2], | |||
stride=v[3], | |||
padding=v[4]) | |||
layers.append((layer_name, conv2d)) | |||
if layer_name not in no_relu_layers: | |||
layers.append(('relu_' + layer_name, nn.ReLU(inplace=True))) | |||
return nn.Sequential(OrderedDict(layers)) | |||
class BodyposeModel(nn.Module): | |||
def __init__(self): | |||
super(BodyposeModel, self).__init__() | |||
# these layers have no relu layer | |||
no_relu_layers = [ | |||
'conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1', | |||
'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2', | |||
'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1', | |||
'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1' | |||
] | |||
blocks = {} | |||
block0 = OrderedDict([('conv1_1', [3, 64, 3, 1, 1]), | |||
('conv1_2', [64, 64, 3, 1, 1]), | |||
('pool1_stage1', [2, 2, 0]), | |||
('conv2_1', [64, 128, 3, 1, 1]), | |||
('conv2_2', [128, 128, 3, 1, 1]), | |||
('pool2_stage1', [2, 2, 0]), | |||
('conv3_1', [128, 256, 3, 1, 1]), | |||
('conv3_2', [256, 256, 3, 1, 1]), | |||
('conv3_3', [256, 256, 3, 1, 1]), | |||
('conv3_4', [256, 256, 3, 1, 1]), | |||
('pool3_stage1', [2, 2, 0]), | |||
('conv4_1', [256, 512, 3, 1, 1]), | |||
('conv4_2', [512, 512, 3, 1, 1]), | |||
('conv4_3_CPM', [512, 256, 3, 1, 1]), | |||
('conv4_4_CPM', [256, 128, 3, 1, 1])]) | |||
# Stage 1 | |||
block1_1 = OrderedDict([('conv5_1_CPM_L1', [128, 128, 3, 1, 1]), | |||
('conv5_2_CPM_L1', [128, 128, 3, 1, 1]), | |||
('conv5_3_CPM_L1', [128, 128, 3, 1, 1]), | |||
('conv5_4_CPM_L1', [128, 512, 1, 1, 0]), | |||
('conv5_5_CPM_L1', [512, 38, 1, 1, 0])]) | |||
block1_2 = OrderedDict([('conv5_1_CPM_L2', [128, 128, 3, 1, 1]), | |||
('conv5_2_CPM_L2', [128, 128, 3, 1, 1]), | |||
('conv5_3_CPM_L2', [128, 128, 3, 1, 1]), | |||
('conv5_4_CPM_L2', [128, 512, 1, 1, 0]), | |||
('conv5_5_CPM_L2', [512, 19, 1, 1, 0])]) | |||
blocks['block1_1'] = block1_1 | |||
blocks['block1_2'] = block1_2 | |||
self.model0 = make_layers(block0, no_relu_layers) | |||
# Stages 2 - 6 | |||
for i in range(2, 7): | |||
blocks['block%d_1' % i] = OrderedDict([ | |||
('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]), | |||
('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]), | |||
('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]), | |||
('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]), | |||
('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]), | |||
('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]), | |||
('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0]) | |||
]) | |||
blocks['block%d_2' % i] = OrderedDict([ | |||
('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]), | |||
('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]), | |||
('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]), | |||
('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]), | |||
('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]), | |||
('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]), | |||
('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0]) | |||
]) | |||
for k in blocks.keys(): | |||
blocks[k] = make_layers(blocks[k], no_relu_layers) | |||
self.model1_1 = blocks['block1_1'] | |||
self.model2_1 = blocks['block2_1'] | |||
self.model3_1 = blocks['block3_1'] | |||
self.model4_1 = blocks['block4_1'] | |||
self.model5_1 = blocks['block5_1'] | |||
self.model6_1 = blocks['block6_1'] | |||
self.model1_2 = blocks['block1_2'] | |||
self.model2_2 = blocks['block2_2'] | |||
self.model3_2 = blocks['block3_2'] | |||
self.model4_2 = blocks['block4_2'] | |||
self.model5_2 = blocks['block5_2'] | |||
self.model6_2 = blocks['block6_2'] | |||
def forward(self, x): | |||
out1 = self.model0(x) | |||
out1_1 = self.model1_1(out1) | |||
out1_2 = self.model1_2(out1) | |||
out2 = torch.cat([out1_1, out1_2, out1], 1) | |||
out2_1 = self.model2_1(out2) | |||
out2_2 = self.model2_2(out2) | |||
out3 = torch.cat([out2_1, out2_2, out1], 1) | |||
out3_1 = self.model3_1(out3) | |||
out3_2 = self.model3_2(out3) | |||
out4 = torch.cat([out3_1, out3_2, out1], 1) | |||
out4_1 = self.model4_1(out4) | |||
out4_2 = self.model4_2(out4) | |||
out5 = torch.cat([out4_1, out4_2, out1], 1) | |||
out5_1 = self.model5_1(out5) | |||
out5_2 = self.model5_2(out5) | |||
out6 = torch.cat([out5_1, out5_2, out1], 1) | |||
out6_1 = self.model6_1(out6) | |||
out6_2 = self.model6_2(out6) | |||
return out6_1, out6_2 |
@@ -0,0 +1,33 @@ | |||
# The implementation is based on openpose, available at https://github.com/Hzzone/pytorch-openpose. | |||
import numpy as np | |||
def pad_rightdown_corner(img, stride, padValue): | |||
h = img.shape[0] | |||
w = img.shape[1] | |||
pad = 4 * [None] | |||
pad[0] = 0 # up | |||
pad[1] = 0 # left | |||
pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down | |||
pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right | |||
img_padded = img | |||
pad_up = np.tile(img_padded[0:1, :, :] * 0 + padValue, (pad[0], 1, 1)) | |||
img_padded = np.concatenate((pad_up, img_padded), axis=0) | |||
pad_left = np.tile(img_padded[:, 0:1, :] * 0 + padValue, (1, pad[1], 1)) | |||
img_padded = np.concatenate((pad_left, img_padded), axis=1) | |||
pad_down = np.tile(img_padded[-2:-1, :, :] * 0 + padValue, (pad[2], 1, 1)) | |||
img_padded = np.concatenate((img_padded, pad_down), axis=0) | |||
pad_right = np.tile(img_padded[:, -2:-1, :] * 0 + padValue, (1, pad[3], 1)) | |||
img_padded = np.concatenate((img_padded, pad_right), axis=1) | |||
return img_padded, pad | |||
def transfer(model, model_weights): | |||
transfered_model_weights = {} | |||
for weights_name in model.state_dict().keys(): | |||
transfered_model_weights[weights_name] = model_weights['.'.join( | |||
weights_name.split('.')[1:])] | |||
return transfered_model_weights |
@@ -0,0 +1,507 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import math | |||
import os | |||
import random | |||
import cv2 | |||
import numba | |||
import numpy as np | |||
import torch | |||
def resize_on_long_side(img, long_side=800): | |||
src_height = img.shape[0] | |||
src_width = img.shape[1] | |||
if src_height > src_width: | |||
scale = long_side * 1.0 / src_height | |||
_img = cv2.resize( | |||
img, (int(src_width * scale), long_side), | |||
interpolation=cv2.INTER_LINEAR) | |||
else: | |||
scale = long_side * 1.0 / src_width | |||
_img = cv2.resize( | |||
img, (long_side, int(src_height * scale)), | |||
interpolation=cv2.INTER_LINEAR) | |||
return _img, scale | |||
def point_in_box(pt, box): | |||
pt_x = pt[0] | |||
pt_y = pt[1] | |||
if pt_x >= box[0] and pt_x <= box[0] + box[2] and pt_y >= box[ | |||
1] and pt_y <= box[1] + box[3]: | |||
return True | |||
else: | |||
return False | |||
def enlarge_box_tblr(roi_bbox, mask, ratio=0.4, use_long_side=True): | |||
if roi_bbox is None or None in roi_bbox: | |||
return [None, None, None, None] | |||
top = roi_bbox[0] | |||
bottom = roi_bbox[1] | |||
left = roi_bbox[2] | |||
right = roi_bbox[3] | |||
roi_width = roi_bbox[3] - roi_bbox[2] | |||
roi_height = roi_bbox[1] - roi_bbox[0] | |||
right = left + roi_width | |||
bottom = top + roi_height | |||
long_side = roi_width if roi_width > roi_height else roi_height | |||
if use_long_side: | |||
new_left = left - int(long_side * ratio) | |||
else: | |||
new_left = left - int(roi_width * ratio) | |||
new_left = 1 if new_left < 0 else new_left | |||
if use_long_side: | |||
new_top = top - int(long_side * ratio) | |||
else: | |||
new_top = top - int(roi_height * ratio) | |||
new_top = 1 if new_top < 0 else new_top | |||
if use_long_side: | |||
new_right = right + int(long_side * ratio) | |||
else: | |||
new_right = right + int(roi_width * ratio) | |||
new_right = mask.shape[1] - 2 if new_right > mask.shape[1] else new_right | |||
if use_long_side: | |||
new_bottom = bottom + int(long_side * ratio) | |||
else: | |||
new_bottom = bottom + int(roi_height * ratio) | |||
new_bottom = mask.shape[0] - 2 if new_bottom > mask.shape[0] else new_bottom | |||
bbox = [new_top, new_bottom, new_left, new_right] | |||
return bbox | |||
def gen_PAF(image, joints): | |||
assert joints.shape[0] == 18 | |||
assert joints.shape[1] == 3 | |||
org_h = image.shape[0] | |||
org_w = image.shape[1] | |||
small_image, resize_scale = resize_on_long_side(image, 120) | |||
joints[:, :2] = joints[:, :2] * resize_scale | |||
joint_left = int(np.min(joints, axis=0)[0]) | |||
joint_right = int(np.max(joints, axis=0)[0]) | |||
joint_top = int(np.min(joints, axis=0)[1]) | |||
joint_bottom = int(np.max(joints, axis=0)[1]) | |||
limb_width = min( | |||
abs(joint_right - joint_left), abs(joint_bottom - joint_top)) // 6 | |||
if limb_width % 2 == 0: | |||
limb_width += 1 | |||
kernel_size = limb_width | |||
part_orders = [(5, 11), (2, 8), (5, 6), (6, 7), (2, 3), (3, 4), (11, 12), | |||
(12, 13), (8, 9), (9, 10)] | |||
map_list = [] | |||
mask_list = [] | |||
PAF_all = np.zeros( | |||
shape=(small_image.shape[0], small_image.shape[1], 2), | |||
dtype=np.float32) | |||
for c, pair in enumerate(part_orders): | |||
idx_a_name = pair[0] | |||
idx_b_name = pair[1] | |||
jointa = joints[idx_a_name] | |||
jointb = joints[idx_b_name] | |||
confidence_threshold = 0.05 | |||
if jointa[2] > confidence_threshold and jointb[ | |||
2] > confidence_threshold: | |||
canvas = np.zeros( | |||
shape=(small_image.shape[0], small_image.shape[1]), | |||
dtype=np.uint8) | |||
canvas = cv2.line(canvas, (int(jointa[0]), int(jointa[1])), | |||
(int(jointb[0]), int(jointb[1])), | |||
(255, 255, 255), 5) | |||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, | |||
(kernel_size, kernel_size)) | |||
canvas = cv2.dilate(canvas, kernel, 1) | |||
canvas = cv2.GaussianBlur(canvas, (kernel_size, kernel_size), 0) | |||
canvas = canvas.astype(np.float32) / 255 | |||
PAF = np.zeros( | |||
shape=(small_image.shape[0], small_image.shape[1], 2), | |||
dtype=np.float32) | |||
PAF[..., 0] = jointb[0] - jointa[0] | |||
PAF[..., 1] = jointb[1] - jointa[1] | |||
mag, ang = cv2.cartToPolar(PAF[..., 0], PAF[..., 1]) | |||
PAF /= (np.dstack((mag, mag)) + 1e-5) | |||
single_PAF = PAF * np.dstack((canvas, canvas)) | |||
map_list.append( | |||
cv2.GaussianBlur(single_PAF, | |||
(kernel_size * 3, kernel_size * 3), 0)) | |||
mask_list.append( | |||
cv2.GaussianBlur(canvas.copy(), | |||
(kernel_size * 3, kernel_size * 3), 0)) | |||
PAF_all = PAF_all * (1.0 - np.dstack( | |||
(canvas, canvas))) + single_PAF | |||
PAF_all = cv2.GaussianBlur(PAF_all, (kernel_size * 3, kernel_size * 3), 0) | |||
PAF_all = cv2.resize( | |||
PAF_all, (org_w, org_h), interpolation=cv2.INTER_LINEAR) | |||
map_list.append(PAF_all) | |||
return PAF_all, map_list, mask_list | |||
def gen_skeleton_map(joints, stack_mode='column', input_roi_box=None): | |||
if type(joints) == list: | |||
joints = np.array(joints) | |||
assert stack_mode == 'column' or stack_mode == 'depth' | |||
part_orders = [(2, 5), (5, 11), (2, 8), (8, 11), (5, 6), (6, 7), (2, 3), | |||
(3, 4), (11, 12), (12, 13), (8, 9), (9, 10)] | |||
def link(img, a, b, color, line_width, scale=1.0, x_offset=0, y_offset=0): | |||
jointa = joints[a] | |||
jointb = joints[b] | |||
temp1 = int((jointa[0] - x_offset) * scale) | |||
temp2 = int((jointa[1] - y_offset) * scale) | |||
temp3 = int((jointb[0] - x_offset) * scale) | |||
temp4 = int((jointb[1] - y_offset) * scale) | |||
cv2.line(img, (temp1, temp2), (temp3, temp4), color, line_width) | |||
roi_box = input_roi_box | |||
roi_box_width = roi_box[3] - roi_box[2] | |||
roi_box_height = roi_box[1] - roi_box[0] | |||
short_side_length = min(roi_box_width, roi_box_height) | |||
line_width = short_side_length // 30 | |||
line_width = max(line_width, 2) | |||
map_cube = np.zeros( | |||
shape=(roi_box_height, roi_box_width, len(part_orders) + 1), | |||
dtype=np.float32) | |||
use_line_width = min(5, line_width) | |||
fx = use_line_width * 1.0 / line_width # fx 最大值为1 | |||
if fx < 0.99: | |||
map_cube = cv2.resize(map_cube, (0, 0), fx=fx, fy=fx) | |||
for c, pair in enumerate(part_orders): | |||
tmp = map_cube[..., c].copy() | |||
link( | |||
tmp, | |||
pair[0], | |||
pair[1], (2.0, 2.0, 2.0), | |||
use_line_width, | |||
scale=fx, | |||
x_offset=roi_box[2], | |||
y_offset=roi_box[0]) | |||
map_cube[..., c] = tmp | |||
tmp = map_cube[..., -1].copy() | |||
link( | |||
tmp, | |||
pair[0], | |||
pair[1], (2.0, 2.0, 2.0), | |||
use_line_width, | |||
scale=fx, | |||
x_offset=roi_box[2], | |||
y_offset=roi_box[0]) | |||
map_cube[..., -1] = tmp | |||
map_cube = cv2.resize(map_cube, (roi_box_width, roi_box_height)) | |||
if stack_mode == 'depth': | |||
return map_cube, roi_box | |||
elif stack_mode == 'column': | |||
joint_maps = [] | |||
for c in range(len(part_orders) + 1): | |||
joint_maps.append(map_cube[..., c]) | |||
joint_map = np.column_stack(joint_maps) | |||
return joint_map, roi_box | |||
def plot_one_box(x, img, color=None, label=None, line_thickness=None): | |||
tl = line_thickness or round( | |||
0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness | |||
color = color or [random.randint(0, 255) for _ in range(3)] | |||
c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3])) | |||
cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA) | |||
if label: | |||
tf = max(tl - 1, 1) # font thickness | |||
t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0] | |||
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3 | |||
cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled | |||
cv2.putText( | |||
img, | |||
label, (c1[0], c1[1] - 2), | |||
0, | |||
tl / 3, [225, 255, 255], | |||
thickness=tf, | |||
lineType=cv2.LINE_AA) | |||
def draw_line(im, points, color, stroke_size=2, closed=False): | |||
points = points.astype(np.int32) | |||
for i in range(len(points) - 1): | |||
cv2.line(im, tuple(points[i]), tuple(points[i + 1]), color, | |||
stroke_size) | |||
if closed: | |||
cv2.line(im, tuple(points[0]), tuple(points[-1]), color, stroke_size) | |||
def enlarged_bbox(bbox, img_width, img_height, enlarge_ratio=0.2): | |||
left = bbox[0] | |||
top = bbox[1] | |||
right = bbox[2] | |||
bottom = bbox[3] | |||
roi_width = right - left | |||
roi_height = bottom - top | |||
new_left = left - int(roi_width * enlarge_ratio) | |||
new_left = 0 if new_left < 0 else new_left | |||
new_top = top - int(roi_height * enlarge_ratio) | |||
new_top = 0 if new_top < 0 else new_top | |||
new_right = right + int(roi_width * enlarge_ratio) | |||
new_right = img_width if new_right > img_width else new_right | |||
new_bottom = bottom + int(roi_height * enlarge_ratio) | |||
new_bottom = img_height if new_bottom > img_height else new_bottom | |||
bbox = [new_left, new_top, new_right, new_bottom] | |||
bbox = [int(x) for x in bbox] | |||
return bbox | |||
def get_map_fusion_map_cuda(map_list, threshold=1, device=torch.device('cpu')): | |||
map_list_cuda = [torch.from_numpy(x).to(device) for x in map_list] | |||
map_concat = torch.stack(tuple(map_list_cuda), dim=-1) | |||
map_concat = torch.abs(map_concat) | |||
map_concat[map_concat < threshold] = 0 | |||
map_concat[map_concat > 1e-5] = 1.0 | |||
sum_map = torch.sum(map_concat, dim=2) | |||
a = torch.ones_like(sum_map) | |||
acc_map = torch.where(sum_map > 0, a * 2.0, torch.zeros_like(sum_map)) | |||
fusion_map = torch.where(sum_map < 0.5, a * 1.5, sum_map) | |||
fusion_map = fusion_map.float() | |||
acc_map = acc_map.float() | |||
fusion_map = fusion_map.cpu().numpy().astype(np.float32) | |||
acc_map = acc_map.cpu().numpy().astype(np.float32) | |||
return fusion_map, acc_map | |||
def gen_border_shade(height, width, height_band, width_band): | |||
height_ratio = height_band * 1.0 / height | |||
width_ratio = width_band * 1.0 / width | |||
_height_band = int(256 * height_ratio) | |||
_width_band = int(256 * width_ratio) | |||
canvas = np.zeros((256, 256), dtype=np.float32) | |||
canvas[_height_band // 2:-_height_band // 2, | |||
_width_band // 2:-_width_band // 2] = 1.0 | |||
canvas = cv2.blur(canvas, (_height_band, _width_band)) | |||
canvas = cv2.resize(canvas, (width, height)) | |||
return canvas | |||
def get_mask_bbox(mask, threshold=127): | |||
ret, mask = cv2.threshold(mask, threshold, 1, 0) | |||
if cv2.countNonZero(mask) == 0: | |||
return [None, None, None, None] | |||
col_acc = np.sum(mask, 0) | |||
row_acc = np.sum(mask, 1) | |||
col_acc = col_acc.tolist() | |||
row_acc = row_acc.tolist() | |||
for x in range(len(col_acc)): | |||
if col_acc[x] > 0: | |||
left = x | |||
break | |||
for x in range(1, len(col_acc)): | |||
if col_acc[-x] > 0: | |||
right = len(col_acc) - x | |||
break | |||
for x in range(len(row_acc)): | |||
if row_acc[x] > 0: | |||
top = x | |||
break | |||
for x in range(1, len(row_acc)): | |||
if row_acc[-x] > 0: | |||
bottom = len(row_acc[::-1]) - x | |||
break | |||
return [top, bottom, left, right] | |||
def visualize_flow(flow): | |||
h, w = flow.shape[:2] | |||
hsv = np.zeros((h, w, 3), np.uint8) | |||
mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1]) | |||
hsv[..., 0] = ang * 180 / np.pi / 2 | |||
hsv[..., 1] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) | |||
hsv[..., 2] = 255 | |||
bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) | |||
bgr = bgr * 1.0 / 255 | |||
return bgr.astype(np.float32) | |||
def vis_joints(image, joints, color, show_text=True, confidence_threshold=0.1): | |||
part_orders = [(2, 5), (5, 11), (2, 8), (8, 11), (5, 6), (6, 7), (2, 3), | |||
(3, 4), (11, 12), (12, 13), (8, 9), (9, 10)] | |||
abandon_idxs = [0, 1, 14, 15, 16, 17] | |||
# draw joints | |||
for i, joint in enumerate(joints): | |||
if i in abandon_idxs: | |||
continue | |||
if joint[-1] > confidence_threshold: | |||
cv2.circle(image, (int(joint[0]), int(joint[1])), 1, color, 2) | |||
if show_text: | |||
cv2.putText(image, | |||
str(i) + '[{:.2f}]'.format(joint[-1]), | |||
(int(joint[0]), int(joint[1])), | |||
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) | |||
# draw link | |||
for pair in part_orders: | |||
if joints[pair[0]][-1] > confidence_threshold and joints[ | |||
pair[1]][-1] > confidence_threshold: | |||
cv2.line(image, (int(joints[pair[0]][0]), int(joints[pair[0]][1])), | |||
(int(joints[pair[1]][0]), int(joints[pair[1]][1])), color, | |||
2) | |||
return image | |||
def get_heatmap_cv(img, magn, max_flow_mag): | |||
min_flow_mag = .5 | |||
cv_magn = np.clip( | |||
255 * (magn - min_flow_mag) / (max_flow_mag - min_flow_mag + 1e-7), | |||
a_min=0, | |||
a_max=255).astype(np.uint8) | |||
if img.dtype != np.uint8: | |||
img = (255 * img).astype(np.uint8) | |||
heatmap_img = cv2.applyColorMap(cv_magn, cv2.COLORMAP_JET) | |||
heatmap_img = heatmap_img[..., ::-1] | |||
h, w = magn.shape | |||
img_alpha = np.ones((h, w), dtype=np.double)[:, :, None] | |||
heatmap_alpha = np.clip( | |||
magn / (max_flow_mag + 1e-7), a_min=1e-7, a_max=1)[:, :, None]**.7 | |||
heatmap_alpha[heatmap_alpha < .2]**.5 | |||
pm_hm = heatmap_img * heatmap_alpha | |||
pm_img = img * img_alpha | |||
cv_out = pm_hm + pm_img * (1 - heatmap_alpha) | |||
cv_out = np.clip(cv_out, a_min=0, a_max=255).astype(np.uint8) | |||
return cv_out | |||
def save_heatmap_cv(img, flow, supression=2): | |||
flow_magn = np.sqrt(flow[:, :, 0]**2 + flow[:, :, 1]**2) | |||
flow_magn -= supression | |||
flow_magn[flow_magn <= 0] = 0 | |||
cv_out = get_heatmap_cv(img, flow_magn, np.max(flow_magn) * 1.3) | |||
return cv_out | |||
@numba.jit(nopython=True, parallel=False) | |||
def bilinear_interp(x, y, v11, v12, v21, v22): | |||
temp1 = (v11 * (1 - y) + v12 * y) * (1 - x) | |||
temp2 = (v21 * (1 - y) + v22 * y) * x | |||
result = temp1 + temp2 | |||
return result | |||
@numba.jit(nopython=True, parallel=False) | |||
def image_warp_grid1(rDx, rDy, oriImg, transRatio, width_expand, | |||
height_expand): | |||
srcW = oriImg.shape[1] | |||
srcH = oriImg.shape[0] | |||
newImg = oriImg.copy() | |||
for i in range(srcH): | |||
for j in range(srcW): | |||
_i = i | |||
_j = j | |||
deltaX = rDx[_i, _j] | |||
deltaY = rDy[_i, _j] | |||
nx = _j + deltaX * transRatio | |||
ny = _i + deltaY * transRatio | |||
if nx >= srcW - width_expand - 1: | |||
if nx > srcW - 1: | |||
nx = srcW - 1 | |||
if ny >= srcH - height_expand - 1: | |||
if ny > srcH - 1: | |||
ny = srcH - 1 | |||
if nx < width_expand: | |||
if nx < 0: | |||
nx = 0 | |||
if ny < height_expand: | |||
if ny < 0: | |||
ny = 0 | |||
nxi = int(math.floor(nx)) | |||
nyi = int(math.floor(ny)) | |||
nxi1 = int(math.ceil(nx)) | |||
nyi1 = int(math.ceil(ny)) | |||
for ll in range(3): | |||
newImg[_i, _j, | |||
ll] = bilinear_interp(ny - nyi, nx - nxi, | |||
oriImg[nyi, nxi, | |||
ll], oriImg[nyi, nxi1, ll], | |||
oriImg[nyi1, nxi, | |||
ll], oriImg[nyi1, nxi1, | |||
ll]) | |||
return newImg |
@@ -184,6 +184,7 @@ TASK_OUTPUTS = { | |||
Tasks.image_to_image_translation: [OutputKeys.OUTPUT_IMG], | |||
Tasks.image_style_transfer: [OutputKeys.OUTPUT_IMG], | |||
Tasks.image_portrait_stylization: [OutputKeys.OUTPUT_IMG], | |||
Tasks.image_body_reshaping: [OutputKeys.OUTPUT_IMG], | |||
# live category recognition result for single video | |||
# { | |||
@@ -75,6 +75,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
'damo/nlp_bart_text-error-correction_chinese'), | |||
Tasks.image_captioning: (Pipelines.image_captioning, | |||
'damo/ofa_image-caption_coco_large_en'), | |||
Tasks.image_body_reshaping: (Pipelines.image_body_reshaping, | |||
'damo/cv_flow-based-body-reshaping_damo'), | |||
Tasks.image_portrait_stylization: | |||
(Pipelines.person_image_cartoon, | |||
'damo/cv_unet_person-image-cartoon_compound-models'), | |||
@@ -0,0 +1,40 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import Any, Dict | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.base import Input, Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import LoadImage | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
@PIPELINES.register_module( | |||
Tasks.image_body_reshaping, module_name=Pipelines.image_body_reshaping) | |||
class ImageBodyReshapingPipeline(Pipeline): | |||
def __init__(self, model: str, **kwargs): | |||
""" | |||
use `model` to create a image body reshaping pipeline for prediction | |||
Args: | |||
model: model id on modelscope hub. | |||
""" | |||
super().__init__(model=model, **kwargs) | |||
logger.info('body reshaping model init done') | |||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||
img = LoadImage.convert_to_ndarray(input) | |||
result = {'img': img} | |||
return result | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
output = self.model.inference(input['img']) | |||
result = {'outputs': output} | |||
return result | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
output_img = inputs['outputs'] | |||
return {OutputKeys.OUTPUT_IMG: output_img} |
@@ -60,7 +60,7 @@ class CVTasks(object): | |||
image_to_image_generation = 'image-to-image-generation' | |||
image_style_transfer = 'image-style-transfer' | |||
image_portrait_stylization = 'image-portrait-stylization' | |||
image_body_reshaping = 'image-body-reshaping' | |||
image_embedding = 'image-embedding' | |||
product_retrieval_embedding = 'product-retrieval-embedding' | |||
@@ -13,6 +13,7 @@ ml_collections | |||
mmcls>=0.21.0 | |||
mmdet>=2.25.0 | |||
networkx>=2.5 | |||
numba | |||
onnxruntime>=1.10 | |||
pai-easycv>=0.6.3.6 | |||
pandas | |||
@@ -0,0 +1,58 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os.path as osp | |||
import unittest | |||
import cv2 | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines import pipeline | |||
from modelscope.pipelines.base import Pipeline | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||
from modelscope.utils.test_utils import test_level | |||
class ImageBodyReshapingTest(unittest.TestCase, DemoCompatibilityCheck): | |||
def setUp(self) -> None: | |||
self.task = Tasks.image_body_reshaping | |||
self.model_id = 'damo/cv_flow-based-body-reshaping_damo' | |||
self.test_image = 'data/test/images/image_body_reshaping.jpg' | |||
def pipeline_inference(self, pipeline: Pipeline, input_location: str): | |||
result = pipeline(input_location) | |||
if result is not None: | |||
cv2.imwrite('result_bodyreshaping.png', | |||
result[OutputKeys.OUTPUT_IMG]) | |||
print( | |||
f'Output written to {osp.abspath("result_body_reshaping.png")}' | |||
) | |||
else: | |||
raise Exception('Testing failed: invalid output') | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_by_direct_model_download(self): | |||
model_dir = snapshot_download(self.model_id) | |||
image_body_reshaping = pipeline( | |||
Tasks.image_body_reshaping, model=model_dir) | |||
self.pipeline_inference(image_body_reshaping, self.test_image) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_modelhub(self): | |||
image_body_reshaping = pipeline( | |||
Tasks.image_body_reshaping, model=self.model_id) | |||
self.pipeline_inference(image_body_reshaping, self.test_image) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_modelhub_default_model(self): | |||
image_body_reshaping = pipeline(Tasks.image_body_reshaping) | |||
self.pipeline_inference(image_body_reshaping, self.test_image) | |||
@unittest.skip('demo compatibility test is only enabled on a needed-basis') | |||
def test_demo_compatibility(self): | |||
self.compatibility_check() | |||
if __name__ == '__main__': | |||
unittest.main() |