Browse Source

[to #42322933]add image_body_reshaping code

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10217723

    * add image_body_reshaping code
master
ryan.yy yingda.chen 3 years ago
parent
commit
ff69439c4f
17 changed files with 1737 additions and 1 deletions
  1. +3
    -0
      data/test/images/image_body_reshaping.jpg
  2. +2
    -0
      modelscope/metainfo.py
  3. +20
    -0
      modelscope/models/cv/image_body_reshaping/__init__.py
  4. +128
    -0
      modelscope/models/cv/image_body_reshaping/image_body_reshaping.py
  5. +189
    -0
      modelscope/models/cv/image_body_reshaping/model.py
  6. +339
    -0
      modelscope/models/cv/image_body_reshaping/person_info.py
  7. +0
    -0
      modelscope/models/cv/image_body_reshaping/pose_estimator/__init__.py
  8. +272
    -0
      modelscope/models/cv/image_body_reshaping/pose_estimator/body.py
  9. +141
    -0
      modelscope/models/cv/image_body_reshaping/pose_estimator/model.py
  10. +33
    -0
      modelscope/models/cv/image_body_reshaping/pose_estimator/util.py
  11. +507
    -0
      modelscope/models/cv/image_body_reshaping/slim_utils.py
  12. +1
    -0
      modelscope/outputs.py
  13. +2
    -0
      modelscope/pipelines/builder.py
  14. +40
    -0
      modelscope/pipelines/cv/image_body_reshaping_pipeline.py
  15. +1
    -1
      modelscope/utils/constant.py
  16. +1
    -0
      requirements/cv.txt
  17. +58
    -0
      tests/pipelines/test_image_body_reshaping.py

+ 3
- 0
data/test/images/image_body_reshaping.jpg View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b2c1119e3d521cf2e583b1e85fc9c9afd1d44954b433135039a98050a730932d
size 1127557

+ 2
- 0
modelscope/metainfo.py View File

@@ -43,6 +43,7 @@ class Models(object):
face_human_hand_detection = 'face-human-hand-detection' face_human_hand_detection = 'face-human-hand-detection'
face_emotion = 'face-emotion' face_emotion = 'face-emotion'
product_segmentation = 'product-segmentation' product_segmentation = 'product-segmentation'
image_body_reshaping = 'image-body-reshaping'


# EasyCV models # EasyCV models
yolox = 'YOLOX' yolox = 'YOLOX'
@@ -187,6 +188,7 @@ class Pipelines(object):
face_human_hand_detection = 'face-human-hand-detection' face_human_hand_detection = 'face-human-hand-detection'
face_emotion = 'face-emotion' face_emotion = 'face-emotion'
product_segmentation = 'product-segmentation' product_segmentation = 'product-segmentation'
image_body_reshaping = 'flow-based-body-reshaping'


# nlp tasks # nlp tasks
automatic_post_editing = 'automatic-post-editing' automatic_post_editing = 'automatic-post-editing'


+ 20
- 0
modelscope/models/cv/image_body_reshaping/__init__.py View File

@@ -0,0 +1,20 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING

from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .image_body_reshaping import ImageBodyReshaping

else:
_import_structure = {'image_body_reshaping': ['ImageBodyReshaping']}

import sys

sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

+ 128
- 0
modelscope/models/cv/image_body_reshaping/image_body_reshaping.py View File

@@ -0,0 +1,128 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from typing import Any, Dict

import cv2
import numpy as np
import torch

from modelscope.metainfo import Models
from modelscope.models.base import Tensor, TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
from .model import FlowGenerator
from .person_info import PersonInfo
from .pose_estimator.body import Body
from .slim_utils import image_warp_grid1, resize_on_long_side

logger = get_logger()

__all__ = ['ImageBodyReshaping']


@MODELS.register_module(
Tasks.image_body_reshaping, module_name=Models.image_body_reshaping)
class ImageBodyReshaping(TorchModel):

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the image body reshaping model from the `model_dir` path.

Args:
model_dir (str): the model path.
"""
super().__init__(model_dir, *args, **kwargs)

if torch.cuda.is_available():
self.device = torch.device('cuda')
else:
self.device = torch.device('cpu')

self.degree = 1.0
self.reshape_model = FlowGenerator(n_channels=16).to(self.device)
model_path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE)
checkpoints = torch.load(model_path, map_location=torch.device('cpu'))
self.reshape_model.load_state_dict(
checkpoints['state_dict'], strict=True)
self.reshape_model.eval()
logger.info('load body reshaping model done')

pose_model_ckpt = os.path.join(model_dir, 'body_pose_model.pth')
self.pose_esti = Body(pose_model_ckpt, self.device)
logger.info('load pose model done')

def pred_joints(self, img):
if img is None:
return None
small_src, resize_scale = resize_on_long_side(img, 300)
body_joints = self.pose_esti(small_src)

if body_joints.shape[0] >= 1:
body_joints[:, :, :2] = body_joints[:, :, :2] / resize_scale

return body_joints

def pred_flow(self, img):

body_joints = self.pred_joints(img)
small_size = 1200

if img.shape[0] > small_size or img.shape[1] > small_size:
_img, _scale = resize_on_long_side(img, small_size)
body_joints[:, :, :2] = body_joints[:, :, :2] * _scale
else:
_img = img

# We only reshape one person
if body_joints.shape[0] < 1 or body_joints.shape[0] > 1:
return None

person = PersonInfo(body_joints[0])

with torch.no_grad():
person_pred = person.pred_flow(_img, self.reshape_model,
self.device)

flow = np.dstack((person_pred['rDx'], person_pred['rDy']))

scale = img.shape[0] * 1.0 / flow.shape[0]

flow = cv2.resize(flow, (img.shape[1], img.shape[0]))
flow *= scale

return flow

def warp(self, src_img, flow):

X_flow = flow[..., 0]
Y_flow = flow[..., 1]

X_flow = np.ascontiguousarray(X_flow)
Y_flow = np.ascontiguousarray(Y_flow)

pred = image_warp_grid1(X_flow, Y_flow, src_img, 1.0, 0, 0)
return pred

def inference(self, img):
img = img.cpu().numpy()
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
flow = self.pred_flow(img)

if flow is None:
return img

assert flow.shape[:2] == img.shape[:2]

mag, ang = cv2.cartToPolar(flow[..., 0] + 1e-8, flow[..., 1] + 1e-8)
mag -= 3
mag[mag <= 0] = 0

x, y = cv2.polarToCart(mag, ang, angleInDegrees=False)
flow = np.dstack((x, y))

flow *= self.degree
pred = self.warp(img, flow)
out_img = np.clip(pred, 0, 255)
logger.info('model inference done')

return out_img.astype(np.uint8)

+ 189
- 0
modelscope/models/cv/image_body_reshaping/model.py View File

@@ -0,0 +1,189 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torch.nn as nn
import torch.nn.functional as F


class ConvLayer(nn.Module):

def __init__(self, in_ch, out_ch):
super(ConvLayer, self).__init__()

self.conv = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=0),
nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True))

def forward(self, x):
x = self.conv(x)
return x


class SASA(nn.Module):

def __init__(self, in_dim):
super(SASA, self).__init__()
self.chanel_in = in_dim

self.query_conv = nn.Conv2d(
in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
self.key_conv = nn.Conv2d(
in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
self.value_conv = nn.Conv2d(
in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.mag_conv = nn.Conv2d(
in_channels=5, out_channels=in_dim // 32, kernel_size=1)

self.gamma = nn.Parameter(torch.zeros(1))

self.softmax = nn.Softmax(dim=-1) #
self.sigmoid = nn.Sigmoid()

def structure_encoder(self, paf_mag, target_height, target_width):
torso_mask = torch.sum(paf_mag[:, 1:3, :, :], dim=1, keepdim=True)
torso_mask = torch.clamp(torso_mask, 0, 1)

arms_mask = torch.sum(paf_mag[:, 4:8, :, :], dim=1, keepdim=True)
arms_mask = torch.clamp(arms_mask, 0, 1)

legs_mask = torch.sum(paf_mag[:, 8:12, :, :], dim=1, keepdim=True)
legs_mask = torch.clamp(legs_mask, 0, 1)

fg_mask = paf_mag[:, 12, :, :].unsqueeze(1)
bg_mask = 1 - fg_mask
Y = torch.cat((arms_mask, torso_mask, legs_mask, fg_mask, bg_mask),
dim=1)
Y = F.interpolate(Y, size=(target_height, target_width), mode='area')
return Y

def forward(self, X, PAF_mag):
"""extract self-attention features.
Args:
X : input feature maps( B x C x H x W)
PAF_mag : ( B x C x H x W), 1 denotes connectivity, 0 denotes non-connectivity

Returns:
out : self attention value + input feature
Y: B X N X N (N is Width*Height)
"""

m_batchsize, C, height, width = X.size()

Y = self.structure_encoder(PAF_mag, height, width)

connectivity_mask_vec = self.mag_conv(Y).view(m_batchsize, -1,
width * height)
affinity = torch.bmm(
connectivity_mask_vec.permute(0, 2, 1), connectivity_mask_vec)
affinity_centered = affinity - torch.mean(affinity)
affinity_sigmoid = self.sigmoid(affinity_centered)

proj_query = self.query_conv(X).view(m_batchsize, -1,
width * height).permute(0, 2, 1)
proj_key = self.key_conv(X).view(m_batchsize, -1, width * height)
selfatten_map = torch.bmm(proj_query, proj_key)
selfatten_centered = selfatten_map - torch.mean(
selfatten_map) # centering
selfatten_sigmoid = self.sigmoid(selfatten_centered)

SASA_map = selfatten_sigmoid * affinity_sigmoid

proj_value = self.value_conv(X).view(m_batchsize, -1, width * height)

out = torch.bmm(proj_value, SASA_map.permute(0, 2, 1))
out = out.view(m_batchsize, C, height, width)

out = self.gamma * out + X
return out, Y


class FlowGenerator(nn.Module):

def __init__(self, n_channels, deep_supervision=False):
super(FlowGenerator, self).__init__()
self.deep_supervision = deep_supervision

self.Encoder = nn.Sequential(
ConvLayer(n_channels, 64),
ConvLayer(64, 64),
nn.MaxPool2d(2),
ConvLayer(64, 128),
ConvLayer(128, 128),
nn.MaxPool2d(2),
ConvLayer(128, 256),
ConvLayer(256, 256),
nn.MaxPool2d(2),
ConvLayer(256, 512),
ConvLayer(512, 512),
nn.MaxPool2d(2),
ConvLayer(512, 1024),
ConvLayer(1024, 1024),
ConvLayer(1024, 1024),
ConvLayer(1024, 1024),
ConvLayer(1024, 1024),
)

self.SASA = SASA(in_dim=1024)

self.Decoder = nn.Sequential(
ConvLayer(1024, 1024),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
ConvLayer(1024, 512),
ConvLayer(512, 512),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
ConvLayer(512, 256),
ConvLayer(256, 256),
ConvLayer(256, 128),
ConvLayer(128, 64),
ConvLayer(64, 32),
nn.Conv2d(32, 2, kernel_size=1, padding=0),
nn.Tanh(),
nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True),
)

dilation_ksize = 17
self.dilation = torch.nn.MaxPool2d(
kernel_size=dilation_ksize,
stride=1,
padding=int((dilation_ksize - 1) / 2))

def warp(self, x, flow, mode='bilinear', padding_mode='zeros', coff=0.2):
n, c, h, w = x.size()
yv, xv = torch.meshgrid([torch.arange(h), torch.arange(w)])
xv = xv.float() / (w - 1) * 2.0 - 1
yv = yv.float() / (h - 1) * 2.0 - 1
grid = torch.cat((xv.unsqueeze(-1), yv.unsqueeze(-1)), -1).unsqueeze(0)
grid = grid.to(flow.device)
grid_x = grid + 2 * flow * coff
warp_x = F.grid_sample(x, grid_x, mode=mode, padding_mode=padding_mode)
return warp_x

def forward(self, img, skeleton_map, coef=0.2):
"""extract self-attention features.
Args:
img : input numpy image
skeleton_map : skeleton map of input image
coef: warp degree

Returns:
warp_x : warped image
flow: predicted flow
"""

img_concat = torch.cat((img, skeleton_map), dim=1)
X = self.Encoder(img_concat)

_, _, height, width = X.size()

# directly get PAF magnitude from skeleton maps via dilation
PAF_mag = self.dilation((skeleton_map + 1.0) * 0.5)

out, Y = self.SASA(X, PAF_mag)
flow = self.Decoder(out)

flow = flow.permute(0, 2, 3, 1) # [n, 2, h, w] ==> [n, h, w, 2]

warp_x = self.warp(img, flow, coff=coef)
warp_x = torch.clamp(warp_x, min=-1.0, max=1.0)

return warp_x, flow

+ 339
- 0
modelscope/models/cv/image_body_reshaping/person_info.py View File

@@ -0,0 +1,339 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import copy

import cv2
import numpy as np
import torch

from .slim_utils import (enlarge_box_tblr, gen_skeleton_map,
get_map_fusion_map_cuda, get_mask_bbox,
resize_on_long_side)


class PersonInfo(object):

def __init__(self, joints):
self.joints = joints
self.flow = None
self.pad_boder = False
self.height_expand = 0
self.width_expand = 0
self.coeff = 0.2
self.network_input_W = 256
self.network_input_H = 256
self.divider = 20
self.flow_scales = ['upper_2']

def update_attribute(self, pad_boder, height_expand, width_expand):
self.pad_boder = pad_boder
self.height_expand = height_expand
self.width_expand = width_expand
if pad_boder:
self.joints[:, 0] += width_expand
self.joints[:, 1] += height_expand

def pred_flow(self, img, flow_net, device):
with torch.no_grad():
if img is None:
print('image is none')
self.flow = None

if len(img.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

if self.pad_boder:
height_expand = self.height_expand
width_expand = self.width_expand
pad_img = cv2.copyMakeBorder(
img,
height_expand,
height_expand,
width_expand,
width_expand,
cv2.BORDER_CONSTANT,
value=(127, 127, 127))

else:
height_expand = 0
width_expand = 0
pad_img = img.copy()

canvas = np.zeros(
shape=(pad_img.shape[0], pad_img.shape[1]), dtype=np.float32)

self.human_joint_box = self.__joint_to_body_box()

self.human_box = enlarge_box_tblr(
self.human_joint_box, pad_img, ratio=0.25)
human_box_height = self.human_box[1] - self.human_box[0]
human_box_width = self.human_box[3] - self.human_box[2]

self.leg_joint_box = self.__joint_to_leg_box()
self.leg_box = enlarge_box_tblr(
self.leg_joint_box, pad_img, ratio=0.25)

self.arm_joint_box = self.__joint_to_arm_box()
self.arm_box = enlarge_box_tblr(
self.arm_joint_box, pad_img, ratio=0.1)

x_flows = []
y_flows = []
multi_bbox = []

for scale in self.flow_scales: # better for metric
scale_value = float(scale.split('_')[-1])

arm_box = copy.deepcopy(self.arm_box)

if arm_box[0] is None:
arm_box = self.human_box

arm_box_height = arm_box[1] - arm_box[0]
arm_box_width = arm_box[3] - arm_box[2]

roi_bbox = None

if arm_box_width < human_box_width * 0.1 or arm_box_height < human_box_height * 0.1:
roi_bbox = self.human_box
else:
arm_box = enlarge_box_tblr(
arm_box, pad_img, ratio=scale_value)
if scale == 'upper_0.2':
arm_box[0] = min(arm_box[0], int(self.joints[0][1]))
if scale.startswith('upper'):
roi_bbox = [
max(self.human_box[0], arm_box[0]),
min(self.human_box[1], arm_box[1]),
max(self.human_box[2], arm_box[2]),
min(self.human_box[3], arm_box[3])
]
if roi_bbox[1] - roi_bbox[0] < 1 or roi_bbox[
3] - roi_bbox[2] < 1:
continue

elif scale.startswith('lower'):
roi_bbox = [
max(self.human_box[0], self.leg_box[0]),
min(self.human_box[1], self.leg_box[1]),
max(self.human_box[2], self.leg_box[2]),
min(self.human_box[3], self.leg_box[3])
]

if roi_bbox[1] - roi_bbox[0] < 1 or roi_bbox[
3] - roi_bbox[2] < 1:
continue

skel_map, roi_bbox = gen_skeleton_map(
self.joints, 'depth', input_roi_box=roi_bbox)

if roi_bbox is None:
continue

if skel_map.dtype != np.float32:
skel_map = skel_map.astype(np.float32)

skel_map -= 1.0 # [0,2] ->[-1,1]

multi_bbox.append(roi_bbox)

roi_bbox_height = roi_bbox[1] - roi_bbox[0]
roi_bbox_width = roi_bbox[3] - roi_bbox[2]

assert skel_map.shape[0] == roi_bbox_height
assert skel_map.shape[1] == roi_bbox_width
roi_height_pad = roi_bbox_height // self.divider
roi_width_pad = roi_bbox_width // self.divider
paded_roi_h = roi_bbox_height + 2 * roi_height_pad
paded_roi_w = roi_bbox_width + 2 * roi_width_pad

roi_height_pad_joint = skel_map.shape[0] // self.divider
roi_width_pad_joint = skel_map.shape[1] // self.divider
skel_map = np.pad(
skel_map,
((roi_height_pad_joint, roi_height_pad_joint),
(roi_width_pad_joint, roi_width_pad_joint), (0, 0)),
'constant',
constant_values=-1)

skel_map_resized = cv2.resize(
skel_map, (self.network_input_W, self.network_input_H))

skel_map_resized[skel_map_resized < 0] = -1.0
skel_map_resized[skel_map_resized > -0.5] = 1.0
skel_map_transformed = torch.from_numpy(
skel_map_resized.transpose((2, 0, 1)))

roi_npy = pad_img[roi_bbox[0]:roi_bbox[1],
roi_bbox[2]:roi_bbox[3], :].copy()
if roi_npy.dtype != np.float32:
roi_npy = roi_npy.astype(np.float32)

roi_npy = np.pad(roi_npy,
((roi_height_pad, roi_height_pad),
(roi_width_pad, roi_width_pad), (0, 0)),
'edge')

roi_npy = roi_npy[:, :, ::-1]

roi_npy = cv2.resize(
roi_npy, (self.network_input_W, self.network_input_H))

roi_npy *= 1.0 / 255
roi_npy -= 0.5
roi_npy *= 2

rgb_tensor = torch.from_numpy(roi_npy.transpose((2, 0, 1)))

rgb_tensor = rgb_tensor.unsqueeze(0).to(device)
skel_map_tensor = skel_map_transformed.unsqueeze(0).to(device)
warped_img_val, flow_field_val = flow_net(
rgb_tensor, skel_map_tensor
) # inference, connectivity_mask [1,12,16,16]
flow_field_val = flow_field_val.detach().squeeze().cpu().numpy(
)

flow_field_val = cv2.resize(
flow_field_val, (paded_roi_w, paded_roi_h),
interpolation=cv2.INTER_LINEAR)
flow_field_val[..., 0] = flow_field_val[
..., 0] * paded_roi_w * 0.5 * 2 * self.coeff
flow_field_val[..., 1] = flow_field_val[
..., 1] * paded_roi_h * 0.5 * 2 * self.coeff

# remove pad areas
flow_field_val = flow_field_val[
roi_height_pad:flow_field_val.shape[0] - roi_height_pad,
roi_width_pad:flow_field_val.shape[1] - roi_width_pad, :]

diffuse_width = max(roi_bbox_width // 3, 1)
diffuse_height = max(roi_bbox_height // 3, 1)
assert roi_bbox_width == flow_field_val.shape[1]
assert roi_bbox_height == flow_field_val.shape[0]

origin_flow = np.zeros(
(pad_img.shape[0] + 2 * diffuse_height,
pad_img.shape[1] + 2 * diffuse_width, 2),
dtype=np.float32)

flow_field_val = np.pad(flow_field_val,
((diffuse_height, diffuse_height),
(diffuse_width, diffuse_width),
(0, 0)), 'linear_ramp')

origin_flow[roi_bbox[0]:roi_bbox[1] + 2 * diffuse_height,
roi_bbox[2]:roi_bbox[3]
+ 2 * diffuse_width] = flow_field_val

origin_flow = origin_flow[diffuse_height:-diffuse_height,
diffuse_width:-diffuse_width, :]

x_flows.append(origin_flow[..., 0])
y_flows.append(origin_flow[..., 1])

if len(x_flows) == 0:
return {
'rDx': np.zeros(canvas.shape[:2], dtype=np.float32),
'rDy': np.zeros(canvas.shape[:2], dtype=np.float32),
'multi_bbox': multi_bbox,
'x_fusion_map':
np.ones(canvas.shape[:2], dtype=np.float32),
'y_fusion_map':
np.ones(canvas.shape[:2], dtype=np.float32)
}
else:
origin_rDx, origin_rDy, x_fusion_map, y_fusion_map = self.blend_multiscale_flow(
x_flows, y_flows, device=device)

return {
'rDx': origin_rDx,
'rDy': origin_rDy,
'multi_bbox': multi_bbox,
'x_fusion_map': x_fusion_map,
'y_fusion_map': y_fusion_map
}

@staticmethod
def blend_multiscale_flow(x_flows, y_flows, device=None):
scale_num = len(x_flows)
if scale_num == 1:
return x_flows[0], y_flows[0], np.ones_like(
x_flows[0]), np.ones_like(x_flows[0])

origin_rDx = np.zeros((x_flows[0].shape[0], x_flows[0].shape[1]),
dtype=np.float32)
origin_rDy = np.zeros((y_flows[0].shape[0], y_flows[0].shape[1]),
dtype=np.float32)

x_fusion_map, x_acc_map = get_map_fusion_map_cuda(
x_flows, 1, device=device)
y_fusion_map, y_acc_map = get_map_fusion_map_cuda(
y_flows, 1, device=device)

x_flow_map = 1.0 / x_fusion_map
y_flow_map = 1.0 / y_fusion_map

all_acc_map = x_acc_map + y_acc_map
all_acc_map = all_acc_map.astype(np.uint8)
roi_box = get_mask_bbox(all_acc_map, threshold=1)

if roi_box[0] is None or roi_box[1] - roi_box[0] <= 0 or roi_box[
3] - roi_box[2] <= 0:
roi_box = [0, x_flow_map.shape[0], 0, x_flow_map.shape[1]]

roi_x_flow_map = x_flow_map[roi_box[0]:roi_box[1],
roi_box[2]:roi_box[3]]
roi_y_flow_map = y_flow_map[roi_box[0]:roi_box[1],
roi_box[2]:roi_box[3]]

roi_width = roi_x_flow_map.shape[1]
roi_height = roi_x_flow_map.shape[0]

roi_x_flow_map, scale = resize_on_long_side(roi_x_flow_map, 320)
roi_y_flow_map, scale = resize_on_long_side(roi_y_flow_map, 320)

roi_x_flow_map = cv2.blur(roi_x_flow_map, (55, 55))
roi_y_flow_map = cv2.blur(roi_y_flow_map, (55, 55))

roi_x_flow_map = cv2.resize(roi_x_flow_map, (roi_width, roi_height))
roi_y_flow_map = cv2.resize(roi_y_flow_map, (roi_width, roi_height))

x_flow_map[roi_box[0]:roi_box[1],
roi_box[2]:roi_box[3]] = roi_x_flow_map
y_flow_map[roi_box[0]:roi_box[1],
roi_box[2]:roi_box[3]] = roi_y_flow_map

for i in range(scale_num):
origin_rDx += x_flows[i]
origin_rDy += y_flows[i]

origin_rDx *= x_flow_map
origin_rDy *= y_flow_map

return origin_rDx, origin_rDy, x_flow_map, y_flow_map

def __joint_to_body_box(self):
joint_left = int(np.min(self.joints, axis=0)[0])
joint_right = int(np.max(self.joints, axis=0)[0])
joint_top = int(np.min(self.joints, axis=0)[1])
joint_bottom = int(np.max(self.joints, axis=0)[1])
return [joint_top, joint_bottom, joint_left, joint_right]

def __joint_to_leg_box(self):
leg_joints = self.joints[8:, :]
if np.max(leg_joints, axis=0)[2] < 0.05:
return [0, 0, 0, 0]
joint_left = int(np.min(leg_joints, axis=0)[0])
joint_right = int(np.max(leg_joints, axis=0)[0])
joint_top = int(np.min(leg_joints, axis=0)[1])
joint_bottom = int(np.max(leg_joints, axis=0)[1])
return [joint_top, joint_bottom, joint_left, joint_right]

def __joint_to_arm_box(self):
arm_joints = self.joints[2:8, :]
if np.max(arm_joints, axis=0)[2] < 0.05:
return [0, 0, 0, 0]
joint_left = int(np.min(arm_joints, axis=0)[0])
joint_right = int(np.max(arm_joints, axis=0)[0])
joint_top = int(np.min(arm_joints, axis=0)[1])
joint_bottom = int(np.max(arm_joints, axis=0)[1])
return [joint_top, joint_bottom, joint_left, joint_right]

+ 0
- 0
modelscope/models/cv/image_body_reshaping/pose_estimator/__init__.py View File


+ 272
- 0
modelscope/models/cv/image_body_reshaping/pose_estimator/body.py View File

@@ -0,0 +1,272 @@
# The implementation is based on openpose, available at https://github.com/Hzzone/pytorch-openpose.

import math

import cv2
import numpy as np
import torch
from scipy.ndimage.filters import gaussian_filter

from .model import BodyposeModel
from .util import pad_rightdown_corner, transfer


class Body(object):

def __init__(self, model_path, device):
self.model = BodyposeModel().to(device)
model_dict = transfer(self.model, torch.load(model_path))
self.model.load_state_dict(model_dict)
self.model.eval()

def __call__(self, oriImg):
scale_search = [0.5]
boxsize = 368
stride = 8
padValue = 128
thre1 = 0.1
thre2 = 0.05
bodyparts = 18
multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19))
paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))

for m in range(len(multiplier)):
scale = multiplier[m]
imageToTest = cv2.resize(
oriImg, (0, 0),
fx=scale,
fy=scale,
interpolation=cv2.INTER_CUBIC)
imageToTest_padded, pad = pad_rightdown_corner(
imageToTest, stride, padValue)
im = np.transpose(
np.float32(imageToTest_padded[:, :, :, np.newaxis]),
(3, 2, 0, 1)) / 256 - 0.5
im = np.ascontiguousarray(im)

data = torch.from_numpy(im).float()
if torch.cuda.is_available():
data = data.cuda()
with torch.no_grad():
Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data)
Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy()
Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy()

# extract outputs, resize, and remove padding
heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2),
(1, 2, 0)) # output 1 is heatmaps
heatmap = cv2.resize(
heatmap, (0, 0),
fx=stride,
fy=stride,
interpolation=cv2.INTER_CUBIC)
heatmap = heatmap[:imageToTest_padded.shape[0]
- pad[2], :imageToTest_padded.shape[1]
- pad[3], :]
heatmap = cv2.resize(
heatmap, (oriImg.shape[1], oriImg.shape[0]),
interpolation=cv2.INTER_CUBIC)

paf = np.transpose(np.squeeze(Mconv7_stage6_L1),
(1, 2, 0)) # output 0 is PAFs
paf = cv2.resize(
paf, (0, 0),
fx=stride,
fy=stride,
interpolation=cv2.INTER_CUBIC)
paf = paf[:imageToTest_padded.shape[0]
- pad[2], :imageToTest_padded.shape[1] - pad[3], :]
paf = cv2.resize(
paf, (oriImg.shape[1], oriImg.shape[0]),
interpolation=cv2.INTER_CUBIC)

heatmap_avg += heatmap_avg + heatmap / len(multiplier)
paf_avg += +paf / len(multiplier)

all_peaks = []
peak_counter = 0

for part in range(bodyparts):
map_ori = heatmap_avg[:, :, part]
one_heatmap = gaussian_filter(map_ori, sigma=3)

map_left = np.zeros(one_heatmap.shape)
map_left[1:, :] = one_heatmap[:-1, :]
map_right = np.zeros(one_heatmap.shape)
map_right[:-1, :] = one_heatmap[1:, :]
map_up = np.zeros(one_heatmap.shape)
map_up[:, 1:] = one_heatmap[:, :-1]
map_down = np.zeros(one_heatmap.shape)
map_down[:, :-1] = one_heatmap[:, 1:]

peaks_binary = np.logical_and.reduce(
(one_heatmap >= map_left, one_heatmap >= map_right,
one_heatmap >= map_up, one_heatmap >= map_down,
one_heatmap > thre1))
peaks = list(
zip(np.nonzero(peaks_binary)[1],
np.nonzero(peaks_binary)[0])) # note reverse
peaks_with_score = [x + (map_ori[x[1], x[0]], ) for x in peaks]
peak_id = range(peak_counter, peak_counter + len(peaks))
peaks_with_score_and_id = [
peaks_with_score[i] + (peak_id[i], )
for i in range(len(peak_id))
]

all_peaks.append(peaks_with_score_and_id)
peak_counter += len(peaks)

# find connection in the specified sequence, center 29 is in the position 15
limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9],
[9, 10], [10, 11], [2, 12], [12, 13], [13, 14], [2, 1],
[1, 15], [15, 17], [1, 16], [16, 18], [3, 17], [6, 18]]
# the middle joints heatmap correpondence
mapIdx = [[31, 32], [39, 40], [33, 34], [35, 36], [41, 42], [43, 44],
[19, 20], [21, 22], [23, 24], [25, 26], [27, 28], [29, 30],
[47, 48], [49, 50], [53, 54], [51, 52], [55, 56], [37, 38],
[45, 46]]

connection_all = []
special_k = []
mid_num = 10

for k in range(len(mapIdx)):
score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]]
candA = all_peaks[limbSeq[k][0] - 1]
candB = all_peaks[limbSeq[k][1] - 1]
nA = len(candA)
nB = len(candB)
if (nA != 0 and nB != 0):
connection_candidate = []
for i in range(nA):
for j in range(nB):
vec = np.subtract(candB[j][:2], candA[i][:2])
norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1])
norm = max(0.001, norm)
vec = np.divide(vec, norm)

startend = list(
zip(
np.linspace(
candA[i][0], candB[j][0], num=mid_num),
np.linspace(
candA[i][1], candB[j][1], num=mid_num)))

vec_x = np.array([
score_mid[int(round(startend[item][1])),
int(round(startend[item][0])), 0]
for item in range(len(startend))
])
vec_y = np.array([
score_mid[int(round(startend[item][1])),
int(round(startend[item][0])), 1]
for item in range(len(startend))
])

score_midpts = np.multiply(
vec_x, vec[0]) + np.multiply(vec_y, vec[1])
temp1 = sum(score_midpts) / len(score_midpts)
temp2 = min(0.5 * oriImg.shape[0] / norm - 1, 0)
score_with_dist_prior = temp1 + temp2
criterion1 = len(np.nonzero(
score_midpts > thre2)[0]) > 0.8 * len(score_midpts)
criterion2 = score_with_dist_prior > 0
if criterion1 and criterion2:
connection_candidate.append([
i, j, score_with_dist_prior,
score_with_dist_prior + candA[i][2]
+ candB[j][2]
])

connection_candidate = sorted(
connection_candidate, key=lambda x: x[2], reverse=True)
connection = np.zeros((0, 5))
for c in range(len(connection_candidate)):
i, j, s = connection_candidate[c][0:3]
if (i not in connection[:, 3]
and j not in connection[:, 4]):
connection = np.vstack(
[connection, [candA[i][3], candB[j][3], s, i, j]])
if (len(connection) >= min(nA, nB)):
break

connection_all.append(connection)
else:
special_k.append(k)
connection_all.append([])

# last number in each row is the total parts number of that person
# the second last number in each row is the score of the overall configuration
subset = -1 * np.ones((0, 20))
candidate = np.array(
[item for sublist in all_peaks for item in sublist])

for k in range(len(mapIdx)):
if k not in special_k:
partAs = connection_all[k][:, 0]
partBs = connection_all[k][:, 1]
indexA, indexB = np.array(limbSeq[k]) - 1

for i in range(len(connection_all[k])): # = 1:size(temp,1)
found = 0
subset_idx = [-1, -1]
for j in range(len(subset)): # 1:size(subset,1):
if subset[j][indexA] == partAs[i] or subset[j][
indexB] == partBs[i]:
subset_idx[found] = j
found += 1

if found == 1:
j = subset_idx[0]
if subset[j][indexB] != partBs[i]:
subset[j][indexB] = partBs[i]
subset[j][-1] += 1
subset[j][-2] += candidate[
partBs[i].astype(int),
2] + connection_all[k][i][2]
elif found == 2: # if found 2 and disjoint, merge them
j1, j2 = subset_idx
tmp1 = (subset[j1] >= 0).astype(int)
tmp2 = (subset[j2] >= 0).astype(int)
membership = (tmp1 + tmp2)[:-2]
if len(np.nonzero(membership == 2)[0]) == 0: # merge
subset[j1][:-2] += (subset[j2][:-2] + 1)
subset[j1][-2:] += subset[j2][-2:]
subset[j1][-2] += connection_all[k][i][2]
subset = np.delete(subset, j2, 0)
else: # as like found == 1
subset[j1][indexB] = partBs[i]
subset[j1][-1] += 1
subset[j1][-2] += candidate[
partBs[i].astype(int),
2] + connection_all[k][i][2]

# if find no partA in the subset, create a new subset
elif not found and k < 17:
row = -1 * np.ones(20)
row[indexA] = partAs[i]
row[indexB] = partBs[i]
row[-1] = 2
row[-2] = sum(
candidate[connection_all[k][i, :2].astype(int),
2]) + connection_all[k][i][2]
subset = np.vstack([subset, row])
# delete some rows of subset which has few parts occur
deleteIdx = []
for i in range(len(subset)):
if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4:
deleteIdx.append(i)
subset = np.delete(subset, deleteIdx, axis=0)

# subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts
# candidate: x, y, score, id
count = subset.shape[0]
joints = np.zeros(shape=(count, bodyparts, 3))

for i in range(count):
for j in range(bodyparts):
joints[i, j, :3] = candidate[int(subset[i, j]), :3]
confidence = 1.0 if subset[i, j] >= 0 else 0.0
joints[i, j, 2] *= confidence
return joints

+ 141
- 0
modelscope/models/cv/image_body_reshaping/pose_estimator/model.py View File

@@ -0,0 +1,141 @@
# The implementation is based on openpose, available at https://github.com/Hzzone/pytorch-openpose.

from collections import OrderedDict

import torch
import torch.nn as nn


def make_layers(block, no_relu_layers):
layers = []
for layer_name, v in block.items():
if 'pool' in layer_name:
layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], padding=v[2])
layers.append((layer_name, layer))
else:
conv2d = nn.Conv2d(
in_channels=v[0],
out_channels=v[1],
kernel_size=v[2],
stride=v[3],
padding=v[4])
layers.append((layer_name, conv2d))
if layer_name not in no_relu_layers:
layers.append(('relu_' + layer_name, nn.ReLU(inplace=True)))

return nn.Sequential(OrderedDict(layers))


class BodyposeModel(nn.Module):

def __init__(self):
super(BodyposeModel, self).__init__()

# these layers have no relu layer
no_relu_layers = [
'conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1',
'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2',
'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1',
'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1'
]
blocks = {}
block0 = OrderedDict([('conv1_1', [3, 64, 3, 1, 1]),
('conv1_2', [64, 64, 3, 1, 1]),
('pool1_stage1', [2, 2, 0]),
('conv2_1', [64, 128, 3, 1, 1]),
('conv2_2', [128, 128, 3, 1, 1]),
('pool2_stage1', [2, 2, 0]),
('conv3_1', [128, 256, 3, 1, 1]),
('conv3_2', [256, 256, 3, 1, 1]),
('conv3_3', [256, 256, 3, 1, 1]),
('conv3_4', [256, 256, 3, 1, 1]),
('pool3_stage1', [2, 2, 0]),
('conv4_1', [256, 512, 3, 1, 1]),
('conv4_2', [512, 512, 3, 1, 1]),
('conv4_3_CPM', [512, 256, 3, 1, 1]),
('conv4_4_CPM', [256, 128, 3, 1, 1])])

# Stage 1
block1_1 = OrderedDict([('conv5_1_CPM_L1', [128, 128, 3, 1, 1]),
('conv5_2_CPM_L1', [128, 128, 3, 1, 1]),
('conv5_3_CPM_L1', [128, 128, 3, 1, 1]),
('conv5_4_CPM_L1', [128, 512, 1, 1, 0]),
('conv5_5_CPM_L1', [512, 38, 1, 1, 0])])

block1_2 = OrderedDict([('conv5_1_CPM_L2', [128, 128, 3, 1, 1]),
('conv5_2_CPM_L2', [128, 128, 3, 1, 1]),
('conv5_3_CPM_L2', [128, 128, 3, 1, 1]),
('conv5_4_CPM_L2', [128, 512, 1, 1, 0]),
('conv5_5_CPM_L2', [512, 19, 1, 1, 0])])
blocks['block1_1'] = block1_1
blocks['block1_2'] = block1_2

self.model0 = make_layers(block0, no_relu_layers)

# Stages 2 - 6
for i in range(2, 7):
blocks['block%d_1' % i] = OrderedDict([
('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]),
('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]),
('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]),
('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]),
('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]),
('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]),
('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0])
])

blocks['block%d_2' % i] = OrderedDict([
('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]),
('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]),
('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]),
('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]),
('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]),
('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]),
('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0])
])

for k in blocks.keys():
blocks[k] = make_layers(blocks[k], no_relu_layers)

self.model1_1 = blocks['block1_1']
self.model2_1 = blocks['block2_1']
self.model3_1 = blocks['block3_1']
self.model4_1 = blocks['block4_1']
self.model5_1 = blocks['block5_1']
self.model6_1 = blocks['block6_1']

self.model1_2 = blocks['block1_2']
self.model2_2 = blocks['block2_2']
self.model3_2 = blocks['block3_2']
self.model4_2 = blocks['block4_2']
self.model5_2 = blocks['block5_2']
self.model6_2 = blocks['block6_2']

def forward(self, x):

out1 = self.model0(x)

out1_1 = self.model1_1(out1)
out1_2 = self.model1_2(out1)
out2 = torch.cat([out1_1, out1_2, out1], 1)

out2_1 = self.model2_1(out2)
out2_2 = self.model2_2(out2)
out3 = torch.cat([out2_1, out2_2, out1], 1)

out3_1 = self.model3_1(out3)
out3_2 = self.model3_2(out3)
out4 = torch.cat([out3_1, out3_2, out1], 1)

out4_1 = self.model4_1(out4)
out4_2 = self.model4_2(out4)
out5 = torch.cat([out4_1, out4_2, out1], 1)

out5_1 = self.model5_1(out5)
out5_2 = self.model5_2(out5)
out6 = torch.cat([out5_1, out5_2, out1], 1)

out6_1 = self.model6_1(out6)
out6_2 = self.model6_2(out6)

return out6_1, out6_2

+ 33
- 0
modelscope/models/cv/image_body_reshaping/pose_estimator/util.py View File

@@ -0,0 +1,33 @@
# The implementation is based on openpose, available at https://github.com/Hzzone/pytorch-openpose.
import numpy as np


def pad_rightdown_corner(img, stride, padValue):
h = img.shape[0]
w = img.shape[1]

pad = 4 * [None]
pad[0] = 0 # up
pad[1] = 0 # left
pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right

img_padded = img
pad_up = np.tile(img_padded[0:1, :, :] * 0 + padValue, (pad[0], 1, 1))
img_padded = np.concatenate((pad_up, img_padded), axis=0)
pad_left = np.tile(img_padded[:, 0:1, :] * 0 + padValue, (1, pad[1], 1))
img_padded = np.concatenate((pad_left, img_padded), axis=1)
pad_down = np.tile(img_padded[-2:-1, :, :] * 0 + padValue, (pad[2], 1, 1))
img_padded = np.concatenate((img_padded, pad_down), axis=0)
pad_right = np.tile(img_padded[:, -2:-1, :] * 0 + padValue, (1, pad[3], 1))
img_padded = np.concatenate((img_padded, pad_right), axis=1)

return img_padded, pad


def transfer(model, model_weights):
transfered_model_weights = {}
for weights_name in model.state_dict().keys():
transfered_model_weights[weights_name] = model_weights['.'.join(
weights_name.split('.')[1:])]
return transfered_model_weights

+ 507
- 0
modelscope/models/cv/image_body_reshaping/slim_utils.py View File

@@ -0,0 +1,507 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import math
import os
import random

import cv2
import numba
import numpy as np
import torch


def resize_on_long_side(img, long_side=800):
src_height = img.shape[0]
src_width = img.shape[1]

if src_height > src_width:
scale = long_side * 1.0 / src_height
_img = cv2.resize(
img, (int(src_width * scale), long_side),
interpolation=cv2.INTER_LINEAR)
else:
scale = long_side * 1.0 / src_width
_img = cv2.resize(
img, (long_side, int(src_height * scale)),
interpolation=cv2.INTER_LINEAR)

return _img, scale


def point_in_box(pt, box):
pt_x = pt[0]
pt_y = pt[1]

if pt_x >= box[0] and pt_x <= box[0] + box[2] and pt_y >= box[
1] and pt_y <= box[1] + box[3]:
return True
else:
return False


def enlarge_box_tblr(roi_bbox, mask, ratio=0.4, use_long_side=True):
if roi_bbox is None or None in roi_bbox:
return [None, None, None, None]

top = roi_bbox[0]
bottom = roi_bbox[1]
left = roi_bbox[2]
right = roi_bbox[3]

roi_width = roi_bbox[3] - roi_bbox[2]
roi_height = roi_bbox[1] - roi_bbox[0]
right = left + roi_width
bottom = top + roi_height

long_side = roi_width if roi_width > roi_height else roi_height

if use_long_side:
new_left = left - int(long_side * ratio)
else:
new_left = left - int(roi_width * ratio)
new_left = 1 if new_left < 0 else new_left

if use_long_side:
new_top = top - int(long_side * ratio)
else:
new_top = top - int(roi_height * ratio)
new_top = 1 if new_top < 0 else new_top

if use_long_side:
new_right = right + int(long_side * ratio)
else:
new_right = right + int(roi_width * ratio)
new_right = mask.shape[1] - 2 if new_right > mask.shape[1] else new_right

if use_long_side:
new_bottom = bottom + int(long_side * ratio)
else:
new_bottom = bottom + int(roi_height * ratio)
new_bottom = mask.shape[0] - 2 if new_bottom > mask.shape[0] else new_bottom

bbox = [new_top, new_bottom, new_left, new_right]
return bbox


def gen_PAF(image, joints):

assert joints.shape[0] == 18
assert joints.shape[1] == 3

org_h = image.shape[0]
org_w = image.shape[1]
small_image, resize_scale = resize_on_long_side(image, 120)

joints[:, :2] = joints[:, :2] * resize_scale

joint_left = int(np.min(joints, axis=0)[0])
joint_right = int(np.max(joints, axis=0)[0])
joint_top = int(np.min(joints, axis=0)[1])
joint_bottom = int(np.max(joints, axis=0)[1])

limb_width = min(
abs(joint_right - joint_left), abs(joint_bottom - joint_top)) // 6

if limb_width % 2 == 0:
limb_width += 1
kernel_size = limb_width

part_orders = [(5, 11), (2, 8), (5, 6), (6, 7), (2, 3), (3, 4), (11, 12),
(12, 13), (8, 9), (9, 10)]

map_list = []
mask_list = []
PAF_all = np.zeros(
shape=(small_image.shape[0], small_image.shape[1], 2),
dtype=np.float32)
for c, pair in enumerate(part_orders):
idx_a_name = pair[0]
idx_b_name = pair[1]

jointa = joints[idx_a_name]
jointb = joints[idx_b_name]

confidence_threshold = 0.05
if jointa[2] > confidence_threshold and jointb[
2] > confidence_threshold:
canvas = np.zeros(
shape=(small_image.shape[0], small_image.shape[1]),
dtype=np.uint8)

canvas = cv2.line(canvas, (int(jointa[0]), int(jointa[1])),
(int(jointb[0]), int(jointb[1])),
(255, 255, 255), 5)

kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
(kernel_size, kernel_size))

canvas = cv2.dilate(canvas, kernel, 1)
canvas = cv2.GaussianBlur(canvas, (kernel_size, kernel_size), 0)
canvas = canvas.astype(np.float32) / 255
PAF = np.zeros(
shape=(small_image.shape[0], small_image.shape[1], 2),
dtype=np.float32)
PAF[..., 0] = jointb[0] - jointa[0]
PAF[..., 1] = jointb[1] - jointa[1]
mag, ang = cv2.cartToPolar(PAF[..., 0], PAF[..., 1])
PAF /= (np.dstack((mag, mag)) + 1e-5)

single_PAF = PAF * np.dstack((canvas, canvas))
map_list.append(
cv2.GaussianBlur(single_PAF,
(kernel_size * 3, kernel_size * 3), 0))

mask_list.append(
cv2.GaussianBlur(canvas.copy(),
(kernel_size * 3, kernel_size * 3), 0))
PAF_all = PAF_all * (1.0 - np.dstack(
(canvas, canvas))) + single_PAF

PAF_all = cv2.GaussianBlur(PAF_all, (kernel_size * 3, kernel_size * 3), 0)
PAF_all = cv2.resize(
PAF_all, (org_w, org_h), interpolation=cv2.INTER_LINEAR)
map_list.append(PAF_all)
return PAF_all, map_list, mask_list


def gen_skeleton_map(joints, stack_mode='column', input_roi_box=None):
if type(joints) == list:
joints = np.array(joints)
assert stack_mode == 'column' or stack_mode == 'depth'

part_orders = [(2, 5), (5, 11), (2, 8), (8, 11), (5, 6), (6, 7), (2, 3),
(3, 4), (11, 12), (12, 13), (8, 9), (9, 10)]

def link(img, a, b, color, line_width, scale=1.0, x_offset=0, y_offset=0):
jointa = joints[a]
jointb = joints[b]

temp1 = int((jointa[0] - x_offset) * scale)
temp2 = int((jointa[1] - y_offset) * scale)
temp3 = int((jointb[0] - x_offset) * scale)
temp4 = int((jointb[1] - y_offset) * scale)

cv2.line(img, (temp1, temp2), (temp3, temp4), color, line_width)

roi_box = input_roi_box

roi_box_width = roi_box[3] - roi_box[2]
roi_box_height = roi_box[1] - roi_box[0]
short_side_length = min(roi_box_width, roi_box_height)
line_width = short_side_length // 30

line_width = max(line_width, 2)

map_cube = np.zeros(
shape=(roi_box_height, roi_box_width, len(part_orders) + 1),
dtype=np.float32)

use_line_width = min(5, line_width)
fx = use_line_width * 1.0 / line_width # fx 最大值为1

if fx < 0.99:
map_cube = cv2.resize(map_cube, (0, 0), fx=fx, fy=fx)

for c, pair in enumerate(part_orders):
tmp = map_cube[..., c].copy()
link(
tmp,
pair[0],
pair[1], (2.0, 2.0, 2.0),
use_line_width,
scale=fx,
x_offset=roi_box[2],
y_offset=roi_box[0])
map_cube[..., c] = tmp

tmp = map_cube[..., -1].copy()
link(
tmp,
pair[0],
pair[1], (2.0, 2.0, 2.0),
use_line_width,
scale=fx,
x_offset=roi_box[2],
y_offset=roi_box[0])
map_cube[..., -1] = tmp

map_cube = cv2.resize(map_cube, (roi_box_width, roi_box_height))

if stack_mode == 'depth':
return map_cube, roi_box
elif stack_mode == 'column':
joint_maps = []
for c in range(len(part_orders) + 1):
joint_maps.append(map_cube[..., c])
joint_map = np.column_stack(joint_maps)

return joint_map, roi_box


def plot_one_box(x, img, color=None, label=None, line_thickness=None):
tl = line_thickness or round(
0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
color = color or [random.randint(0, 255) for _ in range(3)]
c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
if label:
tf = max(tl - 1, 1) # font thickness
t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled
cv2.putText(
img,
label, (c1[0], c1[1] - 2),
0,
tl / 3, [225, 255, 255],
thickness=tf,
lineType=cv2.LINE_AA)


def draw_line(im, points, color, stroke_size=2, closed=False):
points = points.astype(np.int32)
for i in range(len(points) - 1):
cv2.line(im, tuple(points[i]), tuple(points[i + 1]), color,
stroke_size)
if closed:
cv2.line(im, tuple(points[0]), tuple(points[-1]), color, stroke_size)


def enlarged_bbox(bbox, img_width, img_height, enlarge_ratio=0.2):
left = bbox[0]
top = bbox[1]

right = bbox[2]
bottom = bbox[3]

roi_width = right - left
roi_height = bottom - top

new_left = left - int(roi_width * enlarge_ratio)
new_left = 0 if new_left < 0 else new_left

new_top = top - int(roi_height * enlarge_ratio)
new_top = 0 if new_top < 0 else new_top

new_right = right + int(roi_width * enlarge_ratio)
new_right = img_width if new_right > img_width else new_right

new_bottom = bottom + int(roi_height * enlarge_ratio)
new_bottom = img_height if new_bottom > img_height else new_bottom

bbox = [new_left, new_top, new_right, new_bottom]

bbox = [int(x) for x in bbox]

return bbox


def get_map_fusion_map_cuda(map_list, threshold=1, device=torch.device('cpu')):
map_list_cuda = [torch.from_numpy(x).to(device) for x in map_list]
map_concat = torch.stack(tuple(map_list_cuda), dim=-1)

map_concat = torch.abs(map_concat)

map_concat[map_concat < threshold] = 0
map_concat[map_concat > 1e-5] = 1.0

sum_map = torch.sum(map_concat, dim=2)
a = torch.ones_like(sum_map)
acc_map = torch.where(sum_map > 0, a * 2.0, torch.zeros_like(sum_map))

fusion_map = torch.where(sum_map < 0.5, a * 1.5, sum_map)

fusion_map = fusion_map.float()
acc_map = acc_map.float()

fusion_map = fusion_map.cpu().numpy().astype(np.float32)
acc_map = acc_map.cpu().numpy().astype(np.float32)

return fusion_map, acc_map


def gen_border_shade(height, width, height_band, width_band):
height_ratio = height_band * 1.0 / height
width_ratio = width_band * 1.0 / width

_height_band = int(256 * height_ratio)
_width_band = int(256 * width_ratio)

canvas = np.zeros((256, 256), dtype=np.float32)

canvas[_height_band // 2:-_height_band // 2,
_width_band // 2:-_width_band // 2] = 1.0

canvas = cv2.blur(canvas, (_height_band, _width_band))

canvas = cv2.resize(canvas, (width, height))

return canvas


def get_mask_bbox(mask, threshold=127):
ret, mask = cv2.threshold(mask, threshold, 1, 0)

if cv2.countNonZero(mask) == 0:
return [None, None, None, None]

col_acc = np.sum(mask, 0)
row_acc = np.sum(mask, 1)

col_acc = col_acc.tolist()
row_acc = row_acc.tolist()

for x in range(len(col_acc)):
if col_acc[x] > 0:
left = x
break

for x in range(1, len(col_acc)):
if col_acc[-x] > 0:
right = len(col_acc) - x
break

for x in range(len(row_acc)):
if row_acc[x] > 0:
top = x
break

for x in range(1, len(row_acc)):
if row_acc[-x] > 0:
bottom = len(row_acc[::-1]) - x
break
return [top, bottom, left, right]


def visualize_flow(flow):
h, w = flow.shape[:2]
hsv = np.zeros((h, w, 3), np.uint8)
mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1])

hsv[..., 0] = ang * 180 / np.pi / 2
hsv[..., 1] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
hsv[..., 2] = 255
bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
bgr = bgr * 1.0 / 255
return bgr.astype(np.float32)


def vis_joints(image, joints, color, show_text=True, confidence_threshold=0.1):

part_orders = [(2, 5), (5, 11), (2, 8), (8, 11), (5, 6), (6, 7), (2, 3),
(3, 4), (11, 12), (12, 13), (8, 9), (9, 10)]

abandon_idxs = [0, 1, 14, 15, 16, 17]
# draw joints
for i, joint in enumerate(joints):
if i in abandon_idxs:
continue
if joint[-1] > confidence_threshold:

cv2.circle(image, (int(joint[0]), int(joint[1])), 1, color, 2)
if show_text:
cv2.putText(image,
str(i) + '[{:.2f}]'.format(joint[-1]),
(int(joint[0]), int(joint[1])),
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
# draw link
for pair in part_orders:
if joints[pair[0]][-1] > confidence_threshold and joints[
pair[1]][-1] > confidence_threshold:
cv2.line(image, (int(joints[pair[0]][0]), int(joints[pair[0]][1])),
(int(joints[pair[1]][0]), int(joints[pair[1]][1])), color,
2)
return image


def get_heatmap_cv(img, magn, max_flow_mag):
min_flow_mag = .5
cv_magn = np.clip(
255 * (magn - min_flow_mag) / (max_flow_mag - min_flow_mag + 1e-7),
a_min=0,
a_max=255).astype(np.uint8)
if img.dtype != np.uint8:
img = (255 * img).astype(np.uint8)

heatmap_img = cv2.applyColorMap(cv_magn, cv2.COLORMAP_JET)
heatmap_img = heatmap_img[..., ::-1]

h, w = magn.shape
img_alpha = np.ones((h, w), dtype=np.double)[:, :, None]
heatmap_alpha = np.clip(
magn / (max_flow_mag + 1e-7), a_min=1e-7, a_max=1)[:, :, None]**.7
heatmap_alpha[heatmap_alpha < .2]**.5
pm_hm = heatmap_img * heatmap_alpha
pm_img = img * img_alpha
cv_out = pm_hm + pm_img * (1 - heatmap_alpha)
cv_out = np.clip(cv_out, a_min=0, a_max=255).astype(np.uint8)

return cv_out


def save_heatmap_cv(img, flow, supression=2):

flow_magn = np.sqrt(flow[:, :, 0]**2 + flow[:, :, 1]**2)
flow_magn -= supression
flow_magn[flow_magn <= 0] = 0
cv_out = get_heatmap_cv(img, flow_magn, np.max(flow_magn) * 1.3)
return cv_out


@numba.jit(nopython=True, parallel=False)
def bilinear_interp(x, y, v11, v12, v21, v22):
temp1 = (v11 * (1 - y) + v12 * y) * (1 - x)
temp2 = (v21 * (1 - y) + v22 * y) * x
result = temp1 + temp2
return result


@numba.jit(nopython=True, parallel=False)
def image_warp_grid1(rDx, rDy, oriImg, transRatio, width_expand,
height_expand):
srcW = oriImg.shape[1]
srcH = oriImg.shape[0]

newImg = oriImg.copy()

for i in range(srcH):
for j in range(srcW):
_i = i
_j = j

deltaX = rDx[_i, _j]
deltaY = rDy[_i, _j]

nx = _j + deltaX * transRatio
ny = _i + deltaY * transRatio

if nx >= srcW - width_expand - 1:
if nx > srcW - 1:
nx = srcW - 1

if ny >= srcH - height_expand - 1:
if ny > srcH - 1:
ny = srcH - 1

if nx < width_expand:
if nx < 0:
nx = 0

if ny < height_expand:
if ny < 0:
ny = 0

nxi = int(math.floor(nx))
nyi = int(math.floor(ny))
nxi1 = int(math.ceil(nx))
nyi1 = int(math.ceil(ny))

for ll in range(3):
newImg[_i, _j,
ll] = bilinear_interp(ny - nyi, nx - nxi,
oriImg[nyi, nxi,
ll], oriImg[nyi, nxi1, ll],
oriImg[nyi1, nxi,
ll], oriImg[nyi1, nxi1,
ll])
return newImg

+ 1
- 0
modelscope/outputs.py View File

@@ -184,6 +184,7 @@ TASK_OUTPUTS = {
Tasks.image_to_image_translation: [OutputKeys.OUTPUT_IMG], Tasks.image_to_image_translation: [OutputKeys.OUTPUT_IMG],
Tasks.image_style_transfer: [OutputKeys.OUTPUT_IMG], Tasks.image_style_transfer: [OutputKeys.OUTPUT_IMG],
Tasks.image_portrait_stylization: [OutputKeys.OUTPUT_IMG], Tasks.image_portrait_stylization: [OutputKeys.OUTPUT_IMG],
Tasks.image_body_reshaping: [OutputKeys.OUTPUT_IMG],


# live category recognition result for single video # live category recognition result for single video
# { # {


+ 2
- 0
modelscope/pipelines/builder.py View File

@@ -75,6 +75,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/nlp_bart_text-error-correction_chinese'), 'damo/nlp_bart_text-error-correction_chinese'),
Tasks.image_captioning: (Pipelines.image_captioning, Tasks.image_captioning: (Pipelines.image_captioning,
'damo/ofa_image-caption_coco_large_en'), 'damo/ofa_image-caption_coco_large_en'),
Tasks.image_body_reshaping: (Pipelines.image_body_reshaping,
'damo/cv_flow-based-body-reshaping_damo'),
Tasks.image_portrait_stylization: Tasks.image_portrait_stylization:
(Pipelines.person_image_cartoon, (Pipelines.person_image_cartoon,
'damo/cv_unet_person-image-cartoon_compound-models'), 'damo/cv_unet_person-image-cartoon_compound-models'),


+ 40
- 0
modelscope/pipelines/cv/image_body_reshaping_pipeline.py View File

@@ -0,0 +1,40 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict

from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import LoadImage
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


@PIPELINES.register_module(
Tasks.image_body_reshaping, module_name=Pipelines.image_body_reshaping)
class ImageBodyReshapingPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
"""
use `model` to create a image body reshaping pipeline for prediction
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model, **kwargs)
logger.info('body reshaping model init done')

