Browse Source

[to #42322933]style(license): add license + render result poses with video

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10263904
master
hanyuan.chy yingda.chen 3 years ago
parent
commit
7f468acca3
5 changed files with 183 additions and 18 deletions
  1. +2
    -0
      modelscope/models/cv/body_3d_keypoints/body_3d_pose.py
  2. +1
    -1
      modelscope/models/cv/body_3d_keypoints/canonical_pose_modules.py
  3. +15
    -6
      modelscope/outputs.py
  4. +153
    -4
      modelscope/pipelines/cv/body_3d_keypoints_pipeline.py
  5. +12
    -7
      tests/pipelines/test_body_3d_keypoints.py

+ 2
- 0
modelscope/models/cv/body_3d_keypoints/body_3d_pose.py View File

@@ -1,3 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import logging import logging
import os.path as osp import os.path as osp
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union


+ 1
- 1
modelscope/models/cv/body_3d_keypoints/canonical_pose_modules.py View File

@@ -1,4 +1,4 @@
# The implementation is based on OSTrack, available at https://github.com/facebookresearch/VideoPose3D
# The implementation is based on VideoPose3D, available at https://github.com/facebookresearch/VideoPose3D
import torch import torch
import torch.nn as nn import torch.nn as nn




+ 15
- 6
modelscope/outputs.py View File

@@ -21,6 +21,7 @@ class OutputKeys(object):
POLYGONS = 'polygons' POLYGONS = 'polygons'
OUTPUT = 'output' OUTPUT = 'output'
OUTPUT_IMG = 'output_img' OUTPUT_IMG = 'output_img'
OUTPUT_VIDEO = 'output_video'
OUTPUT_PCM = 'output_pcm' OUTPUT_PCM = 'output_pcm'
IMG_EMBEDDING = 'img_embedding' IMG_EMBEDDING = 'img_embedding'
SPO_LIST = 'spo_list' SPO_LIST = 'spo_list'
@@ -218,13 +219,21 @@ TASK_OUTPUTS = {


# 3D human body keypoints detection result for single sample # 3D human body keypoints detection result for single sample
# { # {
# "poses": [
# [[x, y, z]*17],
# [[x, y, z]*17],
# [[x, y, z]*17]
# ]
# "poses": [ # 3d pose coordinate in camera coordinate
# [[x, y, z]*17], # joints of per image
# [[x, y, z]*17],
# ...
# ],
# "timestamps": [ # timestamps of all frames
# "00:00:0.230",
# "00:00:0.560",
# "00:00:0.690",
# ],
# "output_video": "path_to_rendered_video" , this is optional
# and is only avaialbe when the "render" option is enabled.
# } # }
Tasks.body_3d_keypoints: [OutputKeys.POSES],
Tasks.body_3d_keypoints:
[OutputKeys.POSES, OutputKeys.TIMESTAMPS, OutputKeys.OUTPUT_VIDEO],


# 2D hand keypoints result for single sample # 2D hand keypoints result for single sample
# { # {


+ 153
- 4
modelscope/pipelines/cv/body_3d_keypoints_pipeline.py View File

@@ -1,10 +1,19 @@
import os
# Copyright (c) Alibaba, Inc. and its affiliates.

import datetime
import os.path as osp import os.path as osp
import tempfile
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union


import cv2 import cv2
import matplotlib
import matplotlib.pyplot as plt
import mpl_toolkits.mplot3d.axes3d as p3
import numpy as np import numpy as np
import torch import torch
from matplotlib import animation
from matplotlib.animation import writers
from matplotlib.ticker import MultipleLocator


from modelscope.metainfo import Pipelines from modelscope.metainfo import Pipelines
from modelscope.models.cv.body_3d_keypoints.body_3d_pose import ( from modelscope.models.cv.body_3d_keypoints.body_3d_pose import (
@@ -16,6 +25,8 @@ from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger


matplotlib.use('Agg')

logger = get_logger() logger = get_logger()




@@ -121,7 +132,13 @@ class Body3DKeypointsPipeline(Pipeline):
device='gpu' if torch.cuda.is_available() else 'cpu') device='gpu' if torch.cuda.is_available() else 'cpu')


def preprocess(self, input: Input) -> Dict[str, Any]: def preprocess(self, input: Input) -> Dict[str, Any]:
video_frames = self.read_video_frames(input)
video_url = input.get('input_video')
self.output_video_path = input.get('output_video_path')
if self.output_video_path is None:
self.output_video_path = tempfile.NamedTemporaryFile(
suffix='.mp4').name

video_frames = self.read_video_frames(video_url)
if 0 == len(video_frames): if 0 == len(video_frames):
res = {'success': False, 'msg': 'get video frame failed.'} res = {'success': False, 'msg': 'get video frame failed.'}
return res return res
@@ -168,13 +185,21 @@ class Body3DKeypointsPipeline(Pipeline):
return res return res


def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]: def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]:
res = {OutputKeys.POSES: []}
res = {OutputKeys.POSES: [], OutputKeys.TIMESTAMPS: []}


