添加新的action-detection task Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9898947master
@@ -0,0 +1,3 @@ | |||||
version https://git-lfs.github.com/spec/v1 | |||||
oid sha256:0b7c3bc7c82ea5fee9d83130041df01046d89143ff77058b04577455ff6fdc92 | |||||
size 3191059 |
@@ -133,6 +133,7 @@ class Pipelines(object): | |||||
skin_retouching = 'unet-skin-retouching' | skin_retouching = 'unet-skin-retouching' | ||||
tinynas_classification = 'tinynas-classification' | tinynas_classification = 'tinynas-classification' | ||||
crowd_counting = 'hrnet-crowd-counting' | crowd_counting = 'hrnet-crowd-counting' | ||||
action_detection = 'ResNetC3D-action-detection' | |||||
video_single_object_tracking = 'ostrack-vitb-video-single-object-tracking' | video_single_object_tracking = 'ostrack-vitb-video-single-object-tracking' | ||||
image_panoptic_segmentation = 'image-panoptic-segmentation' | image_panoptic_segmentation = 'image-panoptic-segmentation' | ||||
video_summarization = 'googlenet_pgl_video_summarization' | video_summarization = 'googlenet_pgl_video_summarization' | ||||
@@ -0,0 +1,21 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from typing import TYPE_CHECKING | |||||
from modelscope.utils.import_utils import LazyImportModule | |||||
if TYPE_CHECKING: | |||||
from .action_detection_onnx import ActionDetONNX | |||||
else: | |||||
_import_structure = {'action_detection_onnx': ['ActionDetONNX']} | |||||
import sys | |||||
sys.modules[__name__] = LazyImportModule( | |||||
__name__, | |||||
globals()['__file__'], | |||||
_import_structure, | |||||
module_spec=__spec__, | |||||
extra_objects={}, | |||||
) |
@@ -0,0 +1,177 @@ | |||||
import os | |||||
import os.path as osp | |||||
import shutil | |||||
import subprocess | |||||
import cv2 | |||||
import numpy as np | |||||
import onnxruntime as rt | |||||
from modelscope.models import Model | |||||
from modelscope.utils.constant import Devices | |||||
from modelscope.utils.device import verify_device | |||||
class ActionDetONNX(Model): | |||||
def __init__(self, model_dir, config, *args, **kwargs): | |||||
super().__init__(self, model_dir, *args, **kwargs) | |||||
model_file = osp.join(config['model_file']) | |||||
device_type, device_id = verify_device(self._device_name) | |||||
options = rt.SessionOptions() | |||||
options.intra_op_num_threads = 1 | |||||
options.inter_op_num_threads = 1 | |||||
if device_type == Devices.gpu: | |||||
sess = rt.InferenceSession( | |||||
model_file, | |||||
providers=['CUDAExecutionProvider'], | |||||
sess_options=options, | |||||
provider_options=[{ | |||||
'device_id': device_id | |||||
}]) | |||||
else: | |||||
sess = rt.InferenceSession( | |||||
model_file, | |||||
providers=['CPUExecutionProvider'], | |||||
sess_options=options) | |||||
self.input_name = sess.get_inputs()[0].name | |||||
self.sess = sess | |||||
self.num_stride = len(config['fpn_strides']) | |||||
self.score_thresh = np.asarray( | |||||
config['pre_nms_thresh'], dtype='float32').reshape((1, -1)) | |||||
self.size_divisibility = config['size_divisibility'] | |||||
self.nms_threshold = config['nms_thresh'] | |||||
self.tmp_dir = config['tmp_dir'] | |||||
self.temporal_stride = config['step'] | |||||
self.input_data_type = config['input_type'] | |||||
self.action_names = config['action_names'] | |||||
self.video_length_limit = config['video_length_limit'] | |||||
def resize_box(self, det, height, width, scale_h, scale_w): | |||||
bboxs = det[0] | |||||
bboxs[:, [0, 2]] *= scale_w | |||||
bboxs[:, [1, 3]] *= scale_h | |||||
bboxs[:, [0, 2]] = bboxs[:, [0, 2]].clip(0, width - 1) | |||||
bboxs[:, [1, 3]] = bboxs[:, [1, 3]].clip(0, height - 1) | |||||
result = { | |||||
'boxes': bboxs.round().astype('int32').tolist(), | |||||
'scores': det[1].tolist(), | |||||
'labels': [self.action_names[i] for i in det[2].tolist()] | |||||
} | |||||
return result | |||||
def parse_frames(self, frame_names): | |||||
imgs = [cv2.imread(name)[:, :, ::-1] for name in frame_names] | |||||
imgs = np.stack(imgs).astype(self.input_data_type).transpose( | |||||
(3, 0, 1, 2)) # c,t,h,w | |||||
imgs = imgs[None] | |||||
return imgs | |||||
def forward_img(self, imgs, h, w): | |||||
pred = self.sess.run(None, { | |||||
self.input_name: imgs, | |||||
'height': np.asarray(h), | |||||
'width': np.asarray(w) | |||||
}) | |||||
dets = self.post_nms( | |||||
pred, | |||||
score_threshold=self.score_thresh, | |||||
nms_threshold=self.nms_threshold) | |||||
return dets | |||||
def forward_video(self, video_name, scale): | |||||
min_size, max_size = self._get_sizes(scale) | |||||
tmp_dir = osp.join(self.tmp_dir, osp.basename(video_name)[:-4]) | |||||
if osp.exists(tmp_dir): | |||||
shutil.rmtree(tmp_dir) | |||||
os.makedirs(tmp_dir) | |||||
frame_rate = 2 | |||||
cmd = f'ffmpeg -y -loglevel quiet -ss 0 -t {self.video_length_limit}' + \ | |||||
f' -i {video_name} -r {frame_rate} -f image2 {tmp_dir}/%06d.jpg' | |||||
cmd = cmd.split(' ') | |||||
subprocess.call(cmd) | |||||
frame_names = [ | |||||
osp.join(tmp_dir, name) for name in sorted(os.listdir(tmp_dir)) | |||||
if name.endswith('.jpg') | |||||
] | |||||
frame_names = [ | |||||
frame_names[i:i + frame_rate * 2] | |||||
for i in range(0, | |||||
len(frame_names) - frame_rate * 2 + 1, frame_rate | |||||
* self.temporal_stride) | |||||
] | |||||
timestamp = list( | |||||
range(1, | |||||
len(frame_names) * self.temporal_stride, | |||||
self.temporal_stride)) | |||||
batch_imgs = [self.parse_frames(names) for names in frame_names] | |||||
N, _, T, H, W = batch_imgs[0].shape | |||||
scale_min = min_size / min(H, W) | |||||
h, w = min(int(scale_min * H), | |||||
max_size), min(int(scale_min * W), max_size) | |||||
h = round(h / self.size_divisibility) * self.size_divisibility | |||||
w = round(w / self.size_divisibility) * self.size_divisibility | |||||
scale_h, scale_w = H / h, W / w | |||||
results = [] | |||||
for imgs in batch_imgs: | |||||
det = self.forward_img(imgs, h, w) | |||||
det = self.resize_box(det[0], H, W, scale_h, scale_w) | |||||
results.append(det) | |||||
results = [{ | |||||
'timestamp': t, | |||||
'actions': res | |||||
} for t, res in zip(timestamp, results)] | |||||
shutil.rmtree(tmp_dir) | |||||
return results | |||||
def forward(self, video_name): | |||||
return self.forward_video(video_name, scale=1) | |||||
def post_nms(self, pred, score_threshold, nms_threshold=0.3): | |||||
pred_bboxes, pred_scores = pred | |||||
N = len(pred_bboxes) | |||||
dets = [] | |||||
for i in range(N): | |||||
bboxes, scores = pred_bboxes[i], pred_scores[i] | |||||
candidate_inds = scores > score_threshold | |||||
scores = scores[candidate_inds] | |||||
candidate_nonzeros = candidate_inds.nonzero() | |||||
bboxes = bboxes[candidate_nonzeros[0]] | |||||
labels = candidate_nonzeros[1] | |||||
keep = self._nms(bboxes, scores, labels, nms_threshold) | |||||
bbox = bboxes[keep] | |||||
score = scores[keep] | |||||
label = labels[keep] | |||||
dets.append((bbox, score, label)) | |||||
return dets | |||||
def _nms(self, boxes, scores, idxs, nms_threshold): | |||||
if len(boxes) == 0: | |||||
return [] | |||||
max_coordinate = boxes.max() | |||||
offsets = idxs * (max_coordinate + 1) | |||||
boxes_for_nms = boxes + offsets[:, None].astype('float32') | |||||
boxes_for_nms[:, 2] = boxes_for_nms[:, 2] - boxes_for_nms[:, 0] | |||||
boxes_for_nms[:, 3] = boxes_for_nms[:, 3] - boxes_for_nms[:, 1] | |||||
keep = cv2.dnn.NMSBoxes( | |||||
boxes_for_nms.tolist(), | |||||
scores.tolist(), | |||||
score_threshold=0, | |||||
nms_threshold=nms_threshold) | |||||
if len(keep.shape) == 2: | |||||
keep = np.squeeze(keep, 1) | |||||
return keep | |||||
def _get_sizes(self, scale): | |||||
if scale == 1: | |||||
min_size, max_size = 512, 896 | |||||
elif scale == 2: | |||||
min_size, max_size = 768, 1280 | |||||
else: | |||||
min_size, max_size = 1024, 1792 | |||||
return min_size, max_size |
@@ -35,6 +35,7 @@ class OutputKeys(object): | |||||
UUID = 'uuid' | UUID = 'uuid' | ||||
WORD = 'word' | WORD = 'word' | ||||
KWS_LIST = 'kws_list' | KWS_LIST = 'kws_list' | ||||
TIMESTAMPS = 'timestamps' | |||||
SPLIT_VIDEO_NUM = 'split_video_num' | SPLIT_VIDEO_NUM = 'split_video_num' | ||||
SPLIT_META_DICT = 'split_meta_dict' | SPLIT_META_DICT = 'split_meta_dict' | ||||
@@ -541,6 +542,19 @@ TASK_OUTPUTS = { | |||||
# } | # } | ||||
Tasks.visual_entailment: [OutputKeys.SCORES, OutputKeys.LABELS], | Tasks.visual_entailment: [OutputKeys.SCORES, OutputKeys.LABELS], | ||||
# { | |||||
# 'labels': ['吸烟', '打电话', '吸烟'], | |||||
# 'scores': [0.7527753114700317, 0.753358006477356, 0.6880350708961487], | |||||
# 'boxes': [[547, 2, 1225, 719], [529, 8, 1255, 719], [584, 0, 1269, 719]], | |||||
# 'timestamps': [1, 3, 5] | |||||
# } | |||||
Tasks.action_detection: [ | |||||
OutputKeys.TIMESTAMPS, | |||||
OutputKeys.LABELS, | |||||
OutputKeys.SCORES, | |||||
OutputKeys.BOXES, | |||||
], | |||||
# { | # { | ||||
# 'output': [ | # 'output': [ | ||||
# [{'label': '6527856', 'score': 0.9942756295204163}, {'label': '1000012000', 'score': 0.0379515215754509}, | # [{'label': '6527856', 'score': 0.9942756295204163}, {'label': '1000012000', 'score': 0.0379515215754509}, | ||||
@@ -551,6 +565,7 @@ TASK_OUTPUTS = { | |||||
# {'label': '13421097', 'score': 2.75914817393641e-06}]] | # {'label': '13421097', 'score': 2.75914817393641e-06}]] | ||||
# } | # } | ||||
Tasks.faq_question_answering: [OutputKeys.OUTPUT], | Tasks.faq_question_answering: [OutputKeys.OUTPUT], | ||||
# image person reid result for single sample | # image person reid result for single sample | ||||
# { | # { | ||||
# "img_embedding": np.array with shape [1, D], | # "img_embedding": np.array with shape [1, D], | ||||
@@ -71,6 +71,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask-large'), | Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask-large'), | ||||
Tasks.action_recognition: (Pipelines.action_recognition, | Tasks.action_recognition: (Pipelines.action_recognition, | ||||
'damo/cv_TAdaConv_action-recognition'), | 'damo/cv_TAdaConv_action-recognition'), | ||||
Tasks.action_detection: (Pipelines.action_detection, | |||||
'damo/cv_ResNetC3D_action-detection_detection2d'), | |||||
Tasks.live_category: (Pipelines.live_category, | Tasks.live_category: (Pipelines.live_category, | ||||
'damo/cv_resnet50_live-category'), | 'damo/cv_resnet50_live-category'), | ||||
Tasks.video_category: (Pipelines.video_category, | Tasks.video_category: (Pipelines.video_category, | ||||
@@ -5,6 +5,7 @@ from modelscope.utils.import_utils import LazyImportModule | |||||
if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
from .action_recognition_pipeline import ActionRecognitionPipeline | from .action_recognition_pipeline import ActionRecognitionPipeline | ||||
from .action_detection_pipeline import ActionDetectionPipeline | |||||
from .animal_recognition_pipeline import AnimalRecognitionPipeline | from .animal_recognition_pipeline import AnimalRecognitionPipeline | ||||
from .body_2d_keypoints_pipeline import Body2DKeypointsPipeline | from .body_2d_keypoints_pipeline import Body2DKeypointsPipeline | ||||
from .body_3d_keypoints_pipeline import Body3DKeypointsPipeline | from .body_3d_keypoints_pipeline import Body3DKeypointsPipeline | ||||
@@ -48,6 +49,7 @@ if TYPE_CHECKING: | |||||
else: | else: | ||||
_import_structure = { | _import_structure = { | ||||
'action_recognition_pipeline': ['ActionRecognitionPipeline'], | 'action_recognition_pipeline': ['ActionRecognitionPipeline'], | ||||
'action_detection_pipeline': ['ActionDetectionPipeline'], | |||||
'animal_recognition_pipeline': ['AnimalRecognitionPipeline'], | 'animal_recognition_pipeline': ['AnimalRecognitionPipeline'], | ||||
'body_2d_keypoints_pipeline': ['Body2DKeypointsPipeline'], | 'body_2d_keypoints_pipeline': ['Body2DKeypointsPipeline'], | ||||
'body_3d_keypoints_pipeline': ['Body3DKeypointsPipeline'], | 'body_3d_keypoints_pipeline': ['Body3DKeypointsPipeline'], | ||||
@@ -0,0 +1,63 @@ | |||||
import math | |||||
import os.path as osp | |||||
from typing import Any, Dict | |||||
from modelscope.metainfo import Pipelines | |||||
from modelscope.models.cv.action_detection import ActionDetONNX | |||||
from modelscope.outputs import OutputKeys | |||||
from modelscope.pipelines.base import Input, Pipeline | |||||
from modelscope.pipelines.builder import PIPELINES | |||||
from modelscope.utils.config import Config | |||||
from modelscope.utils.constant import ModelFile, Tasks | |||||
from modelscope.utils.logger import get_logger | |||||
logger = get_logger() | |||||
@PIPELINES.register_module( | |||||
Tasks.action_detection, module_name=Pipelines.action_detection) | |||||
class ActionDetectionPipeline(Pipeline): | |||||
def __init__(self, model: str, **kwargs): | |||||
""" | |||||
use `model` to create a action detection pipeline for prediction | |||||
Args: | |||||
model: model id on modelscope hub. | |||||
""" | |||||
super().__init__(model=model, **kwargs) | |||||
model_path = osp.join(self.model, ModelFile.ONNX_MODEL_FILE) | |||||
logger.info(f'loading model from {model_path}') | |||||
config_path = osp.join(self.model, ModelFile.CONFIGURATION) | |||||
logger.info(f'loading config from {config_path}') | |||||
self.cfg = Config.from_file(config_path) | |||||
self.cfg.MODEL.model_file = model_path | |||||
self.model = ActionDetONNX(self.model, self.cfg.MODEL, | |||||
self.device_name) | |||||
logger.info('load model done') | |||||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||||
if isinstance(input, str): | |||||
video_name = input | |||||
else: | |||||
raise TypeError(f'input should be a str,' | |||||
f' but got {type(input)}') | |||||
result = {'video_name': video_name} | |||||
return result | |||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
preds = self.model.forward(input['video_name']) | |||||
labels = sum([pred['actions']['labels'] for pred in preds], []) | |||||
scores = sum([pred['actions']['scores'] for pred in preds], []) | |||||
boxes = sum([pred['actions']['boxes'] for pred in preds], []) | |||||
timestamps = sum([[pred['timestamp']] * len(pred['actions']['labels']) | |||||
for pred in preds], []) | |||||
out = { | |||||
OutputKeys.TIMESTAMPS: timestamps, | |||||
OutputKeys.LABELS: labels, | |||||
OutputKeys.SCORES: scores, | |||||
OutputKeys.BOXES: boxes | |||||
} | |||||
return out | |||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
return inputs |
@@ -58,6 +58,7 @@ class CVTasks(object): | |||||
# video recognition | # video recognition | ||||
live_category = 'live-category' | live_category = 'live-category' | ||||
action_recognition = 'action-recognition' | action_recognition = 'action-recognition' | ||||
action_detection = 'action-detection' | |||||
video_category = 'video-category' | video_category = 'video-category' | ||||
video_embedding = 'video-embedding' | video_embedding = 'video-embedding' | ||||
virtual_try_on = 'virtual-try-on' | virtual_try_on = 'virtual-try-on' | ||||
@@ -0,0 +1,22 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import unittest | |||||
from modelscope.pipelines import pipeline | |||||
from modelscope.utils.constant import ModelFile, Tasks | |||||
from modelscope.utils.test_utils import test_level | |||||
class ActionDetectionTest(unittest.TestCase): | |||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
def test_run(self): | |||||
action_detection_pipline = pipeline( | |||||
Tasks.action_detection, | |||||
model='damo/cv_ResNetC3D_action-detection_detection2d') | |||||
result = action_detection_pipline( | |||||
'data/test/videos/action_detection_test_video.mp4') | |||||
print('action detection results:', result) | |||||
if __name__ == '__main__': | |||||
unittest.main() |