|
@@ -16,7 +16,7 @@ from modelscope.outputs import OutputKeys |
|
|
from modelscope.pipelines import pipeline |
|
|
from modelscope.pipelines import pipeline |
|
|
from modelscope.pipelines.base import Input, Model, Pipeline, Tensor |
|
|
from modelscope.pipelines.base import Input, Model, Pipeline, Tensor |
|
|
from modelscope.pipelines.builder import PIPELINES |
|
|
from modelscope.pipelines.builder import PIPELINES |
|
|
from modelscope.preprocessors import load_image |
|
|
|
|
|
|
|
|
from modelscope.preprocessors import LoadImage |
|
|
from modelscope.utils.constant import ModelFile, Tasks |
|
|
from modelscope.utils.constant import ModelFile, Tasks |
|
|
from modelscope.utils.logger import get_logger |
|
|
from modelscope.utils.logger import get_logger |
|
|
|
|
|
|
|
@@ -29,8 +29,9 @@ class Body2DKeypointsPipeline(Pipeline): |
|
|
|
|
|
|
|
|
def __init__(self, model: str, **kwargs): |
|
|
def __init__(self, model: str, **kwargs): |
|
|
super().__init__(model=model, **kwargs) |
|
|
super().__init__(model=model, **kwargs) |
|
|
self.keypoint_model = KeypointsDetection(model) |
|
|
|
|
|
self.keypoint_model.eval() |
|
|
|
|
|
|
|
|
device = torch.device( |
|
|
|
|
|
f'cuda:{0}' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
self.keypoint_model = KeypointsDetection(model, device) |
|
|
|
|
|
|
|
|
self.human_detect_model_id = 'damo/cv_resnet18_human-detection' |
|
|
self.human_detect_model_id = 'damo/cv_resnet18_human-detection' |
|
|
self.human_detector = pipeline( |
|
|
self.human_detector = pipeline( |
|
@@ -39,12 +40,8 @@ class Body2DKeypointsPipeline(Pipeline): |
|
|
def preprocess(self, input: Input) -> Dict[Tensor, Union[str, np.ndarray]]: |
|
|
def preprocess(self, input: Input) -> Dict[Tensor, Union[str, np.ndarray]]: |
|
|
output = self.human_detector(input) |
|
|
output = self.human_detector(input) |
|
|
|
|
|
|
|
|
if isinstance(input, str): |
|
|
|
|
|
image = cv2.imread(input, -1)[:, :, 0:3] |
|
|
|
|
|
elif isinstance(input, np.ndarray): |
|
|
|
|
|
if len(input.shape) == 2: |
|
|
|
|
|
image = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) |
|
|
|
|
|
image = image[:, :, 0:3] |
|
|
|
|
|
|
|
|
image = LoadImage.convert_to_ndarray(input) |
|
|
|
|
|
image = image[:, :, [2, 1, 0]] # rgb2bgr |
|
|
|
|
|
|
|
|
return {'image': image, 'output': output} |
|
|
return {'image': image, 'output': output} |
|
|
|
|
|
|
|
@@ -88,14 +85,18 @@ class Body2DKeypointsPipeline(Pipeline): |
|
|
|
|
|
|
|
|
class KeypointsDetection(): |
|
|
class KeypointsDetection(): |
|
|
|
|
|
|
|
|
def __init__(self, model: str, **kwargs): |
|
|
|
|
|
|
|
|
def __init__(self, model: str, device: str, **kwargs): |
|
|
self.model = model |
|
|
self.model = model |
|
|
|
|
|
self.device = device |
|
|
cfg = cfg_128x128_15 |
|
|
cfg = cfg_128x128_15 |
|
|
self.key_points_model = PoseHighResolutionNetV2(cfg) |
|
|
self.key_points_model = PoseHighResolutionNetV2(cfg) |
|
|
pretrained_state_dict = torch.load( |
|
|
pretrained_state_dict = torch.load( |
|
|
osp.join(self.model, ModelFile.TORCH_MODEL_FILE)) |
|
|
|
|
|
|
|
|
osp.join(self.model, ModelFile.TORCH_MODEL_FILE), |
|
|
|
|
|
map_location=device) |
|
|
self.key_points_model.load_state_dict( |
|
|
self.key_points_model.load_state_dict( |
|
|
pretrained_state_dict, strict=False) |
|
|
pretrained_state_dict, strict=False) |
|
|
|
|
|
self.key_points_model = self.key_points_model.to(device) |
|
|
|
|
|
self.key_points_model.eval() |
|
|
|
|
|
|
|
|
self.input_size = cfg['MODEL']['IMAGE_SIZE'] |
|
|
self.input_size = cfg['MODEL']['IMAGE_SIZE'] |
|
|
self.lst_parent_ids = cfg['DATASET']['PARENT_IDS'] |
|
|
self.lst_parent_ids = cfg['DATASET']['PARENT_IDS'] |
|
@@ -111,7 +112,7 @@ class KeypointsDetection(): |
|
|
|
|
|
|
|
|
def forward(self, input: Tensor) -> Tensor: |
|
|
def forward(self, input: Tensor) -> Tensor: |
|
|
with torch.no_grad(): |
|
|
with torch.no_grad(): |
|
|
return self.key_points_model.forward(input) |
|
|
|
|
|
|
|
|
return self.key_points_model.forward(input.to(self.device)) |
|
|
|
|
|
|
|
|
def get_pts(self, heatmaps): |
|
|
def get_pts(self, heatmaps): |
|
|
[pts_num, height, width] = heatmaps.shape |
|
|
[pts_num, height, width] = heatmaps.shape |
|
|