From ff69439c4f48bbd1ca3e3f81a3c921925f8e3ca5 Mon Sep 17 00:00:00 2001 From: "ryan.yy" Date: Mon, 10 Oct 2022 17:42:41 +0800 Subject: [PATCH] [to #42322933]add image_body_reshaping code Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10217723 * add image_body_reshaping code --- data/test/images/image_body_reshaping.jpg | 3 + modelscope/metainfo.py | 2 + .../cv/image_body_reshaping/__init__.py | 20 + .../image_body_reshaping.py | 128 +++++ .../models/cv/image_body_reshaping/model.py | 189 +++++++ .../cv/image_body_reshaping/person_info.py | 339 ++++++++++++ .../pose_estimator/__init__.py | 0 .../pose_estimator/body.py | 272 ++++++++++ .../pose_estimator/model.py | 141 +++++ .../pose_estimator/util.py | 33 ++ .../cv/image_body_reshaping/slim_utils.py | 507 ++++++++++++++++++ modelscope/outputs.py | 1 + modelscope/pipelines/builder.py | 2 + .../cv/image_body_reshaping_pipeline.py | 40 ++ modelscope/utils/constant.py | 2 +- requirements/cv.txt | 1 + tests/pipelines/test_image_body_reshaping.py | 58 ++ 17 files changed, 1737 insertions(+), 1 deletion(-) create mode 100644 data/test/images/image_body_reshaping.jpg create mode 100644 modelscope/models/cv/image_body_reshaping/__init__.py create mode 100644 modelscope/models/cv/image_body_reshaping/image_body_reshaping.py create mode 100644 modelscope/models/cv/image_body_reshaping/model.py create mode 100644 modelscope/models/cv/image_body_reshaping/person_info.py create mode 100644 modelscope/models/cv/image_body_reshaping/pose_estimator/__init__.py create mode 100644 modelscope/models/cv/image_body_reshaping/pose_estimator/body.py create mode 100644 modelscope/models/cv/image_body_reshaping/pose_estimator/model.py create mode 100644 modelscope/models/cv/image_body_reshaping/pose_estimator/util.py create mode 100644 modelscope/models/cv/image_body_reshaping/slim_utils.py create mode 100644 modelscope/pipelines/cv/image_body_reshaping_pipeline.py create mode 100644 tests/pipelines/test_image_body_reshaping.py diff --git a/data/test/images/image_body_reshaping.jpg b/data/test/images/image_body_reshaping.jpg new file mode 100644 index 00000000..d78acb8f --- /dev/null +++ b/data/test/images/image_body_reshaping.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b2c1119e3d521cf2e583b1e85fc9c9afd1d44954b433135039a98050a730932d +size 1127557 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 28804ce6..1b8c4720 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -43,6 +43,7 @@ class Models(object): face_human_hand_detection = 'face-human-hand-detection' face_emotion = 'face-emotion' product_segmentation = 'product-segmentation' + image_body_reshaping = 'image-body-reshaping' # EasyCV models yolox = 'YOLOX' @@ -187,6 +188,7 @@ class Pipelines(object): face_human_hand_detection = 'face-human-hand-detection' face_emotion = 'face-emotion' product_segmentation = 'product-segmentation' + image_body_reshaping = 'flow-based-body-reshaping' # nlp tasks automatic_post_editing = 'automatic-post-editing' diff --git a/modelscope/models/cv/image_body_reshaping/__init__.py b/modelscope/models/cv/image_body_reshaping/__init__.py new file mode 100644 index 00000000..a04f110d --- /dev/null +++ b/modelscope/models/cv/image_body_reshaping/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .image_body_reshaping import ImageBodyReshaping + +else: + _import_structure = {'image_body_reshaping': ['ImageBodyReshaping']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/image_body_reshaping/image_body_reshaping.py b/modelscope/models/cv/image_body_reshaping/image_body_reshaping.py new file mode 100644 index 00000000..4aed8d98 --- /dev/null +++ b/modelscope/models/cv/image_body_reshaping/image_body_reshaping.py @@ -0,0 +1,128 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from typing import Any, Dict + +import cv2 +import numpy as np +import torch + +from modelscope.metainfo import Models +from modelscope.models.base import Tensor, TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from .model import FlowGenerator +from .person_info import PersonInfo +from .pose_estimator.body import Body +from .slim_utils import image_warp_grid1, resize_on_long_side + +logger = get_logger() + +__all__ = ['ImageBodyReshaping'] + + +@MODELS.register_module( + Tasks.image_body_reshaping, module_name=Models.image_body_reshaping) +class ImageBodyReshaping(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the image body reshaping model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + super().__init__(model_dir, *args, **kwargs) + + if torch.cuda.is_available(): + self.device = torch.device('cuda') + else: + self.device = torch.device('cpu') + + self.degree = 1.0 + self.reshape_model = FlowGenerator(n_channels=16).to(self.device) + model_path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE) + checkpoints = torch.load(model_path, map_location=torch.device('cpu')) + self.reshape_model.load_state_dict( + checkpoints['state_dict'], strict=True) + self.reshape_model.eval() + logger.info('load body reshaping model done') + + pose_model_ckpt = os.path.join(model_dir, 'body_pose_model.pth') + self.pose_esti = Body(pose_model_ckpt, self.device) + logger.info('load pose model done') + + def pred_joints(self, img): + if img is None: + return None + small_src, resize_scale = resize_on_long_side(img, 300) + body_joints = self.pose_esti(small_src) + + if body_joints.shape[0] >= 1: + body_joints[:, :, :2] = body_joints[:, :, :2] / resize_scale + + return body_joints + + def pred_flow(self, img): + + body_joints = self.pred_joints(img) + small_size = 1200 + + if img.shape[0] > small_size or img.shape[1] > small_size: + _img, _scale = resize_on_long_side(img, small_size) + body_joints[:, :, :2] = body_joints[:, :, :2] * _scale + else: + _img = img + + # We only reshape one person + if body_joints.shape[0] < 1 or body_joints.shape[0] > 1: + return None + + person = PersonInfo(body_joints[0]) + + with torch.no_grad(): + person_pred = person.pred_flow(_img, self.reshape_model, + self.device) + + flow = np.dstack((person_pred['rDx'], person_pred['rDy'])) + + scale = img.shape[0] * 1.0 / flow.shape[0] + + flow = cv2.resize(flow, (img.shape[1], img.shape[0])) + flow *= scale + + return flow + + def warp(self, src_img, flow): + + X_flow = flow[..., 0] + Y_flow = flow[..., 1] + + X_flow = np.ascontiguousarray(X_flow) + Y_flow = np.ascontiguousarray(Y_flow) + + pred = image_warp_grid1(X_flow, Y_flow, src_img, 1.0, 0, 0) + return pred + + def inference(self, img): + img = img.cpu().numpy() + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + flow = self.pred_flow(img) + + if flow is None: + return img + + assert flow.shape[:2] == img.shape[:2] + + mag, ang = cv2.cartToPolar(flow[..., 0] + 1e-8, flow[..., 1] + 1e-8) + mag -= 3 + mag[mag <= 0] = 0 + + x, y = cv2.polarToCart(mag, ang, angleInDegrees=False) + flow = np.dstack((x, y)) + + flow *= self.degree + pred = self.warp(img, flow) + out_img = np.clip(pred, 0, 255) + logger.info('model inference done') + + return out_img.astype(np.uint8) diff --git a/modelscope/models/cv/image_body_reshaping/model.py b/modelscope/models/cv/image_body_reshaping/model.py new file mode 100644 index 00000000..174428a1 --- /dev/null +++ b/modelscope/models/cv/image_body_reshaping/model.py @@ -0,0 +1,189 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ConvLayer(nn.Module): + + def __init__(self, in_ch, out_ch): + super(ConvLayer, self).__init__() + + self.conv = nn.Sequential( + nn.ReflectionPad2d(1), + nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=0), + nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True)) + + def forward(self, x): + x = self.conv(x) + return x + + +class SASA(nn.Module): + + def __init__(self, in_dim): + super(SASA, self).__init__() + self.chanel_in = in_dim + + self.query_conv = nn.Conv2d( + in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) + self.key_conv = nn.Conv2d( + in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) + self.value_conv = nn.Conv2d( + in_channels=in_dim, out_channels=in_dim, kernel_size=1) + self.mag_conv = nn.Conv2d( + in_channels=5, out_channels=in_dim // 32, kernel_size=1) + + self.gamma = nn.Parameter(torch.zeros(1)) + + self.softmax = nn.Softmax(dim=-1) # + self.sigmoid = nn.Sigmoid() + + def structure_encoder(self, paf_mag, target_height, target_width): + torso_mask = torch.sum(paf_mag[:, 1:3, :, :], dim=1, keepdim=True) + torso_mask = torch.clamp(torso_mask, 0, 1) + + arms_mask = torch.sum(paf_mag[:, 4:8, :, :], dim=1, keepdim=True) + arms_mask = torch.clamp(arms_mask, 0, 1) + + legs_mask = torch.sum(paf_mag[:, 8:12, :, :], dim=1, keepdim=True) + legs_mask = torch.clamp(legs_mask, 0, 1) + + fg_mask = paf_mag[:, 12, :, :].unsqueeze(1) + bg_mask = 1 - fg_mask + Y = torch.cat((arms_mask, torso_mask, legs_mask, fg_mask, bg_mask), + dim=1) + Y = F.interpolate(Y, size=(target_height, target_width), mode='area') + return Y + + def forward(self, X, PAF_mag): + """extract self-attention features. + Args: + X : input feature maps( B x C x H x W) + PAF_mag : ( B x C x H x W), 1 denotes connectivity, 0 denotes non-connectivity + + Returns: + out : self attention value + input feature + Y: B X N X N (N is Width*Height) + """ + + m_batchsize, C, height, width = X.size() + + Y = self.structure_encoder(PAF_mag, height, width) + + connectivity_mask_vec = self.mag_conv(Y).view(m_batchsize, -1, + width * height) + affinity = torch.bmm( + connectivity_mask_vec.permute(0, 2, 1), connectivity_mask_vec) + affinity_centered = affinity - torch.mean(affinity) + affinity_sigmoid = self.sigmoid(affinity_centered) + + proj_query = self.query_conv(X).view(m_batchsize, -1, + width * height).permute(0, 2, 1) + proj_key = self.key_conv(X).view(m_batchsize, -1, width * height) + selfatten_map = torch.bmm(proj_query, proj_key) + selfatten_centered = selfatten_map - torch.mean( + selfatten_map) # centering + selfatten_sigmoid = self.sigmoid(selfatten_centered) + + SASA_map = selfatten_sigmoid * affinity_sigmoid + + proj_value = self.value_conv(X).view(m_batchsize, -1, width * height) + + out = torch.bmm(proj_value, SASA_map.permute(0, 2, 1)) + out = out.view(m_batchsize, C, height, width) + + out = self.gamma * out + X + return out, Y + + +class FlowGenerator(nn.Module): + + def __init__(self, n_channels, deep_supervision=False): + super(FlowGenerator, self).__init__() + self.deep_supervision = deep_supervision + + self.Encoder = nn.Sequential( + ConvLayer(n_channels, 64), + ConvLayer(64, 64), + nn.MaxPool2d(2), + ConvLayer(64, 128), + ConvLayer(128, 128), + nn.MaxPool2d(2), + ConvLayer(128, 256), + ConvLayer(256, 256), + nn.MaxPool2d(2), + ConvLayer(256, 512), + ConvLayer(512, 512), + nn.MaxPool2d(2), + ConvLayer(512, 1024), + ConvLayer(1024, 1024), + ConvLayer(1024, 1024), + ConvLayer(1024, 1024), + ConvLayer(1024, 1024), + ) + + self.SASA = SASA(in_dim=1024) + + self.Decoder = nn.Sequential( + ConvLayer(1024, 1024), + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), + ConvLayer(1024, 512), + ConvLayer(512, 512), + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), + ConvLayer(512, 256), + ConvLayer(256, 256), + ConvLayer(256, 128), + ConvLayer(128, 64), + ConvLayer(64, 32), + nn.Conv2d(32, 2, kernel_size=1, padding=0), + nn.Tanh(), + nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True), + ) + + dilation_ksize = 17 + self.dilation = torch.nn.MaxPool2d( + kernel_size=dilation_ksize, + stride=1, + padding=int((dilation_ksize - 1) / 2)) + + def warp(self, x, flow, mode='bilinear', padding_mode='zeros', coff=0.2): + n, c, h, w = x.size() + yv, xv = torch.meshgrid([torch.arange(h), torch.arange(w)]) + xv = xv.float() / (w - 1) * 2.0 - 1 + yv = yv.float() / (h - 1) * 2.0 - 1 + grid = torch.cat((xv.unsqueeze(-1), yv.unsqueeze(-1)), -1).unsqueeze(0) + grid = grid.to(flow.device) + grid_x = grid + 2 * flow * coff + warp_x = F.grid_sample(x, grid_x, mode=mode, padding_mode=padding_mode) + return warp_x + + def forward(self, img, skeleton_map, coef=0.2): + """extract self-attention features. + Args: + img : input numpy image + skeleton_map : skeleton map of input image + coef: warp degree + + Returns: + warp_x : warped image + flow: predicted flow + """ + + img_concat = torch.cat((img, skeleton_map), dim=1) + X = self.Encoder(img_concat) + + _, _, height, width = X.size() + + # directly get PAF magnitude from skeleton maps via dilation + PAF_mag = self.dilation((skeleton_map + 1.0) * 0.5) + + out, Y = self.SASA(X, PAF_mag) + flow = self.Decoder(out) + + flow = flow.permute(0, 2, 3, 1) # [n, 2, h, w] ==> [n, h, w, 2] + + warp_x = self.warp(img, flow, coff=coef) + warp_x = torch.clamp(warp_x, min=-1.0, max=1.0) + + return warp_x, flow diff --git a/modelscope/models/cv/image_body_reshaping/person_info.py b/modelscope/models/cv/image_body_reshaping/person_info.py new file mode 100644 index 00000000..509a2ce3 --- /dev/null +++ b/modelscope/models/cv/image_body_reshaping/person_info.py @@ -0,0 +1,339 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import copy + +import cv2 +import numpy as np +import torch + +from .slim_utils import (enlarge_box_tblr, gen_skeleton_map, + get_map_fusion_map_cuda, get_mask_bbox, + resize_on_long_side) + + +class PersonInfo(object): + + def __init__(self, joints): + self.joints = joints + self.flow = None + self.pad_boder = False + self.height_expand = 0 + self.width_expand = 0 + self.coeff = 0.2 + self.network_input_W = 256 + self.network_input_H = 256 + self.divider = 20 + self.flow_scales = ['upper_2'] + + def update_attribute(self, pad_boder, height_expand, width_expand): + self.pad_boder = pad_boder + self.height_expand = height_expand + self.width_expand = width_expand + if pad_boder: + self.joints[:, 0] += width_expand + self.joints[:, 1] += height_expand + + def pred_flow(self, img, flow_net, device): + with torch.no_grad(): + if img is None: + print('image is none') + self.flow = None + + if len(img.shape) == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + if self.pad_boder: + height_expand = self.height_expand + width_expand = self.width_expand + pad_img = cv2.copyMakeBorder( + img, + height_expand, + height_expand, + width_expand, + width_expand, + cv2.BORDER_CONSTANT, + value=(127, 127, 127)) + + else: + height_expand = 0 + width_expand = 0 + pad_img = img.copy() + + canvas = np.zeros( + shape=(pad_img.shape[0], pad_img.shape[1]), dtype=np.float32) + + self.human_joint_box = self.__joint_to_body_box() + + self.human_box = enlarge_box_tblr( + self.human_joint_box, pad_img, ratio=0.25) + human_box_height = self.human_box[1] - self.human_box[0] + human_box_width = self.human_box[3] - self.human_box[2] + + self.leg_joint_box = self.__joint_to_leg_box() + self.leg_box = enlarge_box_tblr( + self.leg_joint_box, pad_img, ratio=0.25) + + self.arm_joint_box = self.__joint_to_arm_box() + self.arm_box = enlarge_box_tblr( + self.arm_joint_box, pad_img, ratio=0.1) + + x_flows = [] + y_flows = [] + multi_bbox = [] + + for scale in self.flow_scales: # better for metric + scale_value = float(scale.split('_')[-1]) + + arm_box = copy.deepcopy(self.arm_box) + + if arm_box[0] is None: + arm_box = self.human_box + + arm_box_height = arm_box[1] - arm_box[0] + arm_box_width = arm_box[3] - arm_box[2] + + roi_bbox = None + + if arm_box_width < human_box_width * 0.1 or arm_box_height < human_box_height * 0.1: + roi_bbox = self.human_box + else: + arm_box = enlarge_box_tblr( + arm_box, pad_img, ratio=scale_value) + if scale == 'upper_0.2': + arm_box[0] = min(arm_box[0], int(self.joints[0][1])) + if scale.startswith('upper'): + roi_bbox = [ + max(self.human_box[0], arm_box[0]), + min(self.human_box[1], arm_box[1]), + max(self.human_box[2], arm_box[2]), + min(self.human_box[3], arm_box[3]) + ] + if roi_bbox[1] - roi_bbox[0] < 1 or roi_bbox[ + 3] - roi_bbox[2] < 1: + continue + + elif scale.startswith('lower'): + roi_bbox = [ + max(self.human_box[0], self.leg_box[0]), + min(self.human_box[1], self.leg_box[1]), + max(self.human_box[2], self.leg_box[2]), + min(self.human_box[3], self.leg_box[3]) + ] + + if roi_bbox[1] - roi_bbox[0] < 1 or roi_bbox[ + 3] - roi_bbox[2] < 1: + continue + + skel_map, roi_bbox = gen_skeleton_map( + self.joints, 'depth', input_roi_box=roi_bbox) + + if roi_bbox is None: + continue + + if skel_map.dtype != np.float32: + skel_map = skel_map.astype(np.float32) + + skel_map -= 1.0 # [0,2] ->[-1,1] + + multi_bbox.append(roi_bbox) + + roi_bbox_height = roi_bbox[1] - roi_bbox[0] + roi_bbox_width = roi_bbox[3] - roi_bbox[2] + + assert skel_map.shape[0] == roi_bbox_height + assert skel_map.shape[1] == roi_bbox_width + roi_height_pad = roi_bbox_height // self.divider + roi_width_pad = roi_bbox_width // self.divider + paded_roi_h = roi_bbox_height + 2 * roi_height_pad + paded_roi_w = roi_bbox_width + 2 * roi_width_pad + + roi_height_pad_joint = skel_map.shape[0] // self.divider + roi_width_pad_joint = skel_map.shape[1] // self.divider + skel_map = np.pad( + skel_map, + ((roi_height_pad_joint, roi_height_pad_joint), + (roi_width_pad_joint, roi_width_pad_joint), (0, 0)), + 'constant', + constant_values=-1) + + skel_map_resized = cv2.resize( + skel_map, (self.network_input_W, self.network_input_H)) + + skel_map_resized[skel_map_resized < 0] = -1.0 + skel_map_resized[skel_map_resized > -0.5] = 1.0 + skel_map_transformed = torch.from_numpy( + skel_map_resized.transpose((2, 0, 1))) + + roi_npy = pad_img[roi_bbox[0]:roi_bbox[1], + roi_bbox[2]:roi_bbox[3], :].copy() + if roi_npy.dtype != np.float32: + roi_npy = roi_npy.astype(np.float32) + + roi_npy = np.pad(roi_npy, + ((roi_height_pad, roi_height_pad), + (roi_width_pad, roi_width_pad), (0, 0)), + 'edge') + + roi_npy = roi_npy[:, :, ::-1] + + roi_npy = cv2.resize( + roi_npy, (self.network_input_W, self.network_input_H)) + + roi_npy *= 1.0 / 255 + roi_npy -= 0.5 + roi_npy *= 2 + + rgb_tensor = torch.from_numpy(roi_npy.transpose((2, 0, 1))) + + rgb_tensor = rgb_tensor.unsqueeze(0).to(device) + skel_map_tensor = skel_map_transformed.unsqueeze(0).to(device) + warped_img_val, flow_field_val = flow_net( + rgb_tensor, skel_map_tensor + ) # inference, connectivity_mask [1,12,16,16] + flow_field_val = flow_field_val.detach().squeeze().cpu().numpy( + ) + + flow_field_val = cv2.resize( + flow_field_val, (paded_roi_w, paded_roi_h), + interpolation=cv2.INTER_LINEAR) + flow_field_val[..., 0] = flow_field_val[ + ..., 0] * paded_roi_w * 0.5 * 2 * self.coeff + flow_field_val[..., 1] = flow_field_val[ + ..., 1] * paded_roi_h * 0.5 * 2 * self.coeff + + # remove pad areas + flow_field_val = flow_field_val[ + roi_height_pad:flow_field_val.shape[0] - roi_height_pad, + roi_width_pad:flow_field_val.shape[1] - roi_width_pad, :] + + diffuse_width = max(roi_bbox_width // 3, 1) + diffuse_height = max(roi_bbox_height // 3, 1) + assert roi_bbox_width == flow_field_val.shape[1] + assert roi_bbox_height == flow_field_val.shape[0] + + origin_flow = np.zeros( + (pad_img.shape[0] + 2 * diffuse_height, + pad_img.shape[1] + 2 * diffuse_width, 2), + dtype=np.float32) + + flow_field_val = np.pad(flow_field_val, + ((diffuse_height, diffuse_height), + (diffuse_width, diffuse_width), + (0, 0)), 'linear_ramp') + + origin_flow[roi_bbox[0]:roi_bbox[1] + 2 * diffuse_height, + roi_bbox[2]:roi_bbox[3] + + 2 * diffuse_width] = flow_field_val + + origin_flow = origin_flow[diffuse_height:-diffuse_height, + diffuse_width:-diffuse_width, :] + + x_flows.append(origin_flow[..., 0]) + y_flows.append(origin_flow[..., 1]) + + if len(x_flows) == 0: + return { + 'rDx': np.zeros(canvas.shape[:2], dtype=np.float32), + 'rDy': np.zeros(canvas.shape[:2], dtype=np.float32), + 'multi_bbox': multi_bbox, + 'x_fusion_map': + np.ones(canvas.shape[:2], dtype=np.float32), + 'y_fusion_map': + np.ones(canvas.shape[:2], dtype=np.float32) + } + else: + origin_rDx, origin_rDy, x_fusion_map, y_fusion_map = self.blend_multiscale_flow( + x_flows, y_flows, device=device) + + return { + 'rDx': origin_rDx, + 'rDy': origin_rDy, + 'multi_bbox': multi_bbox, + 'x_fusion_map': x_fusion_map, + 'y_fusion_map': y_fusion_map + } + + @staticmethod + def blend_multiscale_flow(x_flows, y_flows, device=None): + scale_num = len(x_flows) + if scale_num == 1: + return x_flows[0], y_flows[0], np.ones_like( + x_flows[0]), np.ones_like(x_flows[0]) + + origin_rDx = np.zeros((x_flows[0].shape[0], x_flows[0].shape[1]), + dtype=np.float32) + origin_rDy = np.zeros((y_flows[0].shape[0], y_flows[0].shape[1]), + dtype=np.float32) + + x_fusion_map, x_acc_map = get_map_fusion_map_cuda( + x_flows, 1, device=device) + y_fusion_map, y_acc_map = get_map_fusion_map_cuda( + y_flows, 1, device=device) + + x_flow_map = 1.0 / x_fusion_map + y_flow_map = 1.0 / y_fusion_map + + all_acc_map = x_acc_map + y_acc_map + all_acc_map = all_acc_map.astype(np.uint8) + roi_box = get_mask_bbox(all_acc_map, threshold=1) + + if roi_box[0] is None or roi_box[1] - roi_box[0] <= 0 or roi_box[ + 3] - roi_box[2] <= 0: + roi_box = [0, x_flow_map.shape[0], 0, x_flow_map.shape[1]] + + roi_x_flow_map = x_flow_map[roi_box[0]:roi_box[1], + roi_box[2]:roi_box[3]] + roi_y_flow_map = y_flow_map[roi_box[0]:roi_box[1], + roi_box[2]:roi_box[3]] + + roi_width = roi_x_flow_map.shape[1] + roi_height = roi_x_flow_map.shape[0] + + roi_x_flow_map, scale = resize_on_long_side(roi_x_flow_map, 320) + roi_y_flow_map, scale = resize_on_long_side(roi_y_flow_map, 320) + + roi_x_flow_map = cv2.blur(roi_x_flow_map, (55, 55)) + roi_y_flow_map = cv2.blur(roi_y_flow_map, (55, 55)) + + roi_x_flow_map = cv2.resize(roi_x_flow_map, (roi_width, roi_height)) + roi_y_flow_map = cv2.resize(roi_y_flow_map, (roi_width, roi_height)) + + x_flow_map[roi_box[0]:roi_box[1], + roi_box[2]:roi_box[3]] = roi_x_flow_map + y_flow_map[roi_box[0]:roi_box[1], + roi_box[2]:roi_box[3]] = roi_y_flow_map + + for i in range(scale_num): + origin_rDx += x_flows[i] + origin_rDy += y_flows[i] + + origin_rDx *= x_flow_map + origin_rDy *= y_flow_map + + return origin_rDx, origin_rDy, x_flow_map, y_flow_map + + def __joint_to_body_box(self): + joint_left = int(np.min(self.joints, axis=0)[0]) + joint_right = int(np.max(self.joints, axis=0)[0]) + joint_top = int(np.min(self.joints, axis=0)[1]) + joint_bottom = int(np.max(self.joints, axis=0)[1]) + return [joint_top, joint_bottom, joint_left, joint_right] + + def __joint_to_leg_box(self): + leg_joints = self.joints[8:, :] + if np.max(leg_joints, axis=0)[2] < 0.05: + return [0, 0, 0, 0] + joint_left = int(np.min(leg_joints, axis=0)[0]) + joint_right = int(np.max(leg_joints, axis=0)[0]) + joint_top = int(np.min(leg_joints, axis=0)[1]) + joint_bottom = int(np.max(leg_joints, axis=0)[1]) + return [joint_top, joint_bottom, joint_left, joint_right] + + def __joint_to_arm_box(self): + arm_joints = self.joints[2:8, :] + if np.max(arm_joints, axis=0)[2] < 0.05: + return [0, 0, 0, 0] + joint_left = int(np.min(arm_joints, axis=0)[0]) + joint_right = int(np.max(arm_joints, axis=0)[0]) + joint_top = int(np.min(arm_joints, axis=0)[1]) + joint_bottom = int(np.max(arm_joints, axis=0)[1]) + return [joint_top, joint_bottom, joint_left, joint_right] diff --git a/modelscope/models/cv/image_body_reshaping/pose_estimator/__init__.py b/modelscope/models/cv/image_body_reshaping/pose_estimator/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/image_body_reshaping/pose_estimator/body.py b/modelscope/models/cv/image_body_reshaping/pose_estimator/body.py new file mode 100644 index 00000000..45b02724 --- /dev/null +++ b/modelscope/models/cv/image_body_reshaping/pose_estimator/body.py @@ -0,0 +1,272 @@ +# The implementation is based on openpose, available at https://github.com/Hzzone/pytorch-openpose. + +import math + +import cv2 +import numpy as np +import torch +from scipy.ndimage.filters import gaussian_filter + +from .model import BodyposeModel +from .util import pad_rightdown_corner, transfer + + +class Body(object): + + def __init__(self, model_path, device): + self.model = BodyposeModel().to(device) + model_dict = transfer(self.model, torch.load(model_path)) + self.model.load_state_dict(model_dict) + self.model.eval() + + def __call__(self, oriImg): + scale_search = [0.5] + boxsize = 368 + stride = 8 + padValue = 128 + thre1 = 0.1 + thre2 = 0.05 + bodyparts = 18 + multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search] + heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19)) + paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38)) + + for m in range(len(multiplier)): + scale = multiplier[m] + imageToTest = cv2.resize( + oriImg, (0, 0), + fx=scale, + fy=scale, + interpolation=cv2.INTER_CUBIC) + imageToTest_padded, pad = pad_rightdown_corner( + imageToTest, stride, padValue) + im = np.transpose( + np.float32(imageToTest_padded[:, :, :, np.newaxis]), + (3, 2, 0, 1)) / 256 - 0.5 + im = np.ascontiguousarray(im) + + data = torch.from_numpy(im).float() + if torch.cuda.is_available(): + data = data.cuda() + with torch.no_grad(): + Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data) + Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy() + Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy() + + # extract outputs, resize, and remove padding + heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2), + (1, 2, 0)) # output 1 is heatmaps + heatmap = cv2.resize( + heatmap, (0, 0), + fx=stride, + fy=stride, + interpolation=cv2.INTER_CUBIC) + heatmap = heatmap[:imageToTest_padded.shape[0] + - pad[2], :imageToTest_padded.shape[1] + - pad[3], :] + heatmap = cv2.resize( + heatmap, (oriImg.shape[1], oriImg.shape[0]), + interpolation=cv2.INTER_CUBIC) + + paf = np.transpose(np.squeeze(Mconv7_stage6_L1), + (1, 2, 0)) # output 0 is PAFs + paf = cv2.resize( + paf, (0, 0), + fx=stride, + fy=stride, + interpolation=cv2.INTER_CUBIC) + paf = paf[:imageToTest_padded.shape[0] + - pad[2], :imageToTest_padded.shape[1] - pad[3], :] + paf = cv2.resize( + paf, (oriImg.shape[1], oriImg.shape[0]), + interpolation=cv2.INTER_CUBIC) + + heatmap_avg += heatmap_avg + heatmap / len(multiplier) + paf_avg += +paf / len(multiplier) + + all_peaks = [] + peak_counter = 0 + + for part in range(bodyparts): + map_ori = heatmap_avg[:, :, part] + one_heatmap = gaussian_filter(map_ori, sigma=3) + + map_left = np.zeros(one_heatmap.shape) + map_left[1:, :] = one_heatmap[:-1, :] + map_right = np.zeros(one_heatmap.shape) + map_right[:-1, :] = one_heatmap[1:, :] + map_up = np.zeros(one_heatmap.shape) + map_up[:, 1:] = one_heatmap[:, :-1] + map_down = np.zeros(one_heatmap.shape) + map_down[:, :-1] = one_heatmap[:, 1:] + + peaks_binary = np.logical_and.reduce( + (one_heatmap >= map_left, one_heatmap >= map_right, + one_heatmap >= map_up, one_heatmap >= map_down, + one_heatmap > thre1)) + peaks = list( + zip(np.nonzero(peaks_binary)[1], + np.nonzero(peaks_binary)[0])) # note reverse + peaks_with_score = [x + (map_ori[x[1], x[0]], ) for x in peaks] + peak_id = range(peak_counter, peak_counter + len(peaks)) + peaks_with_score_and_id = [ + peaks_with_score[i] + (peak_id[i], ) + for i in range(len(peak_id)) + ] + + all_peaks.append(peaks_with_score_and_id) + peak_counter += len(peaks) + + # find connection in the specified sequence, center 29 is in the position 15 + limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], + [9, 10], [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], + [1, 15], [15, 17], [1, 16], [16, 18], [3, 17], [6, 18]] + # the middle joints heatmap correpondence + mapIdx = [[31, 32], [39, 40], [33, 34], [35, 36], [41, 42], [43, 44], + [19, 20], [21, 22], [23, 24], [25, 26], [27, 28], [29, 30], + [47, 48], [49, 50], [53, 54], [51, 52], [55, 56], [37, 38], + [45, 46]] + + connection_all = [] + special_k = [] + mid_num = 10 + + for k in range(len(mapIdx)): + score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]] + candA = all_peaks[limbSeq[k][0] - 1] + candB = all_peaks[limbSeq[k][1] - 1] + nA = len(candA) + nB = len(candB) + if (nA != 0 and nB != 0): + connection_candidate = [] + for i in range(nA): + for j in range(nB): + vec = np.subtract(candB[j][:2], candA[i][:2]) + norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1]) + norm = max(0.001, norm) + vec = np.divide(vec, norm) + + startend = list( + zip( + np.linspace( + candA[i][0], candB[j][0], num=mid_num), + np.linspace( + candA[i][1], candB[j][1], num=mid_num))) + + vec_x = np.array([ + score_mid[int(round(startend[item][1])), + int(round(startend[item][0])), 0] + for item in range(len(startend)) + ]) + vec_y = np.array([ + score_mid[int(round(startend[item][1])), + int(round(startend[item][0])), 1] + for item in range(len(startend)) + ]) + + score_midpts = np.multiply( + vec_x, vec[0]) + np.multiply(vec_y, vec[1]) + temp1 = sum(score_midpts) / len(score_midpts) + temp2 = min(0.5 * oriImg.shape[0] / norm - 1, 0) + score_with_dist_prior = temp1 + temp2 + criterion1 = len(np.nonzero( + score_midpts > thre2)[0]) > 0.8 * len(score_midpts) + criterion2 = score_with_dist_prior > 0 + if criterion1 and criterion2: + connection_candidate.append([ + i, j, score_with_dist_prior, + score_with_dist_prior + candA[i][2] + + candB[j][2] + ]) + + connection_candidate = sorted( + connection_candidate, key=lambda x: x[2], reverse=True) + connection = np.zeros((0, 5)) + for c in range(len(connection_candidate)): + i, j, s = connection_candidate[c][0:3] + if (i not in connection[:, 3] + and j not in connection[:, 4]): + connection = np.vstack( + [connection, [candA[i][3], candB[j][3], s, i, j]]) + if (len(connection) >= min(nA, nB)): + break + + connection_all.append(connection) + else: + special_k.append(k) + connection_all.append([]) + + # last number in each row is the total parts number of that person + # the second last number in each row is the score of the overall configuration + subset = -1 * np.ones((0, 20)) + candidate = np.array( + [item for sublist in all_peaks for item in sublist]) + + for k in range(len(mapIdx)): + if k not in special_k: + partAs = connection_all[k][:, 0] + partBs = connection_all[k][:, 1] + indexA, indexB = np.array(limbSeq[k]) - 1 + + for i in range(len(connection_all[k])): # = 1:size(temp,1) + found = 0 + subset_idx = [-1, -1] + for j in range(len(subset)): # 1:size(subset,1): + if subset[j][indexA] == partAs[i] or subset[j][ + indexB] == partBs[i]: + subset_idx[found] = j + found += 1 + + if found == 1: + j = subset_idx[0] + if subset[j][indexB] != partBs[i]: + subset[j][indexB] = partBs[i] + subset[j][-1] += 1 + subset[j][-2] += candidate[ + partBs[i].astype(int), + 2] + connection_all[k][i][2] + elif found == 2: # if found 2 and disjoint, merge them + j1, j2 = subset_idx + tmp1 = (subset[j1] >= 0).astype(int) + tmp2 = (subset[j2] >= 0).astype(int) + membership = (tmp1 + tmp2)[:-2] + if len(np.nonzero(membership == 2)[0]) == 0: # merge + subset[j1][:-2] += (subset[j2][:-2] + 1) + subset[j1][-2:] += subset[j2][-2:] + subset[j1][-2] += connection_all[k][i][2] + subset = np.delete(subset, j2, 0) + else: # as like found == 1 + subset[j1][indexB] = partBs[i] + subset[j1][-1] += 1 + subset[j1][-2] += candidate[ + partBs[i].astype(int), + 2] + connection_all[k][i][2] + + # if find no partA in the subset, create a new subset + elif not found and k < 17: + row = -1 * np.ones(20) + row[indexA] = partAs[i] + row[indexB] = partBs[i] + row[-1] = 2 + row[-2] = sum( + candidate[connection_all[k][i, :2].astype(int), + 2]) + connection_all[k][i][2] + subset = np.vstack([subset, row]) + # delete some rows of subset which has few parts occur + deleteIdx = [] + for i in range(len(subset)): + if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4: + deleteIdx.append(i) + subset = np.delete(subset, deleteIdx, axis=0) + + # subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts + # candidate: x, y, score, id + count = subset.shape[0] + joints = np.zeros(shape=(count, bodyparts, 3)) + + for i in range(count): + for j in range(bodyparts): + joints[i, j, :3] = candidate[int(subset[i, j]), :3] + confidence = 1.0 if subset[i, j] >= 0 else 0.0 + joints[i, j, 2] *= confidence + return joints diff --git a/modelscope/models/cv/image_body_reshaping/pose_estimator/model.py b/modelscope/models/cv/image_body_reshaping/pose_estimator/model.py new file mode 100644 index 00000000..12f6e84d --- /dev/null +++ b/modelscope/models/cv/image_body_reshaping/pose_estimator/model.py @@ -0,0 +1,141 @@ +# The implementation is based on openpose, available at https://github.com/Hzzone/pytorch-openpose. + +from collections import OrderedDict + +import torch +import torch.nn as nn + + +def make_layers(block, no_relu_layers): + layers = [] + for layer_name, v in block.items(): + if 'pool' in layer_name: + layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], padding=v[2]) + layers.append((layer_name, layer)) + else: + conv2d = nn.Conv2d( + in_channels=v[0], + out_channels=v[1], + kernel_size=v[2], + stride=v[3], + padding=v[4]) + layers.append((layer_name, conv2d)) + if layer_name not in no_relu_layers: + layers.append(('relu_' + layer_name, nn.ReLU(inplace=True))) + + return nn.Sequential(OrderedDict(layers)) + + +class BodyposeModel(nn.Module): + + def __init__(self): + super(BodyposeModel, self).__init__() + + # these layers have no relu layer + no_relu_layers = [ + 'conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1', + 'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2', + 'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1', + 'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1' + ] + blocks = {} + block0 = OrderedDict([('conv1_1', [3, 64, 3, 1, 1]), + ('conv1_2', [64, 64, 3, 1, 1]), + ('pool1_stage1', [2, 2, 0]), + ('conv2_1', [64, 128, 3, 1, 1]), + ('conv2_2', [128, 128, 3, 1, 1]), + ('pool2_stage1', [2, 2, 0]), + ('conv3_1', [128, 256, 3, 1, 1]), + ('conv3_2', [256, 256, 3, 1, 1]), + ('conv3_3', [256, 256, 3, 1, 1]), + ('conv3_4', [256, 256, 3, 1, 1]), + ('pool3_stage1', [2, 2, 0]), + ('conv4_1', [256, 512, 3, 1, 1]), + ('conv4_2', [512, 512, 3, 1, 1]), + ('conv4_3_CPM', [512, 256, 3, 1, 1]), + ('conv4_4_CPM', [256, 128, 3, 1, 1])]) + + # Stage 1 + block1_1 = OrderedDict([('conv5_1_CPM_L1', [128, 128, 3, 1, 1]), + ('conv5_2_CPM_L1', [128, 128, 3, 1, 1]), + ('conv5_3_CPM_L1', [128, 128, 3, 1, 1]), + ('conv5_4_CPM_L1', [128, 512, 1, 1, 0]), + ('conv5_5_CPM_L1', [512, 38, 1, 1, 0])]) + + block1_2 = OrderedDict([('conv5_1_CPM_L2', [128, 128, 3, 1, 1]), + ('conv5_2_CPM_L2', [128, 128, 3, 1, 1]), + ('conv5_3_CPM_L2', [128, 128, 3, 1, 1]), + ('conv5_4_CPM_L2', [128, 512, 1, 1, 0]), + ('conv5_5_CPM_L2', [512, 19, 1, 1, 0])]) + blocks['block1_1'] = block1_1 + blocks['block1_2'] = block1_2 + + self.model0 = make_layers(block0, no_relu_layers) + + # Stages 2 - 6 + for i in range(2, 7): + blocks['block%d_1' % i] = OrderedDict([ + ('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]), + ('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]), + ('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0]) + ]) + + blocks['block%d_2' % i] = OrderedDict([ + ('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]), + ('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]), + ('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0]) + ]) + + for k in blocks.keys(): + blocks[k] = make_layers(blocks[k], no_relu_layers) + + self.model1_1 = blocks['block1_1'] + self.model2_1 = blocks['block2_1'] + self.model3_1 = blocks['block3_1'] + self.model4_1 = blocks['block4_1'] + self.model5_1 = blocks['block5_1'] + self.model6_1 = blocks['block6_1'] + + self.model1_2 = blocks['block1_2'] + self.model2_2 = blocks['block2_2'] + self.model3_2 = blocks['block3_2'] + self.model4_2 = blocks['block4_2'] + self.model5_2 = blocks['block5_2'] + self.model6_2 = blocks['block6_2'] + + def forward(self, x): + + out1 = self.model0(x) + + out1_1 = self.model1_1(out1) + out1_2 = self.model1_2(out1) + out2 = torch.cat([out1_1, out1_2, out1], 1) + + out2_1 = self.model2_1(out2) + out2_2 = self.model2_2(out2) + out3 = torch.cat([out2_1, out2_2, out1], 1) + + out3_1 = self.model3_1(out3) + out3_2 = self.model3_2(out3) + out4 = torch.cat([out3_1, out3_2, out1], 1) + + out4_1 = self.model4_1(out4) + out4_2 = self.model4_2(out4) + out5 = torch.cat([out4_1, out4_2, out1], 1) + + out5_1 = self.model5_1(out5) + out5_2 = self.model5_2(out5) + out6 = torch.cat([out5_1, out5_2, out1], 1) + + out6_1 = self.model6_1(out6) + out6_2 = self.model6_2(out6) + + return out6_1, out6_2 diff --git a/modelscope/models/cv/image_body_reshaping/pose_estimator/util.py b/modelscope/models/cv/image_body_reshaping/pose_estimator/util.py new file mode 100644 index 00000000..13a42074 --- /dev/null +++ b/modelscope/models/cv/image_body_reshaping/pose_estimator/util.py @@ -0,0 +1,33 @@ +# The implementation is based on openpose, available at https://github.com/Hzzone/pytorch-openpose. +import numpy as np + + +def pad_rightdown_corner(img, stride, padValue): + h = img.shape[0] + w = img.shape[1] + + pad = 4 * [None] + pad[0] = 0 # up + pad[1] = 0 # left + pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down + pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right + + img_padded = img + pad_up = np.tile(img_padded[0:1, :, :] * 0 + padValue, (pad[0], 1, 1)) + img_padded = np.concatenate((pad_up, img_padded), axis=0) + pad_left = np.tile(img_padded[:, 0:1, :] * 0 + padValue, (1, pad[1], 1)) + img_padded = np.concatenate((pad_left, img_padded), axis=1) + pad_down = np.tile(img_padded[-2:-1, :, :] * 0 + padValue, (pad[2], 1, 1)) + img_padded = np.concatenate((img_padded, pad_down), axis=0) + pad_right = np.tile(img_padded[:, -2:-1, :] * 0 + padValue, (1, pad[3], 1)) + img_padded = np.concatenate((img_padded, pad_right), axis=1) + + return img_padded, pad + + +def transfer(model, model_weights): + transfered_model_weights = {} + for weights_name in model.state_dict().keys(): + transfered_model_weights[weights_name] = model_weights['.'.join( + weights_name.split('.')[1:])] + return transfered_model_weights diff --git a/modelscope/models/cv/image_body_reshaping/slim_utils.py b/modelscope/models/cv/image_body_reshaping/slim_utils.py new file mode 100644 index 00000000..23d5a741 --- /dev/null +++ b/modelscope/models/cv/image_body_reshaping/slim_utils.py @@ -0,0 +1,507 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math +import os +import random + +import cv2 +import numba +import numpy as np +import torch + + +def resize_on_long_side(img, long_side=800): + src_height = img.shape[0] + src_width = img.shape[1] + + if src_height > src_width: + scale = long_side * 1.0 / src_height + _img = cv2.resize( + img, (int(src_width * scale), long_side), + interpolation=cv2.INTER_LINEAR) + else: + scale = long_side * 1.0 / src_width + _img = cv2.resize( + img, (long_side, int(src_height * scale)), + interpolation=cv2.INTER_LINEAR) + + return _img, scale + + +def point_in_box(pt, box): + pt_x = pt[0] + pt_y = pt[1] + + if pt_x >= box[0] and pt_x <= box[0] + box[2] and pt_y >= box[ + 1] and pt_y <= box[1] + box[3]: + return True + else: + return False + + +def enlarge_box_tblr(roi_bbox, mask, ratio=0.4, use_long_side=True): + if roi_bbox is None or None in roi_bbox: + return [None, None, None, None] + + top = roi_bbox[0] + bottom = roi_bbox[1] + left = roi_bbox[2] + right = roi_bbox[3] + + roi_width = roi_bbox[3] - roi_bbox[2] + roi_height = roi_bbox[1] - roi_bbox[0] + right = left + roi_width + bottom = top + roi_height + + long_side = roi_width if roi_width > roi_height else roi_height + + if use_long_side: + new_left = left - int(long_side * ratio) + else: + new_left = left - int(roi_width * ratio) + new_left = 1 if new_left < 0 else new_left + + if use_long_side: + new_top = top - int(long_side * ratio) + else: + new_top = top - int(roi_height * ratio) + new_top = 1 if new_top < 0 else new_top + + if use_long_side: + new_right = right + int(long_side * ratio) + else: + new_right = right + int(roi_width * ratio) + new_right = mask.shape[1] - 2 if new_right > mask.shape[1] else new_right + + if use_long_side: + new_bottom = bottom + int(long_side * ratio) + else: + new_bottom = bottom + int(roi_height * ratio) + new_bottom = mask.shape[0] - 2 if new_bottom > mask.shape[0] else new_bottom + + bbox = [new_top, new_bottom, new_left, new_right] + return bbox + + +def gen_PAF(image, joints): + + assert joints.shape[0] == 18 + assert joints.shape[1] == 3 + + org_h = image.shape[0] + org_w = image.shape[1] + small_image, resize_scale = resize_on_long_side(image, 120) + + joints[:, :2] = joints[:, :2] * resize_scale + + joint_left = int(np.min(joints, axis=0)[0]) + joint_right = int(np.max(joints, axis=0)[0]) + joint_top = int(np.min(joints, axis=0)[1]) + joint_bottom = int(np.max(joints, axis=0)[1]) + + limb_width = min( + abs(joint_right - joint_left), abs(joint_bottom - joint_top)) // 6 + + if limb_width % 2 == 0: + limb_width += 1 + kernel_size = limb_width + + part_orders = [(5, 11), (2, 8), (5, 6), (6, 7), (2, 3), (3, 4), (11, 12), + (12, 13), (8, 9), (9, 10)] + + map_list = [] + mask_list = [] + PAF_all = np.zeros( + shape=(small_image.shape[0], small_image.shape[1], 2), + dtype=np.float32) + for c, pair in enumerate(part_orders): + idx_a_name = pair[0] + idx_b_name = pair[1] + + jointa = joints[idx_a_name] + jointb = joints[idx_b_name] + + confidence_threshold = 0.05 + if jointa[2] > confidence_threshold and jointb[ + 2] > confidence_threshold: + canvas = np.zeros( + shape=(small_image.shape[0], small_image.shape[1]), + dtype=np.uint8) + + canvas = cv2.line(canvas, (int(jointa[0]), int(jointa[1])), + (int(jointb[0]), int(jointb[1])), + (255, 255, 255), 5) + + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, + (kernel_size, kernel_size)) + + canvas = cv2.dilate(canvas, kernel, 1) + canvas = cv2.GaussianBlur(canvas, (kernel_size, kernel_size), 0) + canvas = canvas.astype(np.float32) / 255 + PAF = np.zeros( + shape=(small_image.shape[0], small_image.shape[1], 2), + dtype=np.float32) + PAF[..., 0] = jointb[0] - jointa[0] + PAF[..., 1] = jointb[1] - jointa[1] + mag, ang = cv2.cartToPolar(PAF[..., 0], PAF[..., 1]) + PAF /= (np.dstack((mag, mag)) + 1e-5) + + single_PAF = PAF * np.dstack((canvas, canvas)) + map_list.append( + cv2.GaussianBlur(single_PAF, + (kernel_size * 3, kernel_size * 3), 0)) + + mask_list.append( + cv2.GaussianBlur(canvas.copy(), + (kernel_size * 3, kernel_size * 3), 0)) + PAF_all = PAF_all * (1.0 - np.dstack( + (canvas, canvas))) + single_PAF + + PAF_all = cv2.GaussianBlur(PAF_all, (kernel_size * 3, kernel_size * 3), 0) + PAF_all = cv2.resize( + PAF_all, (org_w, org_h), interpolation=cv2.INTER_LINEAR) + map_list.append(PAF_all) + return PAF_all, map_list, mask_list + + +def gen_skeleton_map(joints, stack_mode='column', input_roi_box=None): + if type(joints) == list: + joints = np.array(joints) + assert stack_mode == 'column' or stack_mode == 'depth' + + part_orders = [(2, 5), (5, 11), (2, 8), (8, 11), (5, 6), (6, 7), (2, 3), + (3, 4), (11, 12), (12, 13), (8, 9), (9, 10)] + + def link(img, a, b, color, line_width, scale=1.0, x_offset=0, y_offset=0): + jointa = joints[a] + jointb = joints[b] + + temp1 = int((jointa[0] - x_offset) * scale) + temp2 = int((jointa[1] - y_offset) * scale) + temp3 = int((jointb[0] - x_offset) * scale) + temp4 = int((jointb[1] - y_offset) * scale) + + cv2.line(img, (temp1, temp2), (temp3, temp4), color, line_width) + + roi_box = input_roi_box + + roi_box_width = roi_box[3] - roi_box[2] + roi_box_height = roi_box[1] - roi_box[0] + short_side_length = min(roi_box_width, roi_box_height) + line_width = short_side_length // 30 + + line_width = max(line_width, 2) + + map_cube = np.zeros( + shape=(roi_box_height, roi_box_width, len(part_orders) + 1), + dtype=np.float32) + + use_line_width = min(5, line_width) + fx = use_line_width * 1.0 / line_width # fx 最大值为1 + + if fx < 0.99: + map_cube = cv2.resize(map_cube, (0, 0), fx=fx, fy=fx) + + for c, pair in enumerate(part_orders): + tmp = map_cube[..., c].copy() + link( + tmp, + pair[0], + pair[1], (2.0, 2.0, 2.0), + use_line_width, + scale=fx, + x_offset=roi_box[2], + y_offset=roi_box[0]) + map_cube[..., c] = tmp + + tmp = map_cube[..., -1].copy() + link( + tmp, + pair[0], + pair[1], (2.0, 2.0, 2.0), + use_line_width, + scale=fx, + x_offset=roi_box[2], + y_offset=roi_box[0]) + map_cube[..., -1] = tmp + + map_cube = cv2.resize(map_cube, (roi_box_width, roi_box_height)) + + if stack_mode == 'depth': + return map_cube, roi_box + elif stack_mode == 'column': + joint_maps = [] + for c in range(len(part_orders) + 1): + joint_maps.append(map_cube[..., c]) + joint_map = np.column_stack(joint_maps) + + return joint_map, roi_box + + +def plot_one_box(x, img, color=None, label=None, line_thickness=None): + tl = line_thickness or round( + 0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness + color = color or [random.randint(0, 255) for _ in range(3)] + c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3])) + cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA) + if label: + tf = max(tl - 1, 1) # font thickness + t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0] + c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3 + cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled + cv2.putText( + img, + label, (c1[0], c1[1] - 2), + 0, + tl / 3, [225, 255, 255], + thickness=tf, + lineType=cv2.LINE_AA) + + +def draw_line(im, points, color, stroke_size=2, closed=False): + points = points.astype(np.int32) + for i in range(len(points) - 1): + cv2.line(im, tuple(points[i]), tuple(points[i + 1]), color, + stroke_size) + if closed: + cv2.line(im, tuple(points[0]), tuple(points[-1]), color, stroke_size) + + +def enlarged_bbox(bbox, img_width, img_height, enlarge_ratio=0.2): + left = bbox[0] + top = bbox[1] + + right = bbox[2] + bottom = bbox[3] + + roi_width = right - left + roi_height = bottom - top + + new_left = left - int(roi_width * enlarge_ratio) + new_left = 0 if new_left < 0 else new_left + + new_top = top - int(roi_height * enlarge_ratio) + new_top = 0 if new_top < 0 else new_top + + new_right = right + int(roi_width * enlarge_ratio) + new_right = img_width if new_right > img_width else new_right + + new_bottom = bottom + int(roi_height * enlarge_ratio) + new_bottom = img_height if new_bottom > img_height else new_bottom + + bbox = [new_left, new_top, new_right, new_bottom] + + bbox = [int(x) for x in bbox] + + return bbox + + +def get_map_fusion_map_cuda(map_list, threshold=1, device=torch.device('cpu')): + map_list_cuda = [torch.from_numpy(x).to(device) for x in map_list] + map_concat = torch.stack(tuple(map_list_cuda), dim=-1) + + map_concat = torch.abs(map_concat) + + map_concat[map_concat < threshold] = 0 + map_concat[map_concat > 1e-5] = 1.0 + + sum_map = torch.sum(map_concat, dim=2) + a = torch.ones_like(sum_map) + acc_map = torch.where(sum_map > 0, a * 2.0, torch.zeros_like(sum_map)) + + fusion_map = torch.where(sum_map < 0.5, a * 1.5, sum_map) + + fusion_map = fusion_map.float() + acc_map = acc_map.float() + + fusion_map = fusion_map.cpu().numpy().astype(np.float32) + acc_map = acc_map.cpu().numpy().astype(np.float32) + + return fusion_map, acc_map + + +def gen_border_shade(height, width, height_band, width_band): + height_ratio = height_band * 1.0 / height + width_ratio = width_band * 1.0 / width + + _height_band = int(256 * height_ratio) + _width_band = int(256 * width_ratio) + + canvas = np.zeros((256, 256), dtype=np.float32) + + canvas[_height_band // 2:-_height_band // 2, + _width_band // 2:-_width_band // 2] = 1.0 + + canvas = cv2.blur(canvas, (_height_band, _width_band)) + + canvas = cv2.resize(canvas, (width, height)) + + return canvas + + +def get_mask_bbox(mask, threshold=127): + ret, mask = cv2.threshold(mask, threshold, 1, 0) + + if cv2.countNonZero(mask) == 0: + return [None, None, None, None] + + col_acc = np.sum(mask, 0) + row_acc = np.sum(mask, 1) + + col_acc = col_acc.tolist() + row_acc = row_acc.tolist() + + for x in range(len(col_acc)): + if col_acc[x] > 0: + left = x + break + + for x in range(1, len(col_acc)): + if col_acc[-x] > 0: + right = len(col_acc) - x + break + + for x in range(len(row_acc)): + if row_acc[x] > 0: + top = x + break + + for x in range(1, len(row_acc)): + if row_acc[-x] > 0: + bottom = len(row_acc[::-1]) - x + break + return [top, bottom, left, right] + + +def visualize_flow(flow): + h, w = flow.shape[:2] + hsv = np.zeros((h, w, 3), np.uint8) + mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1]) + + hsv[..., 0] = ang * 180 / np.pi / 2 + hsv[..., 1] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) + hsv[..., 2] = 255 + bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) + bgr = bgr * 1.0 / 255 + return bgr.astype(np.float32) + + +def vis_joints(image, joints, color, show_text=True, confidence_threshold=0.1): + + part_orders = [(2, 5), (5, 11), (2, 8), (8, 11), (5, 6), (6, 7), (2, 3), + (3, 4), (11, 12), (12, 13), (8, 9), (9, 10)] + + abandon_idxs = [0, 1, 14, 15, 16, 17] + # draw joints + for i, joint in enumerate(joints): + if i in abandon_idxs: + continue + if joint[-1] > confidence_threshold: + + cv2.circle(image, (int(joint[0]), int(joint[1])), 1, color, 2) + if show_text: + cv2.putText(image, + str(i) + '[{:.2f}]'.format(joint[-1]), + (int(joint[0]), int(joint[1])), + cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) + # draw link + for pair in part_orders: + if joints[pair[0]][-1] > confidence_threshold and joints[ + pair[1]][-1] > confidence_threshold: + cv2.line(image, (int(joints[pair[0]][0]), int(joints[pair[0]][1])), + (int(joints[pair[1]][0]), int(joints[pair[1]][1])), color, + 2) + return image + + +def get_heatmap_cv(img, magn, max_flow_mag): + min_flow_mag = .5 + cv_magn = np.clip( + 255 * (magn - min_flow_mag) / (max_flow_mag - min_flow_mag + 1e-7), + a_min=0, + a_max=255).astype(np.uint8) + if img.dtype != np.uint8: + img = (255 * img).astype(np.uint8) + + heatmap_img = cv2.applyColorMap(cv_magn, cv2.COLORMAP_JET) + heatmap_img = heatmap_img[..., ::-1] + + h, w = magn.shape + img_alpha = np.ones((h, w), dtype=np.double)[:, :, None] + heatmap_alpha = np.clip( + magn / (max_flow_mag + 1e-7), a_min=1e-7, a_max=1)[:, :, None]**.7 + heatmap_alpha[heatmap_alpha < .2]**.5 + pm_hm = heatmap_img * heatmap_alpha + pm_img = img * img_alpha + cv_out = pm_hm + pm_img * (1 - heatmap_alpha) + cv_out = np.clip(cv_out, a_min=0, a_max=255).astype(np.uint8) + + return cv_out + + +def save_heatmap_cv(img, flow, supression=2): + + flow_magn = np.sqrt(flow[:, :, 0]**2 + flow[:, :, 1]**2) + flow_magn -= supression + flow_magn[flow_magn <= 0] = 0 + cv_out = get_heatmap_cv(img, flow_magn, np.max(flow_magn) * 1.3) + return cv_out + + +@numba.jit(nopython=True, parallel=False) +def bilinear_interp(x, y, v11, v12, v21, v22): + temp1 = (v11 * (1 - y) + v12 * y) * (1 - x) + temp2 = (v21 * (1 - y) + v22 * y) * x + result = temp1 + temp2 + return result + + +@numba.jit(nopython=True, parallel=False) +def image_warp_grid1(rDx, rDy, oriImg, transRatio, width_expand, + height_expand): + srcW = oriImg.shape[1] + srcH = oriImg.shape[0] + + newImg = oriImg.copy() + + for i in range(srcH): + for j in range(srcW): + _i = i + _j = j + + deltaX = rDx[_i, _j] + deltaY = rDy[_i, _j] + + nx = _j + deltaX * transRatio + ny = _i + deltaY * transRatio + + if nx >= srcW - width_expand - 1: + if nx > srcW - 1: + nx = srcW - 1 + + if ny >= srcH - height_expand - 1: + if ny > srcH - 1: + ny = srcH - 1 + + if nx < width_expand: + if nx < 0: + nx = 0 + + if ny < height_expand: + if ny < 0: + ny = 0 + + nxi = int(math.floor(nx)) + nyi = int(math.floor(ny)) + nxi1 = int(math.ceil(nx)) + nyi1 = int(math.ceil(ny)) + + for ll in range(3): + newImg[_i, _j, + ll] = bilinear_interp(ny - nyi, nx - nxi, + oriImg[nyi, nxi, + ll], oriImg[nyi, nxi1, ll], + oriImg[nyi1, nxi, + ll], oriImg[nyi1, nxi1, + ll]) + return newImg diff --git a/modelscope/outputs.py b/modelscope/outputs.py index 717ff4dd..c16e256e 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -184,6 +184,7 @@ TASK_OUTPUTS = { Tasks.image_to_image_translation: [OutputKeys.OUTPUT_IMG], Tasks.image_style_transfer: [OutputKeys.OUTPUT_IMG], Tasks.image_portrait_stylization: [OutputKeys.OUTPUT_IMG], + Tasks.image_body_reshaping: [OutputKeys.OUTPUT_IMG], # live category recognition result for single video # { diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 7fa66b5f..c9a70d14 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -75,6 +75,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { 'damo/nlp_bart_text-error-correction_chinese'), Tasks.image_captioning: (Pipelines.image_captioning, 'damo/ofa_image-caption_coco_large_en'), + Tasks.image_body_reshaping: (Pipelines.image_body_reshaping, + 'damo/cv_flow-based-body-reshaping_damo'), Tasks.image_portrait_stylization: (Pipelines.person_image_cartoon, 'damo/cv_unet_person-image-cartoon_compound-models'), diff --git a/modelscope/pipelines/cv/image_body_reshaping_pipeline.py b/modelscope/pipelines/cv/image_body_reshaping_pipeline.py new file mode 100644 index 00000000..c3600eb5 --- /dev/null +++ b/modelscope/pipelines/cv/image_body_reshaping_pipeline.py @@ -0,0 +1,40 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_body_reshaping, module_name=Pipelines.image_body_reshaping) +class ImageBodyReshapingPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a image body reshaping pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + logger.info('body reshaping model init done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_ndarray(input) + result = {'img': img} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + output = self.model.inference(input['img']) + result = {'outputs': output} + return result + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + output_img = inputs['outputs'] + return {OutputKeys.OUTPUT_IMG: output_img} diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 5bc27c03..2331dc85 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -60,7 +60,7 @@ class CVTasks(object): image_to_image_generation = 'image-to-image-generation' image_style_transfer = 'image-style-transfer' image_portrait_stylization = 'image-portrait-stylization' - + image_body_reshaping = 'image-body-reshaping' image_embedding = 'image-embedding' product_retrieval_embedding = 'product-retrieval-embedding' diff --git a/requirements/cv.txt b/requirements/cv.txt index 5a2d7763..f907256d 100644 --- a/requirements/cv.txt +++ b/requirements/cv.txt @@ -13,6 +13,7 @@ ml_collections mmcls>=0.21.0 mmdet>=2.25.0 networkx>=2.5 +numba onnxruntime>=1.10 pai-easycv>=0.6.3.6 pandas diff --git a/tests/pipelines/test_image_body_reshaping.py b/tests/pipelines/test_image_body_reshaping.py new file mode 100644 index 00000000..e1955e94 --- /dev/null +++ b/tests/pipelines/test_image_body_reshaping.py @@ -0,0 +1,58 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +import unittest + +import cv2 + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class ImageBodyReshapingTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.image_body_reshaping + self.model_id = 'damo/cv_flow-based-body-reshaping_damo' + self.test_image = 'data/test/images/image_body_reshaping.jpg' + + def pipeline_inference(self, pipeline: Pipeline, input_location: str): + result = pipeline(input_location) + if result is not None: + cv2.imwrite('result_bodyreshaping.png', + result[OutputKeys.OUTPUT_IMG]) + print( + f'Output written to {osp.abspath("result_body_reshaping.png")}' + ) + else: + raise Exception('Testing failed: invalid output') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download(self): + model_dir = snapshot_download(self.model_id) + image_body_reshaping = pipeline( + Tasks.image_body_reshaping, model=model_dir) + self.pipeline_inference(image_body_reshaping, self.test_image) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + image_body_reshaping = pipeline( + Tasks.image_body_reshaping, model=self.model_id) + self.pipeline_inference(image_body_reshaping, self.test_image) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + image_body_reshaping = pipeline(Tasks.image_body_reshaping) + self.pipeline_inference(image_body_reshaping, self.test_image) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main()