if not input['success']: if not input['success']:
pass pass
else: else:
poses = input[KeypointsTypes.POSES_CAMERA] poses = input[KeypointsTypes.POSES_CAMERA]
res = {OutputKeys.POSES: poses.data.cpu().numpy()}
pred_3d_pose = poses.data.cpu().numpy()[
0] # [frame_num, joint_num, joint_dim]

if 'render' in self.keypoint_model_3d.cfg.keys():
self.render_prediction(pred_3d_pose)
res[OutputKeys.OUTPUT_VIDEO] = self.output_video_path

res[OutputKeys.POSES] = pred_3d_pose
res[OutputKeys.TIMESTAMPS] = self.timestamps
return res return res


def read_video_frames(self, video_url: Union[str, cv2.VideoCapture]): def read_video_frames(self, video_url: Union[str, cv2.VideoCapture]):
@@ -189,7 +214,15 @@ class Body3DKeypointsPipeline(Pipeline):
Returns: Returns:
[nd.array]: List of video frames. [nd.array]: List of video frames.
""" """

def timestamp_format(seconds):
m, s = divmod(seconds, 60)
h, m = divmod(m, 60)
time = '%02d:%02d:%06.3f' % (h, m, s)
return time

frames = [] frames = []
self.timestamps = [] # for video render
if isinstance(video_url, str): if isinstance(video_url, str):
cap = cv2.VideoCapture(video_url) cap = cv2.VideoCapture(video_url)
if not cap.isOpened(): if not cap.isOpened():
@@ -199,15 +232,131 @@ class Body3DKeypointsPipeline(Pipeline):
else: else:
cap = video_url cap = video_url


self.fps = cap.get(cv2.CAP_PROP_FPS)
if self.fps is None or self.fps <= 0:
raise Exception('modelscope error: %s cannot get video fps info.' %
(video_url))

max_frame_num = self.keypoint_model_3d.cfg.model.INPUT.MAX_FRAME max_frame_num = self.keypoint_model_3d.cfg.model.INPUT.MAX_FRAME
frame_idx = 0 frame_idx = 0
while True: while True:
ret, frame = cap.read() ret, frame = cap.read()
if not ret: if not ret:
break break
self.timestamps.append(
timestamp_format(seconds=frame_idx / self.fps))
frame_idx += 1 frame_idx += 1
frames.append(frame) frames.append(frame)
if frame_idx >= max_frame_num: if frame_idx >= max_frame_num:
break break
cap.release() cap.release()
return frames return frames

def render_prediction(self, pose3d_cam_rr):
"""render predict result 3d poses.

Args:
pose3d_cam_rr (nd.array): [frame_num, joint_num, joint_dim], 3d pose joints

Returns:
"""
frame_num = pose3d_cam_rr.shape[0]

left_points = [11, 12, 13, 4, 5, 6] # joints of left body
edges = [[0, 1], [0, 4], [0, 7], [1, 2], [4, 5], [5, 6], [2,
3], [7, 8],
[8, 9], [8, 11], [8, 14], [14, 15], [15, 16], [11, 12],
[12, 13], [9, 10]] # connection between joints

fig = plt.figure()
ax = p3.Axes3D(fig)
x_major_locator = MultipleLocator(0.5)

ax.xaxis.set_major_locator(x_major_locator)
ax.yaxis.set_major_locator(x_major_locator)
ax.zaxis.set_major_locator(x_major_locator)
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_xlim(-1, 1)
ax.set_ylim(-1, 1)
ax.set_zlim(-1, 1)
# view direction
azim = self.keypoint_model_3d.cfg.render.azim
elev = self.keypoint_model_3d.cfg.render.elev
ax.view_init(elev, azim)