def preprocess(self, input: Input) -> Dict[str, Any]:
img = LoadImage.convert_to_ndarray(input)
result = {'img': img}
return result

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
output = self.model.inference(input['img'])
result = {'outputs': output}
return result

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
output_img = inputs['outputs']
return {OutputKeys.OUTPUT_IMG: output_img}

+ 1
- 1
modelscope/utils/constant.py View File

@@ -60,7 +60,7 @@ class CVTasks(object):
image_to_image_generation = 'image-to-image-generation' image_to_image_generation = 'image-to-image-generation'
image_style_transfer = 'image-style-transfer' image_style_transfer = 'image-style-transfer'
image_portrait_stylization = 'image-portrait-stylization' image_portrait_stylization = 'image-portrait-stylization'
image_body_reshaping = 'image-body-reshaping'
image_embedding = 'image-embedding' image_embedding = 'image-embedding'


product_retrieval_embedding = 'product-retrieval-embedding' product_retrieval_embedding = 'product-retrieval-embedding'


+ 1
- 0
requirements/cv.txt View File

@@ -13,6 +13,7 @@ ml_collections
mmcls>=0.21.0 mmcls>=0.21.0
mmdet>=2.25.0 mmdet>=2.25.0
networkx>=2.5 networkx>=2.5
numba
onnxruntime>=1.10 onnxruntime>=1.10
pai-easycv>=0.6.3.6 pai-easycv>=0.6.3.6
pandas pandas


