Browse Source

[to #42322933]bugfix : add PIL image type support and model.to(devices) for body_2d_keypoints ipeline

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9684583
master
shouzhou.bx yingda.chen 3 years ago
parent
commit
ac1ba2a0e0
3 changed files with 33 additions and 25 deletions
  1. +9
    -9
      modelscope/outputs.py
  2. +13
    -12
      modelscope/pipelines/cv/body_2d_keypoints_pipeline.py
  3. +11
    -4
      tests/pipelines/test_body_2d_keypoints.py

+ 9
- 9
modelscope/outputs.py View File

@@ -161,19 +161,19 @@ TASK_OUTPUTS = {
# human body keypoints detection result for single sample
# {
# "poses": [
# [x, y],
# [x, y],
# [x, y]
# [[x, y]*15],
# [[x, y]*15],
# [[x, y]*15]
# ]
# "scores": [
# [score],
# [score],
# [score],
# [[score]*15],
# [[score]*15],
# [[score]*15]
# ]
# "boxes": [
# [x1, y1, x2, y2],
# [x1, y1, x2, y2],
# [x1, y1, x2, y2],
# [[x1, y1], [x2, y2]],
# [[x1, y1], [x2, y2]],
# [[x1, y1], [x2, y2]],
# ]
# }
Tasks.body_2d_keypoints:


+ 13
- 12
modelscope/pipelines/cv/body_2d_keypoints_pipeline.py View File

@@ -16,7 +16,7 @@ from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.pipelines.base import Input, Model, Pipeline, Tensor
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import load_image
from modelscope.preprocessors import LoadImage
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger

@@ -29,8 +29,9 @@ class Body2DKeypointsPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
super().__init__(model=model, **kwargs)
self.keypoint_model = KeypointsDetection(model)
self.keypoint_model.eval()
device = torch.device(
f'cuda:{0}' if torch.cuda.is_available() else 'cpu')
self.keypoint_model = KeypointsDetection(model, device)

self.human_detect_model_id = 'damo/cv_resnet18_human-detection'
self.human_detector = pipeline(
@@ -39,12 +40,8 @@ class Body2DKeypointsPipeline(Pipeline):
def preprocess(self, input: Input) -> Dict[Tensor, Union[str, np.ndarray]]:
output = self.human_detector(input)

if isinstance(input, str):
image = cv2.imread(input, -1)[:, :, 0:3]
elif isinstance(input, np.ndarray):
if len(input.shape) == 2:
image = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR)
image = image[:, :, 0:3]
image = LoadImage.convert_to_ndarray(input)
image = image[:, :, [2, 1, 0]] # rgb2bgr

return {'image': image, 'output': output}

@@ -88,14 +85,18 @@ class Body2DKeypointsPipeline(Pipeline):

class KeypointsDetection():

def __init__(self, model: str, **kwargs):
def __init__(self, model: str, device: str, **kwargs):
self.model = model
self.device = device
cfg = cfg_128x128_15
self.key_points_model = PoseHighResolutionNetV2(cfg)
pretrained_state_dict = torch.load(
osp.join(self.model, ModelFile.TORCH_MODEL_FILE))
osp.join(self.model, ModelFile.TORCH_MODEL_FILE),
map_location=device)
self.key_points_model.load_state_dict(
pretrained_state_dict, strict=False)
self.key_points_model = self.key_points_model.to(device)
self.key_points_model.eval()

self.input_size = cfg['MODEL']['IMAGE_SIZE']
self.lst_parent_ids = cfg['DATASET']['PARENT_IDS']
@@ -111,7 +112,7 @@ class KeypointsDetection():

def forward(self, input: Tensor) -> Tensor:
with torch.no_grad():
return self.key_points_model.forward(input)
return self.key_points_model.forward(input.to(self.device))

def get_pts(self, heatmaps):
[pts_num, height, width] = heatmaps.shape


+ 11
- 4
tests/pipelines/test_body_2d_keypoints.py View File

@@ -3,6 +3,7 @@ import unittest

import cv2
import numpy as np
from PIL import Image

from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
@@ -68,8 +69,8 @@ class Body2DKeypointsTest(unittest.TestCase):
self.model_id = 'damo/cv_hrnetv2w32_body-2d-keypoints_image'
self.test_image = 'data/test/images/keypoints_detect/000000438862.jpg'

def pipeline_inference(self, pipeline: Pipeline):
output = pipeline(self.test_image)
def pipeline_inference(self, pipeline: Pipeline, pipeline_input):
output = pipeline(pipeline_input)
poses = np.array(output[OutputKeys.POSES])
scores = np.array(output[OutputKeys.SCORES])
boxes = np.array(output[OutputKeys.BOXES])
@@ -80,11 +81,17 @@ class Body2DKeypointsTest(unittest.TestCase):
draw_joints(image, np.array(poses[i]), np.array(scores[i]))
cv2.imwrite('pose_keypoint.jpg', image)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_modelhub_with_image_file(self):
body_2d_keypoints = pipeline(
Tasks.body_2d_keypoints, model=self.model_id)
self.pipeline_inference(body_2d_keypoints, self.test_image)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub(self):
def test_run_modelhub_with_image_input(self):
body_2d_keypoints = pipeline(
Tasks.body_2d_keypoints, model=self.model_id)
self.pipeline_inference(body_2d_keypoints)
self.pipeline_inference(body_2d_keypoints, Image.open(self.test_image))


if __name__ == '__main__':


Loading…
Cancel
Save