# init plot, essentially
x = pose3d_cam_rr[0, :, 0]
y = pose3d_cam_rr[0, :, 1]
z = pose3d_cam_rr[0, :, 2]
points, = ax.plot(x, y, z, 'r.')

def renderBones(xs, ys, zs):
"""render bones in skeleton

Args:
xs (nd.array): [joint_num, joint_channel]
ys (nd.array): [joint_num, joint_channel]
zs (nd.array): [joint_num, joint_channel]
"""
bones = {}
for idx, edge in enumerate(edges):
index1, index2 = edge[0], edge[1]
if index1 in left_points:
edge_color = 'red'
else:
edge_color = 'blue'
connect = ax.plot([xs[index1], xs[index2]],
[ys[index1], ys[index2]],
[zs[index1], zs[index2]],
linewidth=2,
color=edge_color) # plot edge
bones[idx] = connect[0]
return bones

bones = renderBones(x, y, z)

def update(frame_idx, points, bones):
"""update animation

Args:
frame_idx (int): frame index
points (mpl_toolkits.mplot3d.art3d.Line3D): skeleton points ploter
bones (dict[int, mpl_toolkits.mplot3d.art3d.Line3D]): connection ploter

Returns:
tuple: points and bones ploter
"""
xs = pose3d_cam_rr[frame_idx, :, 0]
ys = pose3d_cam_rr[frame_idx, :, 1]
zs = pose3d_cam_rr[frame_idx, :, 2]

# update bones
for idx, edge in enumerate(edges):
index1, index2 = edge[0], edge[1]
x1x2 = (xs[index1], xs[index2])
y1y2 = (ys[index1], ys[index2])
z1z2 = (zs[index1], zs[index2])
bones[idx].set_xdata(x1x2)
bones[idx].set_ydata(y1y2)
bones[idx].set_3d_properties(z1z2, 'z')

# update joints
points.set_data(xs, ys)
points.set_3d_properties(zs, 'z')
if 0 == frame_idx / 100:
logger.info(f'rendering {frame_idx}/{frame_num}')
return points, bones

ani = animation.FuncAnimation(
fig=fig,
func=update,
frames=frame_num,
interval=self.fps,
fargs=(points, bones))

# save mp4
Writer = writers['ffmpeg']
writer = Writer(fps=self.fps, metadata={}, bitrate=4096)
ani.save(self.output_video_path, writer=writer)

+ 12
- 7
tests/pipelines/test_body_3d_keypoints.py View File

@@ -28,7 +28,12 @@ class Body3DKeypointsTest(unittest.TestCase, DemoCompatibilityCheck):
def test_run_modelhub_with_video_file(self): def test_run_modelhub_with_video_file(self):
body_3d_keypoints = pipeline( body_3d_keypoints = pipeline(
Tasks.body_3d_keypoints, model=self.model_id) Tasks.body_3d_keypoints, model=self.model_id)
self.pipeline_inference(body_3d_keypoints, self.test_video)
pipeline_input = {
'input_video': self.test_video,
'output_video_path': './result.mp4'
}
self.pipeline_inference(
body_3d_keypoints, pipeline_input=pipeline_input)


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub_with_video_stream(self): def test_run_modelhub_with_video_stream(self):
@@ -37,12 +42,12 @@ class Body3DKeypointsTest(unittest.TestCase, DemoCompatibilityCheck):
if not cap.isOpened(): if not cap.isOpened():
raise Exception('modelscope error: %s cannot be decoded by OpenCV.' raise Exception('modelscope error: %s cannot be decoded by OpenCV.'
% (self.test_video)) % (self.test_video))
self.pipeline_inference(body_3d_keypoints, cap)
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_modelhub_default_model(self):
body_3d_keypoints = pipeline(Tasks.body_3d_keypoints)
self.pipeline_inference(body_3d_keypoints, self.test_video)
pipeline_input = {
'input_video': cap,
'output_video_path': './result.mp4'
}
self.pipeline_inference(
body_3d_keypoints, pipeline_input=pipeline_input)


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_demo_compatibility(self): def test_demo_compatibility(self):


Loading…
Cancel
Save