+ 58
- 0
tests/pipelines/test_image_body_reshaping.py View File

@@ -0,0 +1,58 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
import unittest

import cv2

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.pipelines.base import Pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.test_utils import test_level


class ImageBodyReshapingTest(unittest.TestCase, DemoCompatibilityCheck):

def setUp(self) -> None:
self.task = Tasks.image_body_reshaping
self.model_id = 'damo/cv_flow-based-body-reshaping_damo'
self.test_image = 'data/test/images/image_body_reshaping.jpg'

def pipeline_inference(self, pipeline: Pipeline, input_location: str):
result = pipeline(input_location)
if result is not None:
cv2.imwrite('result_bodyreshaping.png',
result[OutputKeys.OUTPUT_IMG])
print(
f'Output written to {osp.abspath("result_body_reshaping.png")}'
)
else:
raise Exception('Testing failed: invalid output')

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_by_direct_model_download(self):
model_dir = snapshot_download(self.model_id)
image_body_reshaping = pipeline(
Tasks.image_body_reshaping, model=model_dir)
self.pipeline_inference(image_body_reshaping, self.test_image)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub(self):
image_body_reshaping = pipeline(
Tasks.image_body_reshaping, model=self.model_id)
self.pipeline_inference(image_body_reshaping, self.test_image)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_modelhub_default_model(self):
image_body_reshaping = pipeline(Tasks.image_body_reshaping)
self.pipeline_inference(image_body_reshaping, self.test_image)

@unittest.skip('demo compatibility test is only enabled on a needed-basis')
def test_demo_compatibility(self):
self.compatibility_check()


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save