|
|
@@ -1,10 +1,19 @@ |
|
|
|
import os |
|
|
|
# Copyright (c) Alibaba, Inc. and its affiliates. |
|
|
|
|
|
|
|
import datetime |
|
|
|
import os.path as osp |
|
|
|
import tempfile |
|
|
|
from typing import Any, Dict, List, Union |
|
|
|
|
|
|
|
import cv2 |
|
|
|
import matplotlib |
|
|
|
import matplotlib.pyplot as plt |
|
|
|
import mpl_toolkits.mplot3d.axes3d as p3 |
|
|
|
import numpy as np |
|
|
|
import torch |
|
|
|
from matplotlib import animation |
|
|
|
from matplotlib.animation import writers |
|
|
|
from matplotlib.ticker import MultipleLocator |
|
|
|
|
|
|
|
from modelscope.metainfo import Pipelines |
|
|
|
from modelscope.models.cv.body_3d_keypoints.body_3d_pose import ( |
|
|
@@ -16,6 +25,8 @@ from modelscope.pipelines.builder import PIPELINES |
|
|
|
from modelscope.utils.constant import Tasks |
|
|
|
from modelscope.utils.logger import get_logger |
|
|
|
|
|
|
|
matplotlib.use('Agg') |
|
|
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
|
|
|
|
|
@@ -121,7 +132,13 @@ class Body3DKeypointsPipeline(Pipeline): |
|
|
|
device='gpu' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
|
|
def preprocess(self, input: Input) -> Dict[str, Any]: |
|
|
|
video_frames = self.read_video_frames(input) |
|
|
|
video_url = input.get('input_video') |
|
|
|
self.output_video_path = input.get('output_video_path') |
|
|
|
if self.output_video_path is None: |
|
|
|
self.output_video_path = tempfile.NamedTemporaryFile( |
|
|
|
suffix='.mp4').name |
|
|
|
|
|
|
|
video_frames = self.read_video_frames(video_url) |
|
|
|
if 0 == len(video_frames): |
|
|
|
res = {'success': False, 'msg': 'get video frame failed.'} |
|
|
|
return res |
|
|
@@ -168,13 +185,21 @@ class Body3DKeypointsPipeline(Pipeline): |
|
|
|
return res |
|
|
|
|
|
|
|
def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]: |
|
|
|
res = {OutputKeys.POSES: []} |
|
|
|
res = {OutputKeys.POSES: [], OutputKeys.TIMESTAMPS: []} |
|
|
|
|
|
|
|
if not input['success']: |
|
|
|
pass |
|
|
|
else: |
|
|
|
poses = input[KeypointsTypes.POSES_CAMERA] |
|
|
|
res = {OutputKeys.POSES: poses.data.cpu().numpy()} |
|
|
|
pred_3d_pose = poses.data.cpu().numpy()[ |
|
|
|
0] # [frame_num, joint_num, joint_dim] |
|
|
|
|
|
|
|
if 'render' in self.keypoint_model_3d.cfg.keys(): |
|
|
|
self.render_prediction(pred_3d_pose) |
|
|
|
res[OutputKeys.OUTPUT_VIDEO] = self.output_video_path |
|
|
|
|
|
|
|
res[OutputKeys.POSES] = pred_3d_pose |
|
|
|
res[OutputKeys.TIMESTAMPS] = self.timestamps |
|
|
|
return res |
|
|
|
|
|
|
|
def read_video_frames(self, video_url: Union[str, cv2.VideoCapture]): |
|
|
@@ -189,7 +214,15 @@ class Body3DKeypointsPipeline(Pipeline): |
|
|
|
Returns: |
|
|
|
[nd.array]: List of video frames. |
|
|
|
""" |
|
|
|
|
|
|
|
def timestamp_format(seconds): |
|
|
|
m, s = divmod(seconds, 60) |
|
|
|
h, m = divmod(m, 60) |
|
|
|
time = '%02d:%02d:%06.3f' % (h, m, s) |
|
|
|
return time |
|
|
|
|
|
|
|
frames = [] |
|
|
|
self.timestamps = [] # for video render |
|
|
|
if isinstance(video_url, str): |
|
|
|
cap = cv2.VideoCapture(video_url) |
|
|
|
if not cap.isOpened(): |
|
|
@@ -199,15 +232,131 @@ class Body3DKeypointsPipeline(Pipeline): |
|
|
|
else: |
|
|
|
cap = video_url |
|
|
|
|
|
|
|
self.fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
|
if self.fps is None or self.fps <= 0: |
|
|
|
raise Exception('modelscope error: %s cannot get video fps info.' % |
|
|
|
(video_url)) |
|
|
|
|
|
|
|
max_frame_num = self.keypoint_model_3d.cfg.model.INPUT.MAX_FRAME |
|
|
|
frame_idx = 0 |
|
|
|
while True: |
|
|
|
ret, frame = cap.read() |
|
|
|
if not ret: |
|
|
|
break |
|
|
|
self.timestamps.append( |
|
|
|
timestamp_format(seconds=frame_idx / self.fps)) |
|
|
|
frame_idx += 1 |
|
|
|
frames.append(frame) |
|
|
|
if frame_idx >= max_frame_num: |
|
|
|
break |
|
|
|
cap.release() |
|
|
|
return frames |
|
|
|
|
|
|
|
def render_prediction(self, pose3d_cam_rr): |
|
|
|
"""render predict result 3d poses. |
|
|
|
|
|
|
|
Args: |
|
|
|
pose3d_cam_rr (nd.array): [frame_num, joint_num, joint_dim], 3d pose joints |
|
|
|
|
|
|
|
Returns: |
|
|
|
""" |
|
|
|
frame_num = pose3d_cam_rr.shape[0] |
|
|
|
|
|
|
|
left_points = [11, 12, 13, 4, 5, 6] # joints of left body |
|
|
|
edges = [[0, 1], [0, 4], [0, 7], [1, 2], [4, 5], [5, 6], [2, |
|
|
|
3], [7, 8], |
|
|
|
[8, 9], [8, 11], [8, 14], [14, 15], [15, 16], [11, 12], |
|
|
|
[12, 13], [9, 10]] # connection between joints |
|
|
|
|
|
|
|
fig = plt.figure() |
|
|
|
ax = p3.Axes3D(fig) |
|
|
|
x_major_locator = MultipleLocator(0.5) |
|
|
|
|
|
|
|
ax.xaxis.set_major_locator(x_major_locator) |
|
|
|
ax.yaxis.set_major_locator(x_major_locator) |
|
|
|
ax.zaxis.set_major_locator(x_major_locator) |
|
|
|
ax.set_xlabel('X') |
|
|
|
ax.set_ylabel('Y') |
|
|
|
ax.set_zlabel('Z') |
|
|
|
ax.set_xlim(-1, 1) |
|
|
|
ax.set_ylim(-1, 1) |
|
|
|
ax.set_zlim(-1, 1) |
|
|
|
# view direction |
|
|
|
azim = self.keypoint_model_3d.cfg.render.azim |
|
|
|
elev = self.keypoint_model_3d.cfg.render.elev |
|
|
|
ax.view_init(elev, azim) |
|
|
|
|
|
|
|
# init plot, essentially |
|
|
|
x = pose3d_cam_rr[0, :, 0] |
|
|
|
y = pose3d_cam_rr[0, :, 1] |
|
|
|
z = pose3d_cam_rr[0, :, 2] |
|
|
|
points, = ax.plot(x, y, z, 'r.') |
|
|
|
|
|
|
|
def renderBones(xs, ys, zs): |
|
|
|
"""render bones in skeleton |
|
|
|
|
|
|
|
Args: |
|
|
|
xs (nd.array): [joint_num, joint_channel] |
|
|
|
ys (nd.array): [joint_num, joint_channel] |
|
|
|
zs (nd.array): [joint_num, joint_channel] |
|
|
|
""" |
|
|
|
bones = {} |
|
|
|
for idx, edge in enumerate(edges): |
|
|
|
index1, index2 = edge[0], edge[1] |
|
|
|
if index1 in left_points: |
|
|
|
edge_color = 'red' |
|
|
|
else: |
|
|
|
edge_color = 'blue' |
|
|
|
connect = ax.plot([xs[index1], xs[index2]], |
|
|
|
[ys[index1], ys[index2]], |
|
|
|
[zs[index1], zs[index2]], |
|
|
|
linewidth=2, |
|
|
|
color=edge_color) # plot edge |
|
|
|
bones[idx] = connect[0] |
|
|
|
return bones |
|
|
|
|
|
|
|
bones = renderBones(x, y, z) |
|
|
|
|
|
|
|
def update(frame_idx, points, bones): |
|
|
|
"""update animation |
|
|
|
|
|
|
|
Args: |
|
|
|
frame_idx (int): frame index |
|
|
|
points (mpl_toolkits.mplot3d.art3d.Line3D): skeleton points ploter |
|
|
|
bones (dict[int, mpl_toolkits.mplot3d.art3d.Line3D]): connection ploter |
|
|
|
|
|
|
|
Returns: |
|
|
|
tuple: points and bones ploter |
|
|
|
""" |
|
|
|
xs = pose3d_cam_rr[frame_idx, :, 0] |
|
|
|
ys = pose3d_cam_rr[frame_idx, :, 1] |
|
|
|
zs = pose3d_cam_rr[frame_idx, :, 2] |
|
|
|
|
|
|
|
# update bones |
|
|
|
for idx, edge in enumerate(edges): |
|
|
|
index1, index2 = edge[0], edge[1] |
|
|
|
x1x2 = (xs[index1], xs[index2]) |
|
|
|
y1y2 = (ys[index1], ys[index2]) |
|
|
|
z1z2 = (zs[index1], zs[index2]) |
|
|
|
bones[idx].set_xdata(x1x2) |
|
|
|
bones[idx].set_ydata(y1y2) |
|
|
|
bones[idx].set_3d_properties(z1z2, 'z') |
|
|
|
|
|
|
|
# update joints |
|
|
|
points.set_data(xs, ys) |
|
|
|
points.set_3d_properties(zs, 'z') |
|
|
|
if 0 == frame_idx / 100: |
|
|
|
logger.info(f'rendering {frame_idx}/{frame_num}') |
|
|
|
return points, bones |
|
|
|
|
|
|
|
ani = animation.FuncAnimation( |
|
|
|
fig=fig, |
|
|
|
func=update, |
|
|
|
frames=frame_num, |
|
|
|
interval=self.fps, |
|
|
|
fargs=(points, bones)) |
|
|
|
|
|
|
|
# save mp4 |
|
|
|
Writer = writers['ffmpeg'] |
|
|
|
writer = Writer(fps=self.fps, metadata={}, bitrate=4096) |
|
|
|
ani.save(self.output_video_path, writer=writer) |