Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10217723 * add image_body_reshaping codemaster
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:b2c1119e3d521cf2e583b1e85fc9c9afd1d44954b433135039a98050a730932d | |||||
| size 1127557 | |||||
| @@ -43,6 +43,7 @@ class Models(object): | |||||
| face_human_hand_detection = 'face-human-hand-detection' | face_human_hand_detection = 'face-human-hand-detection' | ||||
| face_emotion = 'face-emotion' | face_emotion = 'face-emotion' | ||||
| product_segmentation = 'product-segmentation' | product_segmentation = 'product-segmentation' | ||||
| image_body_reshaping = 'image-body-reshaping' | |||||
| # EasyCV models | # EasyCV models | ||||
| yolox = 'YOLOX' | yolox = 'YOLOX' | ||||
| @@ -187,6 +188,7 @@ class Pipelines(object): | |||||
| face_human_hand_detection = 'face-human-hand-detection' | face_human_hand_detection = 'face-human-hand-detection' | ||||
| face_emotion = 'face-emotion' | face_emotion = 'face-emotion' | ||||
| product_segmentation = 'product-segmentation' | product_segmentation = 'product-segmentation' | ||||
| image_body_reshaping = 'flow-based-body-reshaping' | |||||
| # nlp tasks | # nlp tasks | ||||
| automatic_post_editing = 'automatic-post-editing' | automatic_post_editing = 'automatic-post-editing' | ||||
| @@ -0,0 +1,20 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import TYPE_CHECKING | |||||
| from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | |||||
| from .image_body_reshaping import ImageBodyReshaping | |||||
| else: | |||||
| _import_structure = {'image_body_reshaping': ['ImageBodyReshaping']} | |||||
| import sys | |||||
| sys.modules[__name__] = LazyImportModule( | |||||
| __name__, | |||||
| globals()['__file__'], | |||||
| _import_structure, | |||||
| module_spec=__spec__, | |||||
| extra_objects={}, | |||||
| ) | |||||
| @@ -0,0 +1,128 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| from typing import Any, Dict | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import torch | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base import Tensor, TorchModel | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| from .model import FlowGenerator | |||||
| from .person_info import PersonInfo | |||||
| from .pose_estimator.body import Body | |||||
| from .slim_utils import image_warp_grid1, resize_on_long_side | |||||
| logger = get_logger() | |||||
| __all__ = ['ImageBodyReshaping'] | |||||
| @MODELS.register_module( | |||||
| Tasks.image_body_reshaping, module_name=Models.image_body_reshaping) | |||||
| class ImageBodyReshaping(TorchModel): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """initialize the image body reshaping model from the `model_dir` path. | |||||
| Args: | |||||
| model_dir (str): the model path. | |||||
| """ | |||||
| super().__init__(model_dir, *args, **kwargs) | |||||
| if torch.cuda.is_available(): | |||||
| self.device = torch.device('cuda') | |||||
| else: | |||||
| self.device = torch.device('cpu') | |||||
| self.degree = 1.0 | |||||
| self.reshape_model = FlowGenerator(n_channels=16).to(self.device) | |||||
| model_path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE) | |||||
| checkpoints = torch.load(model_path, map_location=torch.device('cpu')) | |||||
| self.reshape_model.load_state_dict( | |||||
| checkpoints['state_dict'], strict=True) | |||||
| self.reshape_model.eval() | |||||
| logger.info('load body reshaping model done') | |||||
| pose_model_ckpt = os.path.join(model_dir, 'body_pose_model.pth') | |||||
| self.pose_esti = Body(pose_model_ckpt, self.device) | |||||
| logger.info('load pose model done') | |||||
| def pred_joints(self, img): | |||||
| if img is None: | |||||
| return None | |||||
| small_src, resize_scale = resize_on_long_side(img, 300) | |||||
| body_joints = self.pose_esti(small_src) | |||||
| if body_joints.shape[0] >= 1: | |||||
| body_joints[:, :, :2] = body_joints[:, :, :2] / resize_scale | |||||
| return body_joints | |||||
| def pred_flow(self, img): | |||||
| body_joints = self.pred_joints(img) | |||||
| small_size = 1200 | |||||
| if img.shape[0] > small_size or img.shape[1] > small_size: | |||||
| _img, _scale = resize_on_long_side(img, small_size) | |||||
| body_joints[:, :, :2] = body_joints[:, :, :2] * _scale | |||||
| else: | |||||
| _img = img | |||||
| # We only reshape one person | |||||
| if body_joints.shape[0] < 1 or body_joints.shape[0] > 1: | |||||
| return None | |||||
| person = PersonInfo(body_joints[0]) | |||||
| with torch.no_grad(): | |||||
| person_pred = person.pred_flow(_img, self.reshape_model, | |||||
| self.device) | |||||
| flow = np.dstack((person_pred['rDx'], person_pred['rDy'])) | |||||
| scale = img.shape[0] * 1.0 / flow.shape[0] | |||||
| flow = cv2.resize(flow, (img.shape[1], img.shape[0])) | |||||
| flow *= scale | |||||
| return flow | |||||
| def warp(self, src_img, flow): | |||||
| X_flow = flow[..., 0] | |||||
| Y_flow = flow[..., 1] | |||||
| X_flow = np.ascontiguousarray(X_flow) | |||||
| Y_flow = np.ascontiguousarray(Y_flow) | |||||
| pred = image_warp_grid1(X_flow, Y_flow, src_img, 1.0, 0, 0) | |||||
| return pred | |||||
| def inference(self, img): | |||||
| img = img.cpu().numpy() | |||||
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |||||
| flow = self.pred_flow(img) | |||||
| if flow is None: | |||||
| return img | |||||
| assert flow.shape[:2] == img.shape[:2] | |||||
| mag, ang = cv2.cartToPolar(flow[..., 0] + 1e-8, flow[..., 1] + 1e-8) | |||||
| mag -= 3 | |||||
| mag[mag <= 0] = 0 | |||||
| x, y = cv2.polarToCart(mag, ang, angleInDegrees=False) | |||||
| flow = np.dstack((x, y)) | |||||
| flow *= self.degree | |||||
| pred = self.warp(img, flow) | |||||
| out_img = np.clip(pred, 0, 255) | |||||
| logger.info('model inference done') | |||||
| return out_img.astype(np.uint8) | |||||
| @@ -0,0 +1,189 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| class ConvLayer(nn.Module): | |||||
| def __init__(self, in_ch, out_ch): | |||||
| super(ConvLayer, self).__init__() | |||||
| self.conv = nn.Sequential( | |||||
| nn.ReflectionPad2d(1), | |||||
| nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=0), | |||||
| nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True)) | |||||
| def forward(self, x): | |||||
| x = self.conv(x) | |||||
| return x | |||||
| class SASA(nn.Module): | |||||
| def __init__(self, in_dim): | |||||
| super(SASA, self).__init__() | |||||
| self.chanel_in = in_dim | |||||
| self.query_conv = nn.Conv2d( | |||||
| in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) | |||||
| self.key_conv = nn.Conv2d( | |||||
| in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) | |||||
| self.value_conv = nn.Conv2d( | |||||
| in_channels=in_dim, out_channels=in_dim, kernel_size=1) | |||||
| self.mag_conv = nn.Conv2d( | |||||
| in_channels=5, out_channels=in_dim // 32, kernel_size=1) | |||||
| self.gamma = nn.Parameter(torch.zeros(1)) | |||||
| self.softmax = nn.Softmax(dim=-1) # | |||||
| self.sigmoid = nn.Sigmoid() | |||||
| def structure_encoder(self, paf_mag, target_height, target_width): | |||||
| torso_mask = torch.sum(paf_mag[:, 1:3, :, :], dim=1, keepdim=True) | |||||
| torso_mask = torch.clamp(torso_mask, 0, 1) | |||||
| arms_mask = torch.sum(paf_mag[:, 4:8, :, :], dim=1, keepdim=True) | |||||
| arms_mask = torch.clamp(arms_mask, 0, 1) | |||||
| legs_mask = torch.sum(paf_mag[:, 8:12, :, :], dim=1, keepdim=True) | |||||
| legs_mask = torch.clamp(legs_mask, 0, 1) | |||||
| fg_mask = paf_mag[:, 12, :, :].unsqueeze(1) | |||||
| bg_mask = 1 - fg_mask | |||||
| Y = torch.cat((arms_mask, torso_mask, legs_mask, fg_mask, bg_mask), | |||||
| dim=1) | |||||
| Y = F.interpolate(Y, size=(target_height, target_width), mode='area') | |||||
| return Y | |||||
| def forward(self, X, PAF_mag): | |||||
| """extract self-attention features. | |||||
| Args: | |||||
| X : input feature maps( B x C x H x W) | |||||
| PAF_mag : ( B x C x H x W), 1 denotes connectivity, 0 denotes non-connectivity | |||||
| Returns: | |||||
| out : self attention value + input feature | |||||
| Y: B X N X N (N is Width*Height) | |||||
| """ | |||||
| m_batchsize, C, height, width = X.size() | |||||
| Y = self.structure_encoder(PAF_mag, height, width) | |||||
| connectivity_mask_vec = self.mag_conv(Y).view(m_batchsize, -1, | |||||
| width * height) | |||||
| affinity = torch.bmm( | |||||
| connectivity_mask_vec.permute(0, 2, 1), connectivity_mask_vec) | |||||
| affinity_centered = affinity - torch.mean(affinity) | |||||
| affinity_sigmoid = self.sigmoid(affinity_centered) | |||||
| proj_query = self.query_conv(X).view(m_batchsize, -1, | |||||
| width * height).permute(0, 2, 1) | |||||
| proj_key = self.key_conv(X).view(m_batchsize, -1, width * height) | |||||
| selfatten_map = torch.bmm(proj_query, proj_key) | |||||
| selfatten_centered = selfatten_map - torch.mean( | |||||
| selfatten_map) # centering | |||||
| selfatten_sigmoid = self.sigmoid(selfatten_centered) | |||||
| SASA_map = selfatten_sigmoid * affinity_sigmoid | |||||
| proj_value = self.value_conv(X).view(m_batchsize, -1, width * height) | |||||
| out = torch.bmm(proj_value, SASA_map.permute(0, 2, 1)) | |||||
| out = out.view(m_batchsize, C, height, width) | |||||
| out = self.gamma * out + X | |||||
| return out, Y | |||||
| class FlowGenerator(nn.Module): | |||||
| def __init__(self, n_channels, deep_supervision=False): | |||||
| super(FlowGenerator, self).__init__() | |||||
| self.deep_supervision = deep_supervision | |||||
| self.Encoder = nn.Sequential( | |||||
| ConvLayer(n_channels, 64), | |||||
| ConvLayer(64, 64), | |||||
| nn.MaxPool2d(2), | |||||
| ConvLayer(64, 128), | |||||
| ConvLayer(128, 128), | |||||
| nn.MaxPool2d(2), | |||||
| ConvLayer(128, 256), | |||||
| ConvLayer(256, 256), | |||||
| nn.MaxPool2d(2), | |||||
| ConvLayer(256, 512), | |||||
| ConvLayer(512, 512), | |||||
| nn.MaxPool2d(2), | |||||
| ConvLayer(512, 1024), | |||||
| ConvLayer(1024, 1024), | |||||
| ConvLayer(1024, 1024), | |||||
| ConvLayer(1024, 1024), | |||||
| ConvLayer(1024, 1024), | |||||
| ) | |||||
| self.SASA = SASA(in_dim=1024) | |||||
| self.Decoder = nn.Sequential( | |||||
| ConvLayer(1024, 1024), | |||||
| nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), | |||||
| ConvLayer(1024, 512), | |||||
| ConvLayer(512, 512), | |||||
| nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), | |||||
| ConvLayer(512, 256), | |||||
| ConvLayer(256, 256), | |||||
| ConvLayer(256, 128), | |||||
| ConvLayer(128, 64), | |||||
| ConvLayer(64, 32), | |||||
| nn.Conv2d(32, 2, kernel_size=1, padding=0), | |||||
| nn.Tanh(), | |||||
| nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True), | |||||
| ) | |||||
| dilation_ksize = 17 | |||||
| self.dilation = torch.nn.MaxPool2d( | |||||
| kernel_size=dilation_ksize, | |||||
| stride=1, | |||||
| padding=int((dilation_ksize - 1) / 2)) | |||||
| def warp(self, x, flow, mode='bilinear', padding_mode='zeros', coff=0.2): | |||||
| n, c, h, w = x.size() | |||||
| yv, xv = torch.meshgrid([torch.arange(h), torch.arange(w)]) | |||||
| xv = xv.float() / (w - 1) * 2.0 - 1 | |||||
| yv = yv.float() / (h - 1) * 2.0 - 1 | |||||
| grid = torch.cat((xv.unsqueeze(-1), yv.unsqueeze(-1)), -1).unsqueeze(0) | |||||
| grid = grid.to(flow.device) | |||||
| grid_x = grid + 2 * flow * coff | |||||
| warp_x = F.grid_sample(x, grid_x, mode=mode, padding_mode=padding_mode) | |||||
| return warp_x | |||||
| def forward(self, img, skeleton_map, coef=0.2): | |||||
| """extract self-attention features. | |||||
| Args: | |||||
| img : input numpy image | |||||
| skeleton_map : skeleton map of input image | |||||
| coef: warp degree | |||||
| Returns: | |||||
| warp_x : warped image | |||||
| flow: predicted flow | |||||
| """ | |||||
| img_concat = torch.cat((img, skeleton_map), dim=1) | |||||
| X = self.Encoder(img_concat) | |||||
| _, _, height, width = X.size() | |||||
| # directly get PAF magnitude from skeleton maps via dilation | |||||
| PAF_mag = self.dilation((skeleton_map + 1.0) * 0.5) | |||||
| out, Y = self.SASA(X, PAF_mag) | |||||
| flow = self.Decoder(out) | |||||
| flow = flow.permute(0, 2, 3, 1) # [n, 2, h, w] ==> [n, h, w, 2] | |||||
| warp_x = self.warp(img, flow, coff=coef) | |||||
| warp_x = torch.clamp(warp_x, min=-1.0, max=1.0) | |||||
| return warp_x, flow | |||||
| @@ -0,0 +1,339 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import copy | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import torch | |||||
| from .slim_utils import (enlarge_box_tblr, gen_skeleton_map, | |||||
| get_map_fusion_map_cuda, get_mask_bbox, | |||||
| resize_on_long_side) | |||||
| class PersonInfo(object): | |||||
| def __init__(self, joints): | |||||
| self.joints = joints | |||||
| self.flow = None | |||||
| self.pad_boder = False | |||||
| self.height_expand = 0 | |||||
| self.width_expand = 0 | |||||
| self.coeff = 0.2 | |||||
| self.network_input_W = 256 | |||||
| self.network_input_H = 256 | |||||
| self.divider = 20 | |||||
| self.flow_scales = ['upper_2'] | |||||
| def update_attribute(self, pad_boder, height_expand, width_expand): | |||||
| self.pad_boder = pad_boder | |||||
| self.height_expand = height_expand | |||||
| self.width_expand = width_expand | |||||
| if pad_boder: | |||||
| self.joints[:, 0] += width_expand | |||||
| self.joints[:, 1] += height_expand | |||||
| def pred_flow(self, img, flow_net, device): | |||||
| with torch.no_grad(): | |||||
| if img is None: | |||||
| print('image is none') | |||||
| self.flow = None | |||||
| if len(img.shape) == 2: | |||||
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |||||
| if self.pad_boder: | |||||
| height_expand = self.height_expand | |||||
| width_expand = self.width_expand | |||||
| pad_img = cv2.copyMakeBorder( | |||||
| img, | |||||
| height_expand, | |||||
| height_expand, | |||||
| width_expand, | |||||
| width_expand, | |||||
| cv2.BORDER_CONSTANT, | |||||
| value=(127, 127, 127)) | |||||
| else: | |||||
| height_expand = 0 | |||||
| width_expand = 0 | |||||
| pad_img = img.copy() | |||||
| canvas = np.zeros( | |||||
| shape=(pad_img.shape[0], pad_img.shape[1]), dtype=np.float32) | |||||
| self.human_joint_box = self.__joint_to_body_box() | |||||
| self.human_box = enlarge_box_tblr( | |||||
| self.human_joint_box, pad_img, ratio=0.25) | |||||
| human_box_height = self.human_box[1] - self.human_box[0] | |||||
| human_box_width = self.human_box[3] - self.human_box[2] | |||||
| self.leg_joint_box = self.__joint_to_leg_box() | |||||
| self.leg_box = enlarge_box_tblr( | |||||
| self.leg_joint_box, pad_img, ratio=0.25) | |||||
| self.arm_joint_box = self.__joint_to_arm_box() | |||||
| self.arm_box = enlarge_box_tblr( | |||||
| self.arm_joint_box, pad_img, ratio=0.1) | |||||
| x_flows = [] | |||||
| y_flows = [] | |||||
| multi_bbox = [] | |||||
| for scale in self.flow_scales: # better for metric | |||||
| scale_value = float(scale.split('_')[-1]) | |||||
| arm_box = copy.deepcopy(self.arm_box) | |||||
| if arm_box[0] is None: | |||||
| arm_box = self.human_box | |||||
| arm_box_height = arm_box[1] - arm_box[0] | |||||
| arm_box_width = arm_box[3] - arm_box[2] | |||||
| roi_bbox = None | |||||
| if arm_box_width < human_box_width * 0.1 or arm_box_height < human_box_height * 0.1: | |||||
| roi_bbox = self.human_box | |||||
| else: | |||||
| arm_box = enlarge_box_tblr( | |||||
| arm_box, pad_img, ratio=scale_value) | |||||
| if scale == 'upper_0.2': | |||||
| arm_box[0] = min(arm_box[0], int(self.joints[0][1])) | |||||
| if scale.startswith('upper'): | |||||
| roi_bbox = [ | |||||
| max(self.human_box[0], arm_box[0]), | |||||
| min(self.human_box[1], arm_box[1]), | |||||
| max(self.human_box[2], arm_box[2]), | |||||
| min(self.human_box[3], arm_box[3]) | |||||
| ] | |||||
| if roi_bbox[1] - roi_bbox[0] < 1 or roi_bbox[ | |||||
| 3] - roi_bbox[2] < 1: | |||||
| continue | |||||
| elif scale.startswith('lower'): | |||||
| roi_bbox = [ | |||||
| max(self.human_box[0], self.leg_box[0]), | |||||
| min(self.human_box[1], self.leg_box[1]), | |||||
| max(self.human_box[2], self.leg_box[2]), | |||||
| min(self.human_box[3], self.leg_box[3]) | |||||
| ] | |||||
| if roi_bbox[1] - roi_bbox[0] < 1 or roi_bbox[ | |||||
| 3] - roi_bbox[2] < 1: | |||||
| continue | |||||
| skel_map, roi_bbox = gen_skeleton_map( | |||||
| self.joints, 'depth', input_roi_box=roi_bbox) | |||||
| if roi_bbox is None: | |||||
| continue | |||||
| if skel_map.dtype != np.float32: | |||||
| skel_map = skel_map.astype(np.float32) | |||||
| skel_map -= 1.0 # [0,2] ->[-1,1] | |||||
| multi_bbox.append(roi_bbox) | |||||
| roi_bbox_height = roi_bbox[1] - roi_bbox[0] | |||||
| roi_bbox_width = roi_bbox[3] - roi_bbox[2] | |||||
| assert skel_map.shape[0] == roi_bbox_height | |||||
| assert skel_map.shape[1] == roi_bbox_width | |||||
| roi_height_pad = roi_bbox_height // self.divider | |||||
| roi_width_pad = roi_bbox_width // self.divider | |||||
| paded_roi_h = roi_bbox_height + 2 * roi_height_pad | |||||
| paded_roi_w = roi_bbox_width + 2 * roi_width_pad | |||||
| roi_height_pad_joint = skel_map.shape[0] // self.divider | |||||
| roi_width_pad_joint = skel_map.shape[1] // self.divider | |||||
| skel_map = np.pad( | |||||
| skel_map, | |||||
| ((roi_height_pad_joint, roi_height_pad_joint), | |||||
| (roi_width_pad_joint, roi_width_pad_joint), (0, 0)), | |||||
| 'constant', | |||||
| constant_values=-1) | |||||
| skel_map_resized = cv2.resize( | |||||
| skel_map, (self.network_input_W, self.network_input_H)) | |||||
| skel_map_resized[skel_map_resized < 0] = -1.0 | |||||
| skel_map_resized[skel_map_resized > -0.5] = 1.0 | |||||
| skel_map_transformed = torch.from_numpy( | |||||
| skel_map_resized.transpose((2, 0, 1))) | |||||
| roi_npy = pad_img[roi_bbox[0]:roi_bbox[1], | |||||
| roi_bbox[2]:roi_bbox[3], :].copy() | |||||
| if roi_npy.dtype != np.float32: | |||||
| roi_npy = roi_npy.astype(np.float32) | |||||
| roi_npy = np.pad(roi_npy, | |||||
| ((roi_height_pad, roi_height_pad), | |||||
| (roi_width_pad, roi_width_pad), (0, 0)), | |||||
| 'edge') | |||||
| roi_npy = roi_npy[:, :, ::-1] | |||||
| roi_npy = cv2.resize( | |||||
| roi_npy, (self.network_input_W, self.network_input_H)) | |||||
| roi_npy *= 1.0 / 255 | |||||
| roi_npy -= 0.5 | |||||
| roi_npy *= 2 | |||||
| rgb_tensor = torch.from_numpy(roi_npy.transpose((2, 0, 1))) | |||||
| rgb_tensor = rgb_tensor.unsqueeze(0).to(device) | |||||
| skel_map_tensor = skel_map_transformed.unsqueeze(0).to(device) | |||||
| warped_img_val, flow_field_val = flow_net( | |||||
| rgb_tensor, skel_map_tensor | |||||
| ) # inference, connectivity_mask [1,12,16,16] | |||||
| flow_field_val = flow_field_val.detach().squeeze().cpu().numpy( | |||||
| ) | |||||
| flow_field_val = cv2.resize( | |||||
| flow_field_val, (paded_roi_w, paded_roi_h), | |||||
| interpolation=cv2.INTER_LINEAR) | |||||
| flow_field_val[..., 0] = flow_field_val[ | |||||
| ..., 0] * paded_roi_w * 0.5 * 2 * self.coeff | |||||
| flow_field_val[..., 1] = flow_field_val[ | |||||
| ..., 1] * paded_roi_h * 0.5 * 2 * self.coeff | |||||
| # remove pad areas | |||||
| flow_field_val = flow_field_val[ | |||||
| roi_height_pad:flow_field_val.shape[0] - roi_height_pad, | |||||
| roi_width_pad:flow_field_val.shape[1] - roi_width_pad, :] | |||||
| diffuse_width = max(roi_bbox_width // 3, 1) | |||||
| diffuse_height = max(roi_bbox_height // 3, 1) | |||||
| assert roi_bbox_width == flow_field_val.shape[1] | |||||
| assert roi_bbox_height == flow_field_val.shape[0] | |||||
| origin_flow = np.zeros( | |||||
| (pad_img.shape[0] + 2 * diffuse_height, | |||||
| pad_img.shape[1] + 2 * diffuse_width, 2), | |||||
| dtype=np.float32) | |||||
| flow_field_val = np.pad(flow_field_val, | |||||
| ((diffuse_height, diffuse_height), | |||||
| (diffuse_width, diffuse_width), | |||||
| (0, 0)), 'linear_ramp') | |||||
| origin_flow[roi_bbox[0]:roi_bbox[1] + 2 * diffuse_height, | |||||
| roi_bbox[2]:roi_bbox[3] | |||||
| + 2 * diffuse_width] = flow_field_val | |||||
| origin_flow = origin_flow[diffuse_height:-diffuse_height, | |||||
| diffuse_width:-diffuse_width, :] | |||||
| x_flows.append(origin_flow[..., 0]) | |||||
| y_flows.append(origin_flow[..., 1]) | |||||
| if len(x_flows) == 0: | |||||
| return { | |||||
| 'rDx': np.zeros(canvas.shape[:2], dtype=np.float32), | |||||
| 'rDy': np.zeros(canvas.shape[:2], dtype=np.float32), | |||||
| 'multi_bbox': multi_bbox, | |||||
| 'x_fusion_map': | |||||
| np.ones(canvas.shape[:2], dtype=np.float32), | |||||
| 'y_fusion_map': | |||||
| np.ones(canvas.shape[:2], dtype=np.float32) | |||||
| } | |||||
| else: | |||||
| origin_rDx, origin_rDy, x_fusion_map, y_fusion_map = self.blend_multiscale_flow( | |||||
| x_flows, y_flows, device=device) | |||||
| return { | |||||
| 'rDx': origin_rDx, | |||||
| 'rDy': origin_rDy, | |||||
| 'multi_bbox': multi_bbox, | |||||
| 'x_fusion_map': x_fusion_map, | |||||
| 'y_fusion_map': y_fusion_map | |||||
| } | |||||
| @staticmethod | |||||
| def blend_multiscale_flow(x_flows, y_flows, device=None): | |||||
| scale_num = len(x_flows) | |||||
| if scale_num == 1: | |||||
| return x_flows[0], y_flows[0], np.ones_like( | |||||
| x_flows[0]), np.ones_like(x_flows[0]) | |||||
| origin_rDx = np.zeros((x_flows[0].shape[0], x_flows[0].shape[1]), | |||||
| dtype=np.float32) | |||||
| origin_rDy = np.zeros((y_flows[0].shape[0], y_flows[0].shape[1]), | |||||
| dtype=np.float32) | |||||
| x_fusion_map, x_acc_map = get_map_fusion_map_cuda( | |||||
| x_flows, 1, device=device) | |||||
| y_fusion_map, y_acc_map = get_map_fusion_map_cuda( | |||||
| y_flows, 1, device=device) | |||||
| x_flow_map = 1.0 / x_fusion_map | |||||
| y_flow_map = 1.0 / y_fusion_map | |||||
| all_acc_map = x_acc_map + y_acc_map | |||||
| all_acc_map = all_acc_map.astype(np.uint8) | |||||
| roi_box = get_mask_bbox(all_acc_map, threshold=1) | |||||
| if roi_box[0] is None or roi_box[1] - roi_box[0] <= 0 or roi_box[ | |||||
| 3] - roi_box[2] <= 0: | |||||
| roi_box = [0, x_flow_map.shape[0], 0, x_flow_map.shape[1]] | |||||
| roi_x_flow_map = x_flow_map[roi_box[0]:roi_box[1], | |||||
| roi_box[2]:roi_box[3]] | |||||
| roi_y_flow_map = y_flow_map[roi_box[0]:roi_box[1], | |||||
| roi_box[2]:roi_box[3]] | |||||
| roi_width = roi_x_flow_map.shape[1] | |||||
| roi_height = roi_x_flow_map.shape[0] | |||||
| roi_x_flow_map, scale = resize_on_long_side(roi_x_flow_map, 320) | |||||
| roi_y_flow_map, scale = resize_on_long_side(roi_y_flow_map, 320) | |||||
| roi_x_flow_map = cv2.blur(roi_x_flow_map, (55, 55)) | |||||
| roi_y_flow_map = cv2.blur(roi_y_flow_map, (55, 55)) | |||||
| roi_x_flow_map = cv2.resize(roi_x_flow_map, (roi_width, roi_height)) | |||||
| roi_y_flow_map = cv2.resize(roi_y_flow_map, (roi_width, roi_height)) | |||||
| x_flow_map[roi_box[0]:roi_box[1], | |||||
| roi_box[2]:roi_box[3]] = roi_x_flow_map | |||||
| y_flow_map[roi_box[0]:roi_box[1], | |||||
| roi_box[2]:roi_box[3]] = roi_y_flow_map | |||||
| for i in range(scale_num): | |||||
| origin_rDx += x_flows[i] | |||||
| origin_rDy += y_flows[i] | |||||
| origin_rDx *= x_flow_map | |||||
| origin_rDy *= y_flow_map | |||||
| return origin_rDx, origin_rDy, x_flow_map, y_flow_map | |||||
| def __joint_to_body_box(self): | |||||
| joint_left = int(np.min(self.joints, axis=0)[0]) | |||||
| joint_right = int(np.max(self.joints, axis=0)[0]) | |||||
| joint_top = int(np.min(self.joints, axis=0)[1]) | |||||
| joint_bottom = int(np.max(self.joints, axis=0)[1]) | |||||
| return [joint_top, joint_bottom, joint_left, joint_right] | |||||
| def __joint_to_leg_box(self): | |||||
| leg_joints = self.joints[8:, :] | |||||
| if np.max(leg_joints, axis=0)[2] < 0.05: | |||||
| return [0, 0, 0, 0] | |||||
| joint_left = int(np.min(leg_joints, axis=0)[0]) | |||||
| joint_right = int(np.max(leg_joints, axis=0)[0]) | |||||
| joint_top = int(np.min(leg_joints, axis=0)[1]) | |||||
| joint_bottom = int(np.max(leg_joints, axis=0)[1]) | |||||
| return [joint_top, joint_bottom, joint_left, joint_right] | |||||
| def __joint_to_arm_box(self): | |||||
| arm_joints = self.joints[2:8, :] | |||||
| if np.max(arm_joints, axis=0)[2] < 0.05: | |||||
| return [0, 0, 0, 0] | |||||
| joint_left = int(np.min(arm_joints, axis=0)[0]) | |||||
| joint_right = int(np.max(arm_joints, axis=0)[0]) | |||||
| joint_top = int(np.min(arm_joints, axis=0)[1]) | |||||
| joint_bottom = int(np.max(arm_joints, axis=0)[1]) | |||||
| return [joint_top, joint_bottom, joint_left, joint_right] | |||||
| @@ -0,0 +1,272 @@ | |||||
| # The implementation is based on openpose, available at https://github.com/Hzzone/pytorch-openpose. | |||||
| import math | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import torch | |||||
| from scipy.ndimage.filters import gaussian_filter | |||||
| from .model import BodyposeModel | |||||
| from .util import pad_rightdown_corner, transfer | |||||
| class Body(object): | |||||
| def __init__(self, model_path, device): | |||||
| self.model = BodyposeModel().to(device) | |||||
| model_dict = transfer(self.model, torch.load(model_path)) | |||||
| self.model.load_state_dict(model_dict) | |||||
| self.model.eval() | |||||
| def __call__(self, oriImg): | |||||
| scale_search = [0.5] | |||||
| boxsize = 368 | |||||
| stride = 8 | |||||
| padValue = 128 | |||||
| thre1 = 0.1 | |||||
| thre2 = 0.05 | |||||
| bodyparts = 18 | |||||
| multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search] | |||||
| heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19)) | |||||
| paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38)) | |||||
| for m in range(len(multiplier)): | |||||
| scale = multiplier[m] | |||||
| imageToTest = cv2.resize( | |||||
| oriImg, (0, 0), | |||||
| fx=scale, | |||||
| fy=scale, | |||||
| interpolation=cv2.INTER_CUBIC) | |||||
| imageToTest_padded, pad = pad_rightdown_corner( | |||||
| imageToTest, stride, padValue) | |||||
| im = np.transpose( | |||||
| np.float32(imageToTest_padded[:, :, :, np.newaxis]), | |||||
| (3, 2, 0, 1)) / 256 - 0.5 | |||||
| im = np.ascontiguousarray(im) | |||||
| data = torch.from_numpy(im).float() | |||||
| if torch.cuda.is_available(): | |||||
| data = data.cuda() | |||||
| with torch.no_grad(): | |||||
| Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data) | |||||
| Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy() | |||||
| Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy() | |||||
| # extract outputs, resize, and remove padding | |||||
| heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2), | |||||
| (1, 2, 0)) # output 1 is heatmaps | |||||
| heatmap = cv2.resize( | |||||
| heatmap, (0, 0), | |||||
| fx=stride, | |||||
| fy=stride, | |||||
| interpolation=cv2.INTER_CUBIC) | |||||
| heatmap = heatmap[:imageToTest_padded.shape[0] | |||||
| - pad[2], :imageToTest_padded.shape[1] | |||||
| - pad[3], :] | |||||
| heatmap = cv2.resize( | |||||
| heatmap, (oriImg.shape[1], oriImg.shape[0]), | |||||
| interpolation=cv2.INTER_CUBIC) | |||||
| paf = np.transpose(np.squeeze(Mconv7_stage6_L1), | |||||
| (1, 2, 0)) # output 0 is PAFs | |||||
| paf = cv2.resize( | |||||
| paf, (0, 0), | |||||
| fx=stride, | |||||
| fy=stride, | |||||
| interpolation=cv2.INTER_CUBIC) | |||||
| paf = paf[:imageToTest_padded.shape[0] | |||||
| - pad[2], :imageToTest_padded.shape[1] - pad[3], :] | |||||
| paf = cv2.resize( | |||||
| paf, (oriImg.shape[1], oriImg.shape[0]), | |||||
| interpolation=cv2.INTER_CUBIC) | |||||
| heatmap_avg += heatmap_avg + heatmap / len(multiplier) | |||||
| paf_avg += +paf / len(multiplier) | |||||
| all_peaks = [] | |||||
| peak_counter = 0 | |||||
| for part in range(bodyparts): | |||||
| map_ori = heatmap_avg[:, :, part] | |||||
| one_heatmap = gaussian_filter(map_ori, sigma=3) | |||||
| map_left = np.zeros(one_heatmap.shape) | |||||
| map_left[1:, :] = one_heatmap[:-1, :] | |||||
| map_right = np.zeros(one_heatmap.shape) | |||||
| map_right[:-1, :] = one_heatmap[1:, :] | |||||
| map_up = np.zeros(one_heatmap.shape) | |||||
| map_up[:, 1:] = one_heatmap[:, :-1] | |||||
| map_down = np.zeros(one_heatmap.shape) | |||||
| map_down[:, :-1] = one_heatmap[:, 1:] | |||||
| peaks_binary = np.logical_and.reduce( | |||||
| (one_heatmap >= map_left, one_heatmap >= map_right, | |||||
| one_heatmap >= map_up, one_heatmap >= map_down, | |||||
| one_heatmap > thre1)) | |||||
| peaks = list( | |||||
| zip(np.nonzero(peaks_binary)[1], | |||||
| np.nonzero(peaks_binary)[0])) # note reverse | |||||
| peaks_with_score = [x + (map_ori[x[1], x[0]], ) for x in peaks] | |||||
| peak_id = range(peak_counter, peak_counter + len(peaks)) | |||||
| peaks_with_score_and_id = [ | |||||
| peaks_with_score[i] + (peak_id[i], ) | |||||
| for i in range(len(peak_id)) | |||||
| ] | |||||
| all_peaks.append(peaks_with_score_and_id) | |||||
| peak_counter += len(peaks) | |||||
| # find connection in the specified sequence, center 29 is in the position 15 | |||||
| limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], | |||||
| [9, 10], [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], | |||||
| [1, 15], [15, 17], [1, 16], [16, 18], [3, 17], [6, 18]] | |||||
| # the middle joints heatmap correpondence | |||||
| mapIdx = [[31, 32], [39, 40], [33, 34], [35, 36], [41, 42], [43, 44], | |||||
| [19, 20], [21, 22], [23, 24], [25, 26], [27, 28], [29, 30], | |||||
| [47, 48], [49, 50], [53, 54], [51, 52], [55, 56], [37, 38], | |||||
| [45, 46]] | |||||
| connection_all = [] | |||||
| special_k = [] | |||||
| mid_num = 10 | |||||
| for k in range(len(mapIdx)): | |||||
| score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]] | |||||
| candA = all_peaks[limbSeq[k][0] - 1] | |||||
| candB = all_peaks[limbSeq[k][1] - 1] | |||||
| nA = len(candA) | |||||
| nB = len(candB) | |||||
| if (nA != 0 and nB != 0): | |||||
| connection_candidate = [] | |||||
| for i in range(nA): | |||||
| for j in range(nB): | |||||
| vec = np.subtract(candB[j][:2], candA[i][:2]) | |||||
| norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1]) | |||||
| norm = max(0.001, norm) | |||||
| vec = np.divide(vec, norm) | |||||
| startend = list( | |||||
| zip( | |||||
| np.linspace( | |||||
| candA[i][0], candB[j][0], num=mid_num), | |||||
| np.linspace( | |||||
| candA[i][1], candB[j][1], num=mid_num))) | |||||
| vec_x = np.array([ | |||||
| score_mid[int(round(startend[item][1])), | |||||
| int(round(startend[item][0])), 0] | |||||
| for item in range(len(startend)) | |||||
| ]) | |||||
| vec_y = np.array([ | |||||
| score_mid[int(round(startend[item][1])), | |||||
| int(round(startend[item][0])), 1] | |||||
| for item in range(len(startend)) | |||||
| ]) | |||||
| score_midpts = np.multiply( | |||||
| vec_x, vec[0]) + np.multiply(vec_y, vec[1]) | |||||
| temp1 = sum(score_midpts) / len(score_midpts) | |||||
| temp2 = min(0.5 * oriImg.shape[0] / norm - 1, 0) | |||||
| score_with_dist_prior = temp1 + temp2 | |||||
| criterion1 = len(np.nonzero( | |||||
| score_midpts > thre2)[0]) > 0.8 * len(score_midpts) | |||||
| criterion2 = score_with_dist_prior > 0 | |||||
| if criterion1 and criterion2: | |||||
| connection_candidate.append([ | |||||
| i, j, score_with_dist_prior, | |||||
| score_with_dist_prior + candA[i][2] | |||||
| + candB[j][2] | |||||
| ]) | |||||
| connection_candidate = sorted( | |||||
| connection_candidate, key=lambda x: x[2], reverse=True) | |||||
| connection = np.zeros((0, 5)) | |||||
| for c in range(len(connection_candidate)): | |||||
| i, j, s = connection_candidate[c][0:3] | |||||
| if (i not in connection[:, 3] | |||||
| and j not in connection[:, 4]): | |||||
| connection = np.vstack( | |||||
| [connection, [candA[i][3], candB[j][3], s, i, j]]) | |||||
| if (len(connection) >= min(nA, nB)): | |||||
| break | |||||
| connection_all.append(connection) | |||||
| else: | |||||
| special_k.append(k) | |||||
| connection_all.append([]) | |||||
| # last number in each row is the total parts number of that person | |||||
| # the second last number in each row is the score of the overall configuration | |||||
| subset = -1 * np.ones((0, 20)) | |||||
| candidate = np.array( | |||||
| [item for sublist in all_peaks for item in sublist]) | |||||
| for k in range(len(mapIdx)): | |||||
| if k not in special_k: | |||||
| partAs = connection_all[k][:, 0] | |||||
| partBs = connection_all[k][:, 1] | |||||
| indexA, indexB = np.array(limbSeq[k]) - 1 | |||||
| for i in range(len(connection_all[k])): # = 1:size(temp,1) | |||||
| found = 0 | |||||
| subset_idx = [-1, -1] | |||||
| for j in range(len(subset)): # 1:size(subset,1): | |||||
| if subset[j][indexA] == partAs[i] or subset[j][ | |||||
| indexB] == partBs[i]: | |||||
| subset_idx[found] = j | |||||
| found += 1 | |||||
| if found == 1: | |||||
| j = subset_idx[0] | |||||
| if subset[j][indexB] != partBs[i]: | |||||
| subset[j][indexB] = partBs[i] | |||||
| subset[j][-1] += 1 | |||||
| subset[j][-2] += candidate[ | |||||
| partBs[i].astype(int), | |||||
| 2] + connection_all[k][i][2] | |||||
| elif found == 2: # if found 2 and disjoint, merge them | |||||
| j1, j2 = subset_idx | |||||
| tmp1 = (subset[j1] >= 0).astype(int) | |||||
| tmp2 = (subset[j2] >= 0).astype(int) | |||||
| membership = (tmp1 + tmp2)[:-2] | |||||
| if len(np.nonzero(membership == 2)[0]) == 0: # merge | |||||
| subset[j1][:-2] += (subset[j2][:-2] + 1) | |||||
| subset[j1][-2:] += subset[j2][-2:] | |||||
| subset[j1][-2] += connection_all[k][i][2] | |||||
| subset = np.delete(subset, j2, 0) | |||||
| else: # as like found == 1 | |||||
| subset[j1][indexB] = partBs[i] | |||||
| subset[j1][-1] += 1 | |||||
| subset[j1][-2] += candidate[ | |||||
| partBs[i].astype(int), | |||||
| 2] + connection_all[k][i][2] | |||||
| # if find no partA in the subset, create a new subset | |||||
| elif not found and k < 17: | |||||
| row = -1 * np.ones(20) | |||||
| row[indexA] = partAs[i] | |||||
| row[indexB] = partBs[i] | |||||
| row[-1] = 2 | |||||
| row[-2] = sum( | |||||
| candidate[connection_all[k][i, :2].astype(int), | |||||
| 2]) + connection_all[k][i][2] | |||||
| subset = np.vstack([subset, row]) | |||||
| # delete some rows of subset which has few parts occur | |||||
| deleteIdx = [] | |||||
| for i in range(len(subset)): | |||||
| if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4: | |||||
| deleteIdx.append(i) | |||||
| subset = np.delete(subset, deleteIdx, axis=0) | |||||
| # subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts | |||||
| # candidate: x, y, score, id | |||||
| count = subset.shape[0] | |||||
| joints = np.zeros(shape=(count, bodyparts, 3)) | |||||
| for i in range(count): | |||||
| for j in range(bodyparts): | |||||
| joints[i, j, :3] = candidate[int(subset[i, j]), :3] | |||||
| confidence = 1.0 if subset[i, j] >= 0 else 0.0 | |||||
| joints[i, j, 2] *= confidence | |||||
| return joints | |||||
| @@ -0,0 +1,141 @@ | |||||
| # The implementation is based on openpose, available at https://github.com/Hzzone/pytorch-openpose. | |||||
| from collections import OrderedDict | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| def make_layers(block, no_relu_layers): | |||||
| layers = [] | |||||
| for layer_name, v in block.items(): | |||||
| if 'pool' in layer_name: | |||||
| layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], padding=v[2]) | |||||
| layers.append((layer_name, layer)) | |||||
| else: | |||||
| conv2d = nn.Conv2d( | |||||
| in_channels=v[0], | |||||
| out_channels=v[1], | |||||
| kernel_size=v[2], | |||||
| stride=v[3], | |||||
| padding=v[4]) | |||||
| layers.append((layer_name, conv2d)) | |||||
| if layer_name not in no_relu_layers: | |||||
| layers.append(('relu_' + layer_name, nn.ReLU(inplace=True))) | |||||
| return nn.Sequential(OrderedDict(layers)) | |||||
| class BodyposeModel(nn.Module): | |||||
| def __init__(self): | |||||
| super(BodyposeModel, self).__init__() | |||||
| # these layers have no relu layer | |||||
| no_relu_layers = [ | |||||
| 'conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1', | |||||
| 'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2', | |||||
| 'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1', | |||||
| 'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1' | |||||
| ] | |||||
| blocks = {} | |||||
| block0 = OrderedDict([('conv1_1', [3, 64, 3, 1, 1]), | |||||
| ('conv1_2', [64, 64, 3, 1, 1]), | |||||
| ('pool1_stage1', [2, 2, 0]), | |||||
| ('conv2_1', [64, 128, 3, 1, 1]), | |||||
| ('conv2_2', [128, 128, 3, 1, 1]), | |||||
| ('pool2_stage1', [2, 2, 0]), | |||||
| ('conv3_1', [128, 256, 3, 1, 1]), | |||||
| ('conv3_2', [256, 256, 3, 1, 1]), | |||||
| ('conv3_3', [256, 256, 3, 1, 1]), | |||||
| ('conv3_4', [256, 256, 3, 1, 1]), | |||||
| ('pool3_stage1', [2, 2, 0]), | |||||
| ('conv4_1', [256, 512, 3, 1, 1]), | |||||
| ('conv4_2', [512, 512, 3, 1, 1]), | |||||
| ('conv4_3_CPM', [512, 256, 3, 1, 1]), | |||||
| ('conv4_4_CPM', [256, 128, 3, 1, 1])]) | |||||
| # Stage 1 | |||||
| block1_1 = OrderedDict([('conv5_1_CPM_L1', [128, 128, 3, 1, 1]), | |||||
| ('conv5_2_CPM_L1', [128, 128, 3, 1, 1]), | |||||
| ('conv5_3_CPM_L1', [128, 128, 3, 1, 1]), | |||||
| ('conv5_4_CPM_L1', [128, 512, 1, 1, 0]), | |||||
| ('conv5_5_CPM_L1', [512, 38, 1, 1, 0])]) | |||||
| block1_2 = OrderedDict([('conv5_1_CPM_L2', [128, 128, 3, 1, 1]), | |||||
| ('conv5_2_CPM_L2', [128, 128, 3, 1, 1]), | |||||
| ('conv5_3_CPM_L2', [128, 128, 3, 1, 1]), | |||||
| ('conv5_4_CPM_L2', [128, 512, 1, 1, 0]), | |||||
| ('conv5_5_CPM_L2', [512, 19, 1, 1, 0])]) | |||||
| blocks['block1_1'] = block1_1 | |||||
| blocks['block1_2'] = block1_2 | |||||
| self.model0 = make_layers(block0, no_relu_layers) | |||||
| # Stages 2 - 6 | |||||
| for i in range(2, 7): | |||||
| blocks['block%d_1' % i] = OrderedDict([ | |||||
| ('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]), | |||||
| ('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]), | |||||
| ('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]), | |||||
| ('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]), | |||||
| ('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]), | |||||
| ('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]), | |||||
| ('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0]) | |||||
| ]) | |||||
| blocks['block%d_2' % i] = OrderedDict([ | |||||
| ('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]), | |||||
| ('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]), | |||||
| ('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]), | |||||
| ('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]), | |||||
| ('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]), | |||||
| ('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]), | |||||
| ('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0]) | |||||
| ]) | |||||
| for k in blocks.keys(): | |||||
| blocks[k] = make_layers(blocks[k], no_relu_layers) | |||||
| self.model1_1 = blocks['block1_1'] | |||||
| self.model2_1 = blocks['block2_1'] | |||||
| self.model3_1 = blocks['block3_1'] | |||||
| self.model4_1 = blocks['block4_1'] | |||||
| self.model5_1 = blocks['block5_1'] | |||||
| self.model6_1 = blocks['block6_1'] | |||||
| self.model1_2 = blocks['block1_2'] | |||||
| self.model2_2 = blocks['block2_2'] | |||||
| self.model3_2 = blocks['block3_2'] | |||||
| self.model4_2 = blocks['block4_2'] | |||||
| self.model5_2 = blocks['block5_2'] | |||||
| self.model6_2 = blocks['block6_2'] | |||||
| def forward(self, x): | |||||
| out1 = self.model0(x) | |||||
| out1_1 = self.model1_1(out1) | |||||
| out1_2 = self.model1_2(out1) | |||||
| out2 = torch.cat([out1_1, out1_2, out1], 1) | |||||
| out2_1 = self.model2_1(out2) | |||||
| out2_2 = self.model2_2(out2) | |||||
| out3 = torch.cat([out2_1, out2_2, out1], 1) | |||||
| out3_1 = self.model3_1(out3) | |||||
| out3_2 = self.model3_2(out3) | |||||
| out4 = torch.cat([out3_1, out3_2, out1], 1) | |||||
| out4_1 = self.model4_1(out4) | |||||
| out4_2 = self.model4_2(out4) | |||||
| out5 = torch.cat([out4_1, out4_2, out1], 1) | |||||
| out5_1 = self.model5_1(out5) | |||||
| out5_2 = self.model5_2(out5) | |||||
| out6 = torch.cat([out5_1, out5_2, out1], 1) | |||||
| out6_1 = self.model6_1(out6) | |||||
| out6_2 = self.model6_2(out6) | |||||
| return out6_1, out6_2 | |||||
| @@ -0,0 +1,33 @@ | |||||
| # The implementation is based on openpose, available at https://github.com/Hzzone/pytorch-openpose. | |||||
| import numpy as np | |||||
| def pad_rightdown_corner(img, stride, padValue): | |||||
| h = img.shape[0] | |||||
| w = img.shape[1] | |||||
| pad = 4 * [None] | |||||
| pad[0] = 0 # up | |||||
| pad[1] = 0 # left | |||||
| pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down | |||||
| pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right | |||||
| img_padded = img | |||||
| pad_up = np.tile(img_padded[0:1, :, :] * 0 + padValue, (pad[0], 1, 1)) | |||||
| img_padded = np.concatenate((pad_up, img_padded), axis=0) | |||||
| pad_left = np.tile(img_padded[:, 0:1, :] * 0 + padValue, (1, pad[1], 1)) | |||||
| img_padded = np.concatenate((pad_left, img_padded), axis=1) | |||||
| pad_down = np.tile(img_padded[-2:-1, :, :] * 0 + padValue, (pad[2], 1, 1)) | |||||
| img_padded = np.concatenate((img_padded, pad_down), axis=0) | |||||
| pad_right = np.tile(img_padded[:, -2:-1, :] * 0 + padValue, (1, pad[3], 1)) | |||||
| img_padded = np.concatenate((img_padded, pad_right), axis=1) | |||||
| return img_padded, pad | |||||
| def transfer(model, model_weights): | |||||
| transfered_model_weights = {} | |||||
| for weights_name in model.state_dict().keys(): | |||||
| transfered_model_weights[weights_name] = model_weights['.'.join( | |||||
| weights_name.split('.')[1:])] | |||||
| return transfered_model_weights | |||||
| @@ -0,0 +1,507 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import math | |||||
| import os | |||||
| import random | |||||
| import cv2 | |||||
| import numba | |||||
| import numpy as np | |||||
| import torch | |||||
| def resize_on_long_side(img, long_side=800): | |||||
| src_height = img.shape[0] | |||||
| src_width = img.shape[1] | |||||
| if src_height > src_width: | |||||
| scale = long_side * 1.0 / src_height | |||||
| _img = cv2.resize( | |||||
| img, (int(src_width * scale), long_side), | |||||
| interpolation=cv2.INTER_LINEAR) | |||||
| else: | |||||
| scale = long_side * 1.0 / src_width | |||||
| _img = cv2.resize( | |||||
| img, (long_side, int(src_height * scale)), | |||||
| interpolation=cv2.INTER_LINEAR) | |||||
| return _img, scale | |||||
| def point_in_box(pt, box): | |||||
| pt_x = pt[0] | |||||
| pt_y = pt[1] | |||||
| if pt_x >= box[0] and pt_x <= box[0] + box[2] and pt_y >= box[ | |||||
| 1] and pt_y <= box[1] + box[3]: | |||||
| return True | |||||
| else: | |||||
| return False | |||||
| def enlarge_box_tblr(roi_bbox, mask, ratio=0.4, use_long_side=True): | |||||
| if roi_bbox is None or None in roi_bbox: | |||||
| return [None, None, None, None] | |||||
| top = roi_bbox[0] | |||||
| bottom = roi_bbox[1] | |||||
| left = roi_bbox[2] | |||||
| right = roi_bbox[3] | |||||
| roi_width = roi_bbox[3] - roi_bbox[2] | |||||
| roi_height = roi_bbox[1] - roi_bbox[0] | |||||
| right = left + roi_width | |||||
| bottom = top + roi_height | |||||
| long_side = roi_width if roi_width > roi_height else roi_height | |||||
| if use_long_side: | |||||
| new_left = left - int(long_side * ratio) | |||||
| else: | |||||
| new_left = left - int(roi_width * ratio) | |||||
| new_left = 1 if new_left < 0 else new_left | |||||
| if use_long_side: | |||||
| new_top = top - int(long_side * ratio) | |||||
| else: | |||||
| new_top = top - int(roi_height * ratio) | |||||
| new_top = 1 if new_top < 0 else new_top | |||||
| if use_long_side: | |||||
| new_right = right + int(long_side * ratio) | |||||
| else: | |||||
| new_right = right + int(roi_width * ratio) | |||||
| new_right = mask.shape[1] - 2 if new_right > mask.shape[1] else new_right | |||||
| if use_long_side: | |||||
| new_bottom = bottom + int(long_side * ratio) | |||||
| else: | |||||
| new_bottom = bottom + int(roi_height * ratio) | |||||
| new_bottom = mask.shape[0] - 2 if new_bottom > mask.shape[0] else new_bottom | |||||
| bbox = [new_top, new_bottom, new_left, new_right] | |||||
| return bbox | |||||
| def gen_PAF(image, joints): | |||||
| assert joints.shape[0] == 18 | |||||
| assert joints.shape[1] == 3 | |||||
| org_h = image.shape[0] | |||||
| org_w = image.shape[1] | |||||
| small_image, resize_scale = resize_on_long_side(image, 120) | |||||
| joints[:, :2] = joints[:, :2] * resize_scale | |||||
| joint_left = int(np.min(joints, axis=0)[0]) | |||||
| joint_right = int(np.max(joints, axis=0)[0]) | |||||
| joint_top = int(np.min(joints, axis=0)[1]) | |||||
| joint_bottom = int(np.max(joints, axis=0)[1]) | |||||
| limb_width = min( | |||||
| abs(joint_right - joint_left), abs(joint_bottom - joint_top)) // 6 | |||||
| if limb_width % 2 == 0: | |||||
| limb_width += 1 | |||||
| kernel_size = limb_width | |||||
| part_orders = [(5, 11), (2, 8), (5, 6), (6, 7), (2, 3), (3, 4), (11, 12), | |||||
| (12, 13), (8, 9), (9, 10)] | |||||
| map_list = [] | |||||
| mask_list = [] | |||||
| PAF_all = np.zeros( | |||||
| shape=(small_image.shape[0], small_image.shape[1], 2), | |||||
| dtype=np.float32) | |||||
| for c, pair in enumerate(part_orders): | |||||
| idx_a_name = pair[0] | |||||
| idx_b_name = pair[1] | |||||
| jointa = joints[idx_a_name] | |||||
| jointb = joints[idx_b_name] | |||||
| confidence_threshold = 0.05 | |||||
| if jointa[2] > confidence_threshold and jointb[ | |||||
| 2] > confidence_threshold: | |||||
| canvas = np.zeros( | |||||
| shape=(small_image.shape[0], small_image.shape[1]), | |||||
| dtype=np.uint8) | |||||
| canvas = cv2.line(canvas, (int(jointa[0]), int(jointa[1])), | |||||
| (int(jointb[0]), int(jointb[1])), | |||||
| (255, 255, 255), 5) | |||||
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, | |||||
| (kernel_size, kernel_size)) | |||||
| canvas = cv2.dilate(canvas, kernel, 1) | |||||
| canvas = cv2.GaussianBlur(canvas, (kernel_size, kernel_size), 0) | |||||
| canvas = canvas.astype(np.float32) / 255 | |||||
| PAF = np.zeros( | |||||
| shape=(small_image.shape[0], small_image.shape[1], 2), | |||||
| dtype=np.float32) | |||||
| PAF[..., 0] = jointb[0] - jointa[0] | |||||
| PAF[..., 1] = jointb[1] - jointa[1] | |||||
| mag, ang = cv2.cartToPolar(PAF[..., 0], PAF[..., 1]) | |||||
| PAF /= (np.dstack((mag, mag)) + 1e-5) | |||||
| single_PAF = PAF * np.dstack((canvas, canvas)) | |||||
| map_list.append( | |||||
| cv2.GaussianBlur(single_PAF, | |||||
| (kernel_size * 3, kernel_size * 3), 0)) | |||||
| mask_list.append( | |||||
| cv2.GaussianBlur(canvas.copy(), | |||||
| (kernel_size * 3, kernel_size * 3), 0)) | |||||
| PAF_all = PAF_all * (1.0 - np.dstack( | |||||
| (canvas, canvas))) + single_PAF | |||||
| PAF_all = cv2.GaussianBlur(PAF_all, (kernel_size * 3, kernel_size * 3), 0) | |||||
| PAF_all = cv2.resize( | |||||
| PAF_all, (org_w, org_h), interpolation=cv2.INTER_LINEAR) | |||||
| map_list.append(PAF_all) | |||||
| return PAF_all, map_list, mask_list | |||||
| def gen_skeleton_map(joints, stack_mode='column', input_roi_box=None): | |||||
| if type(joints) == list: | |||||
| joints = np.array(joints) | |||||
| assert stack_mode == 'column' or stack_mode == 'depth' | |||||
| part_orders = [(2, 5), (5, 11), (2, 8), (8, 11), (5, 6), (6, 7), (2, 3), | |||||
| (3, 4), (11, 12), (12, 13), (8, 9), (9, 10)] | |||||
| def link(img, a, b, color, line_width, scale=1.0, x_offset=0, y_offset=0): | |||||
| jointa = joints[a] | |||||
| jointb = joints[b] | |||||
| temp1 = int((jointa[0] - x_offset) * scale) | |||||
| temp2 = int((jointa[1] - y_offset) * scale) | |||||
| temp3 = int((jointb[0] - x_offset) * scale) | |||||
| temp4 = int((jointb[1] - y_offset) * scale) | |||||
| cv2.line(img, (temp1, temp2), (temp3, temp4), color, line_width) | |||||
| roi_box = input_roi_box | |||||
| roi_box_width = roi_box[3] - roi_box[2] | |||||
| roi_box_height = roi_box[1] - roi_box[0] | |||||
| short_side_length = min(roi_box_width, roi_box_height) | |||||
| line_width = short_side_length // 30 | |||||
| line_width = max(line_width, 2) | |||||
| map_cube = np.zeros( | |||||
| shape=(roi_box_height, roi_box_width, len(part_orders) + 1), | |||||
| dtype=np.float32) | |||||
| use_line_width = min(5, line_width) | |||||
| fx = use_line_width * 1.0 / line_width # fx 最大值为1 | |||||
| if fx < 0.99: | |||||
| map_cube = cv2.resize(map_cube, (0, 0), fx=fx, fy=fx) | |||||
| for c, pair in enumerate(part_orders): | |||||
| tmp = map_cube[..., c].copy() | |||||
| link( | |||||
| tmp, | |||||
| pair[0], | |||||
| pair[1], (2.0, 2.0, 2.0), | |||||
| use_line_width, | |||||
| scale=fx, | |||||
| x_offset=roi_box[2], | |||||
| y_offset=roi_box[0]) | |||||
| map_cube[..., c] = tmp | |||||
| tmp = map_cube[..., -1].copy() | |||||
| link( | |||||
| tmp, | |||||
| pair[0], | |||||
| pair[1], (2.0, 2.0, 2.0), | |||||
| use_line_width, | |||||
| scale=fx, | |||||
| x_offset=roi_box[2], | |||||
| y_offset=roi_box[0]) | |||||
| map_cube[..., -1] = tmp | |||||
| map_cube = cv2.resize(map_cube, (roi_box_width, roi_box_height)) | |||||
| if stack_mode == 'depth': | |||||
| return map_cube, roi_box | |||||
| elif stack_mode == 'column': | |||||
| joint_maps = [] | |||||
| for c in range(len(part_orders) + 1): | |||||
| joint_maps.append(map_cube[..., c]) | |||||
| joint_map = np.column_stack(joint_maps) | |||||
| return joint_map, roi_box | |||||
| def plot_one_box(x, img, color=None, label=None, line_thickness=None): | |||||
| tl = line_thickness or round( | |||||
| 0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness | |||||
| color = color or [random.randint(0, 255) for _ in range(3)] | |||||
| c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3])) | |||||
| cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA) | |||||
| if label: | |||||
| tf = max(tl - 1, 1) # font thickness | |||||
| t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0] | |||||
| c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3 | |||||
| cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled | |||||
| cv2.putText( | |||||
| img, | |||||
| label, (c1[0], c1[1] - 2), | |||||
| 0, | |||||
| tl / 3, [225, 255, 255], | |||||
| thickness=tf, | |||||
| lineType=cv2.LINE_AA) | |||||
| def draw_line(im, points, color, stroke_size=2, closed=False): | |||||
| points = points.astype(np.int32) | |||||
| for i in range(len(points) - 1): | |||||
| cv2.line(im, tuple(points[i]), tuple(points[i + 1]), color, | |||||
| stroke_size) | |||||
| if closed: | |||||
| cv2.line(im, tuple(points[0]), tuple(points[-1]), color, stroke_size) | |||||
| def enlarged_bbox(bbox, img_width, img_height, enlarge_ratio=0.2): | |||||
| left = bbox[0] | |||||
| top = bbox[1] | |||||
| right = bbox[2] | |||||
| bottom = bbox[3] | |||||
| roi_width = right - left | |||||
| roi_height = bottom - top | |||||
| new_left = left - int(roi_width * enlarge_ratio) | |||||
| new_left = 0 if new_left < 0 else new_left | |||||
| new_top = top - int(roi_height * enlarge_ratio) | |||||
| new_top = 0 if new_top < 0 else new_top | |||||
| new_right = right + int(roi_width * enlarge_ratio) | |||||
| new_right = img_width if new_right > img_width else new_right | |||||
| new_bottom = bottom + int(roi_height * enlarge_ratio) | |||||
| new_bottom = img_height if new_bottom > img_height else new_bottom | |||||
| bbox = [new_left, new_top, new_right, new_bottom] | |||||
| bbox = [int(x) for x in bbox] | |||||
| return bbox | |||||
| def get_map_fusion_map_cuda(map_list, threshold=1, device=torch.device('cpu')): | |||||
| map_list_cuda = [torch.from_numpy(x).to(device) for x in map_list] | |||||
| map_concat = torch.stack(tuple(map_list_cuda), dim=-1) | |||||
| map_concat = torch.abs(map_concat) | |||||
| map_concat[map_concat < threshold] = 0 | |||||
| map_concat[map_concat > 1e-5] = 1.0 | |||||
| sum_map = torch.sum(map_concat, dim=2) | |||||
| a = torch.ones_like(sum_map) | |||||
| acc_map = torch.where(sum_map > 0, a * 2.0, torch.zeros_like(sum_map)) | |||||
| fusion_map = torch.where(sum_map < 0.5, a * 1.5, sum_map) | |||||
| fusion_map = fusion_map.float() | |||||
| acc_map = acc_map.float() | |||||
| fusion_map = fusion_map.cpu().numpy().astype(np.float32) | |||||
| acc_map = acc_map.cpu().numpy().astype(np.float32) | |||||
| return fusion_map, acc_map | |||||
| def gen_border_shade(height, width, height_band, width_band): | |||||
| height_ratio = height_band * 1.0 / height | |||||
| width_ratio = width_band * 1.0 / width | |||||
| _height_band = int(256 * height_ratio) | |||||
| _width_band = int(256 * width_ratio) | |||||
| canvas = np.zeros((256, 256), dtype=np.float32) | |||||
| canvas[_height_band // 2:-_height_band // 2, | |||||
| _width_band // 2:-_width_band // 2] = 1.0 | |||||
| canvas = cv2.blur(canvas, (_height_band, _width_band)) | |||||
| canvas = cv2.resize(canvas, (width, height)) | |||||
| return canvas | |||||
| def get_mask_bbox(mask, threshold=127): | |||||
| ret, mask = cv2.threshold(mask, threshold, 1, 0) | |||||
| if cv2.countNonZero(mask) == 0: | |||||
| return [None, None, None, None] | |||||
| col_acc = np.sum(mask, 0) | |||||
| row_acc = np.sum(mask, 1) | |||||
| col_acc = col_acc.tolist() | |||||
| row_acc = row_acc.tolist() | |||||
| for x in range(len(col_acc)): | |||||
| if col_acc[x] > 0: | |||||
| left = x | |||||
| break | |||||
| for x in range(1, len(col_acc)): | |||||
| if col_acc[-x] > 0: | |||||
| right = len(col_acc) - x | |||||
| break | |||||
| for x in range(len(row_acc)): | |||||
| if row_acc[x] > 0: | |||||
| top = x | |||||
| break | |||||
| for x in range(1, len(row_acc)): | |||||
| if row_acc[-x] > 0: | |||||
| bottom = len(row_acc[::-1]) - x | |||||
| break | |||||
| return [top, bottom, left, right] | |||||
| def visualize_flow(flow): | |||||
| h, w = flow.shape[:2] | |||||
| hsv = np.zeros((h, w, 3), np.uint8) | |||||
| mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1]) | |||||
| hsv[..., 0] = ang * 180 / np.pi / 2 | |||||
| hsv[..., 1] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) | |||||
| hsv[..., 2] = 255 | |||||
| bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) | |||||
| bgr = bgr * 1.0 / 255 | |||||
| return bgr.astype(np.float32) | |||||
| def vis_joints(image, joints, color, show_text=True, confidence_threshold=0.1): | |||||
| part_orders = [(2, 5), (5, 11), (2, 8), (8, 11), (5, 6), (6, 7), (2, 3), | |||||
| (3, 4), (11, 12), (12, 13), (8, 9), (9, 10)] | |||||
| abandon_idxs = [0, 1, 14, 15, 16, 17] | |||||
| # draw joints | |||||
| for i, joint in enumerate(joints): | |||||
| if i in abandon_idxs: | |||||
| continue | |||||
| if joint[-1] > confidence_threshold: | |||||
| cv2.circle(image, (int(joint[0]), int(joint[1])), 1, color, 2) | |||||
| if show_text: | |||||
| cv2.putText(image, | |||||
| str(i) + '[{:.2f}]'.format(joint[-1]), | |||||
| (int(joint[0]), int(joint[1])), | |||||
| cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) | |||||
| # draw link | |||||
| for pair in part_orders: | |||||
| if joints[pair[0]][-1] > confidence_threshold and joints[ | |||||
| pair[1]][-1] > confidence_threshold: | |||||
| cv2.line(image, (int(joints[pair[0]][0]), int(joints[pair[0]][1])), | |||||
| (int(joints[pair[1]][0]), int(joints[pair[1]][1])), color, | |||||
| 2) | |||||
| return image | |||||
| def get_heatmap_cv(img, magn, max_flow_mag): | |||||
| min_flow_mag = .5 | |||||
| cv_magn = np.clip( | |||||
| 255 * (magn - min_flow_mag) / (max_flow_mag - min_flow_mag + 1e-7), | |||||
| a_min=0, | |||||
| a_max=255).astype(np.uint8) | |||||
| if img.dtype != np.uint8: | |||||
| img = (255 * img).astype(np.uint8) | |||||
| heatmap_img = cv2.applyColorMap(cv_magn, cv2.COLORMAP_JET) | |||||
| heatmap_img = heatmap_img[..., ::-1] | |||||
| h, w = magn.shape | |||||
| img_alpha = np.ones((h, w), dtype=np.double)[:, :, None] | |||||
| heatmap_alpha = np.clip( | |||||
| magn / (max_flow_mag + 1e-7), a_min=1e-7, a_max=1)[:, :, None]**.7 | |||||
| heatmap_alpha[heatmap_alpha < .2]**.5 | |||||
| pm_hm = heatmap_img * heatmap_alpha | |||||
| pm_img = img * img_alpha | |||||
| cv_out = pm_hm + pm_img * (1 - heatmap_alpha) | |||||
| cv_out = np.clip(cv_out, a_min=0, a_max=255).astype(np.uint8) | |||||
| return cv_out | |||||
| def save_heatmap_cv(img, flow, supression=2): | |||||
| flow_magn = np.sqrt(flow[:, :, 0]**2 + flow[:, :, 1]**2) | |||||
| flow_magn -= supression | |||||
| flow_magn[flow_magn <= 0] = 0 | |||||
| cv_out = get_heatmap_cv(img, flow_magn, np.max(flow_magn) * 1.3) | |||||
| return cv_out | |||||
| @numba.jit(nopython=True, parallel=False) | |||||
| def bilinear_interp(x, y, v11, v12, v21, v22): | |||||
| temp1 = (v11 * (1 - y) + v12 * y) * (1 - x) | |||||
| temp2 = (v21 * (1 - y) + v22 * y) * x | |||||
| result = temp1 + temp2 | |||||
| return result | |||||
| @numba.jit(nopython=True, parallel=False) | |||||
| def image_warp_grid1(rDx, rDy, oriImg, transRatio, width_expand, | |||||
| height_expand): | |||||
| srcW = oriImg.shape[1] | |||||
| srcH = oriImg.shape[0] | |||||
| newImg = oriImg.copy() | |||||
| for i in range(srcH): | |||||
| for j in range(srcW): | |||||
| _i = i | |||||
| _j = j | |||||
| deltaX = rDx[_i, _j] | |||||
| deltaY = rDy[_i, _j] | |||||
| nx = _j + deltaX * transRatio | |||||
| ny = _i + deltaY * transRatio | |||||
| if nx >= srcW - width_expand - 1: | |||||
| if nx > srcW - 1: | |||||
| nx = srcW - 1 | |||||
| if ny >= srcH - height_expand - 1: | |||||
| if ny > srcH - 1: | |||||
| ny = srcH - 1 | |||||
| if nx < width_expand: | |||||
| if nx < 0: | |||||
| nx = 0 | |||||
| if ny < height_expand: | |||||
| if ny < 0: | |||||
| ny = 0 | |||||
| nxi = int(math.floor(nx)) | |||||
| nyi = int(math.floor(ny)) | |||||
| nxi1 = int(math.ceil(nx)) | |||||
| nyi1 = int(math.ceil(ny)) | |||||
| for ll in range(3): | |||||
| newImg[_i, _j, | |||||
| ll] = bilinear_interp(ny - nyi, nx - nxi, | |||||
| oriImg[nyi, nxi, | |||||
| ll], oriImg[nyi, nxi1, ll], | |||||
| oriImg[nyi1, nxi, | |||||
| ll], oriImg[nyi1, nxi1, | |||||
| ll]) | |||||
| return newImg | |||||
| @@ -184,6 +184,7 @@ TASK_OUTPUTS = { | |||||
| Tasks.image_to_image_translation: [OutputKeys.OUTPUT_IMG], | Tasks.image_to_image_translation: [OutputKeys.OUTPUT_IMG], | ||||
| Tasks.image_style_transfer: [OutputKeys.OUTPUT_IMG], | Tasks.image_style_transfer: [OutputKeys.OUTPUT_IMG], | ||||
| Tasks.image_portrait_stylization: [OutputKeys.OUTPUT_IMG], | Tasks.image_portrait_stylization: [OutputKeys.OUTPUT_IMG], | ||||
| Tasks.image_body_reshaping: [OutputKeys.OUTPUT_IMG], | |||||
| # live category recognition result for single video | # live category recognition result for single video | ||||
| # { | # { | ||||
| @@ -75,6 +75,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| 'damo/nlp_bart_text-error-correction_chinese'), | 'damo/nlp_bart_text-error-correction_chinese'), | ||||
| Tasks.image_captioning: (Pipelines.image_captioning, | Tasks.image_captioning: (Pipelines.image_captioning, | ||||
| 'damo/ofa_image-caption_coco_large_en'), | 'damo/ofa_image-caption_coco_large_en'), | ||||
| Tasks.image_body_reshaping: (Pipelines.image_body_reshaping, | |||||
| 'damo/cv_flow-based-body-reshaping_damo'), | |||||
| Tasks.image_portrait_stylization: | Tasks.image_portrait_stylization: | ||||
| (Pipelines.person_image_cartoon, | (Pipelines.person_image_cartoon, | ||||
| 'damo/cv_unet_person-image-cartoon_compound-models'), | 'damo/cv_unet_person-image-cartoon_compound-models'), | ||||
| @@ -0,0 +1,40 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import Any, Dict | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Input, Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.preprocessors import LoadImage | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| @PIPELINES.register_module( | |||||
| Tasks.image_body_reshaping, module_name=Pipelines.image_body_reshaping) | |||||
| class ImageBodyReshapingPipeline(Pipeline): | |||||
| def __init__(self, model: str, **kwargs): | |||||
| """ | |||||
| use `model` to create a image body reshaping pipeline for prediction | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| super().__init__(model=model, **kwargs) | |||||
| logger.info('body reshaping model init done') | |||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||||
| img = LoadImage.convert_to_ndarray(input) | |||||
| result = {'img': img} | |||||
| return result | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
| output = self.model.inference(input['img']) | |||||
| result = {'outputs': output} | |||||
| return result | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| output_img = inputs['outputs'] | |||||
| return {OutputKeys.OUTPUT_IMG: output_img} | |||||
| @@ -60,7 +60,7 @@ class CVTasks(object): | |||||
| image_to_image_generation = 'image-to-image-generation' | image_to_image_generation = 'image-to-image-generation' | ||||
| image_style_transfer = 'image-style-transfer' | image_style_transfer = 'image-style-transfer' | ||||
| image_portrait_stylization = 'image-portrait-stylization' | image_portrait_stylization = 'image-portrait-stylization' | ||||
| image_body_reshaping = 'image-body-reshaping' | |||||
| image_embedding = 'image-embedding' | image_embedding = 'image-embedding' | ||||
| product_retrieval_embedding = 'product-retrieval-embedding' | product_retrieval_embedding = 'product-retrieval-embedding' | ||||
| @@ -13,6 +13,7 @@ ml_collections | |||||
| mmcls>=0.21.0 | mmcls>=0.21.0 | ||||
| mmdet>=2.25.0 | mmdet>=2.25.0 | ||||
| networkx>=2.5 | networkx>=2.5 | ||||
| numba | |||||
| onnxruntime>=1.10 | onnxruntime>=1.10 | ||||
| pai-easycv>=0.6.3.6 | pai-easycv>=0.6.3.6 | ||||
| pandas | pandas | ||||
| @@ -0,0 +1,58 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os.path as osp | |||||
| import unittest | |||||
| import cv2 | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.pipelines.base import Pipeline | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class ImageBodyReshapingTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| def setUp(self) -> None: | |||||
| self.task = Tasks.image_body_reshaping | |||||
| self.model_id = 'damo/cv_flow-based-body-reshaping_damo' | |||||
| self.test_image = 'data/test/images/image_body_reshaping.jpg' | |||||
| def pipeline_inference(self, pipeline: Pipeline, input_location: str): | |||||
| result = pipeline(input_location) | |||||
| if result is not None: | |||||
| cv2.imwrite('result_bodyreshaping.png', | |||||
| result[OutputKeys.OUTPUT_IMG]) | |||||
| print( | |||||
| f'Output written to {osp.abspath("result_body_reshaping.png")}' | |||||
| ) | |||||
| else: | |||||
| raise Exception('Testing failed: invalid output') | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_by_direct_model_download(self): | |||||
| model_dir = snapshot_download(self.model_id) | |||||
| image_body_reshaping = pipeline( | |||||
| Tasks.image_body_reshaping, model=model_dir) | |||||
| self.pipeline_inference(image_body_reshaping, self.test_image) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_modelhub(self): | |||||
| image_body_reshaping = pipeline( | |||||
| Tasks.image_body_reshaping, model=self.model_id) | |||||
| self.pipeline_inference(image_body_reshaping, self.test_image) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_modelhub_default_model(self): | |||||
| image_body_reshaping = pipeline(Tasks.image_body_reshaping) | |||||
| self.pipeline_inference(image_body_reshaping, self.test_image) | |||||
| @unittest.skip('demo compatibility test is only enabled on a needed-basis') | |||||
| def test_demo_compatibility(self): | |||||
| self.compatibility_check() | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||