添加新的action-detection task Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9898947master
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:0b7c3bc7c82ea5fee9d83130041df01046d89143ff77058b04577455ff6fdc92 | |||
size 3191059 |
@@ -133,6 +133,7 @@ class Pipelines(object): | |||
skin_retouching = 'unet-skin-retouching' | |||
tinynas_classification = 'tinynas-classification' | |||
crowd_counting = 'hrnet-crowd-counting' | |||
action_detection = 'ResNetC3D-action-detection' | |||
video_single_object_tracking = 'ostrack-vitb-video-single-object-tracking' | |||
image_panoptic_segmentation = 'image-panoptic-segmentation' | |||
video_summarization = 'googlenet_pgl_video_summarization' | |||
@@ -0,0 +1,21 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .action_detection_onnx import ActionDetONNX | |||
else: | |||
_import_structure = {'action_detection_onnx': ['ActionDetONNX']} | |||
import sys | |||
sys.modules[__name__] = LazyImportModule( | |||
__name__, | |||
globals()['__file__'], | |||
_import_structure, | |||
module_spec=__spec__, | |||
extra_objects={}, | |||
) |
@@ -0,0 +1,177 @@ | |||
import os | |||
import os.path as osp | |||
import shutil | |||
import subprocess | |||
import cv2 | |||
import numpy as np | |||
import onnxruntime as rt | |||
from modelscope.models import Model | |||
from modelscope.utils.constant import Devices | |||
from modelscope.utils.device import verify_device | |||
class ActionDetONNX(Model): | |||
def __init__(self, model_dir, config, *args, **kwargs): | |||
super().__init__(self, model_dir, *args, **kwargs) | |||
model_file = osp.join(config['model_file']) | |||
device_type, device_id = verify_device(self._device_name) | |||
options = rt.SessionOptions() | |||
options.intra_op_num_threads = 1 | |||
options.inter_op_num_threads = 1 | |||
if device_type == Devices.gpu: | |||
sess = rt.InferenceSession( | |||
model_file, | |||
providers=['CUDAExecutionProvider'], | |||
sess_options=options, | |||
provider_options=[{ | |||
'device_id': device_id | |||
}]) | |||
else: | |||
sess = rt.InferenceSession( | |||
model_file, | |||
providers=['CPUExecutionProvider'], | |||
sess_options=options) | |||
self.input_name = sess.get_inputs()[0].name | |||
self.sess = sess | |||
self.num_stride = len(config['fpn_strides']) | |||
self.score_thresh = np.asarray( | |||
config['pre_nms_thresh'], dtype='float32').reshape((1, -1)) | |||
self.size_divisibility = config['size_divisibility'] | |||
self.nms_threshold = config['nms_thresh'] | |||
self.tmp_dir = config['tmp_dir'] | |||
self.temporal_stride = config['step'] | |||
self.input_data_type = config['input_type'] | |||
self.action_names = config['action_names'] | |||
self.video_length_limit = config['video_length_limit'] | |||
def resize_box(self, det, height, width, scale_h, scale_w): | |||
bboxs = det[0] | |||
bboxs[:, [0, 2]] *= scale_w | |||
bboxs[:, [1, 3]] *= scale_h | |||
bboxs[:, [0, 2]] = bboxs[:, [0, 2]].clip(0, width - 1) | |||
bboxs[:, [1, 3]] = bboxs[:, [1, 3]].clip(0, height - 1) | |||
result = { | |||
'boxes': bboxs.round().astype('int32').tolist(), | |||
'scores': det[1].tolist(), | |||
'labels': [self.action_names[i] for i in det[2].tolist()] | |||
} | |||
return result | |||
def parse_frames(self, frame_names): | |||
imgs = [cv2.imread(name)[:, :, ::-1] for name in frame_names] | |||
imgs = np.stack(imgs).astype(self.input_data_type).transpose( | |||
(3, 0, 1, 2)) # c,t,h,w | |||
imgs = imgs[None] | |||
return imgs | |||
def forward_img(self, imgs, h, w): | |||
pred = self.sess.run(None, { | |||
self.input_name: imgs, | |||
'height': np.asarray(h), | |||
'width': np.asarray(w) | |||
}) | |||
dets = self.post_nms( | |||
pred, | |||
score_threshold=self.score_thresh, | |||
nms_threshold=self.nms_threshold) | |||
return dets | |||
def forward_video(self, video_name, scale): | |||
min_size, max_size = self._get_sizes(scale) | |||
tmp_dir = osp.join(self.tmp_dir, osp.basename(video_name)[:-4]) | |||
if osp.exists(tmp_dir): | |||
shutil.rmtree(tmp_dir) | |||
os.makedirs(tmp_dir) | |||
frame_rate = 2 | |||
cmd = f'ffmpeg -y -loglevel quiet -ss 0 -t {self.video_length_limit}' + \ | |||
f' -i {video_name} -r {frame_rate} -f image2 {tmp_dir}/%06d.jpg' | |||
cmd = cmd.split(' ') | |||
subprocess.call(cmd) | |||
frame_names = [ | |||
osp.join(tmp_dir, name) for name in sorted(os.listdir(tmp_dir)) | |||
if name.endswith('.jpg') | |||
] | |||
frame_names = [ | |||
frame_names[i:i + frame_rate * 2] | |||
for i in range(0, | |||
len(frame_names) - frame_rate * 2 + 1, frame_rate | |||
* self.temporal_stride) | |||
] | |||
timestamp = list( | |||
range(1, | |||
len(frame_names) * self.temporal_stride, | |||
self.temporal_stride)) | |||
batch_imgs = [self.parse_frames(names) for names in frame_names] | |||
N, _, T, H, W = batch_imgs[0].shape | |||
scale_min = min_size / min(H, W) | |||
h, w = min(int(scale_min * H), | |||
max_size), min(int(scale_min * W), max_size) | |||
h = round(h / self.size_divisibility) * self.size_divisibility | |||
w = round(w / self.size_divisibility) * self.size_divisibility | |||
scale_h, scale_w = H / h, W / w | |||
results = [] | |||
for imgs in batch_imgs: | |||
det = self.forward_img(imgs, h, w) | |||
det = self.resize_box(det[0], H, W, scale_h, scale_w) | |||
results.append(det) | |||
results = [{ | |||
'timestamp': t, | |||
'actions': res | |||
} for t, res in zip(timestamp, results)] | |||
shutil.rmtree(tmp_dir) | |||
return results | |||
def forward(self, video_name): | |||
return self.forward_video(video_name, scale=1) | |||
def post_nms(self, pred, score_threshold, nms_threshold=0.3): | |||
pred_bboxes, pred_scores = pred | |||
N = len(pred_bboxes) | |||
dets = [] | |||
for i in range(N): | |||
bboxes, scores = pred_bboxes[i], pred_scores[i] | |||
candidate_inds = scores > score_threshold | |||
scores = scores[candidate_inds] | |||
candidate_nonzeros = candidate_inds.nonzero() | |||
bboxes = bboxes[candidate_nonzeros[0]] | |||
labels = candidate_nonzeros[1] | |||
keep = self._nms(bboxes, scores, labels, nms_threshold) | |||
bbox = bboxes[keep] | |||
score = scores[keep] | |||
label = labels[keep] | |||
dets.append((bbox, score, label)) | |||
return dets | |||
def _nms(self, boxes, scores, idxs, nms_threshold): | |||
if len(boxes) == 0: | |||
return [] | |||
max_coordinate = boxes.max() | |||
offsets = idxs * (max_coordinate + 1) | |||
boxes_for_nms = boxes + offsets[:, None].astype('float32') | |||
boxes_for_nms[:, 2] = boxes_for_nms[:, 2] - boxes_for_nms[:, 0] | |||
boxes_for_nms[:, 3] = boxes_for_nms[:, 3] - boxes_for_nms[:, 1] | |||
keep = cv2.dnn.NMSBoxes( | |||
boxes_for_nms.tolist(), | |||
scores.tolist(), | |||
score_threshold=0, | |||
nms_threshold=nms_threshold) | |||
if len(keep.shape) == 2: | |||
keep = np.squeeze(keep, 1) | |||
return keep | |||
def _get_sizes(self, scale): | |||
if scale == 1: | |||
min_size, max_size = 512, 896 | |||
elif scale == 2: | |||
min_size, max_size = 768, 1280 | |||
else: | |||
min_size, max_size = 1024, 1792 | |||
return min_size, max_size |
@@ -35,6 +35,7 @@ class OutputKeys(object): | |||
UUID = 'uuid' | |||
WORD = 'word' | |||
KWS_LIST = 'kws_list' | |||
TIMESTAMPS = 'timestamps' | |||
SPLIT_VIDEO_NUM = 'split_video_num' | |||
SPLIT_META_DICT = 'split_meta_dict' | |||
@@ -541,6 +542,19 @@ TASK_OUTPUTS = { | |||
# } | |||
Tasks.visual_entailment: [OutputKeys.SCORES, OutputKeys.LABELS], | |||
# { | |||
# 'labels': ['吸烟', '打电话', '吸烟'], | |||
# 'scores': [0.7527753114700317, 0.753358006477356, 0.6880350708961487], | |||
# 'boxes': [[547, 2, 1225, 719], [529, 8, 1255, 719], [584, 0, 1269, 719]], | |||
# 'timestamps': [1, 3, 5] | |||
# } | |||
Tasks.action_detection: [ | |||
OutputKeys.TIMESTAMPS, | |||
OutputKeys.LABELS, | |||
OutputKeys.SCORES, | |||
OutputKeys.BOXES, | |||
], | |||
# { | |||
# 'output': [ | |||
# [{'label': '6527856', 'score': 0.9942756295204163}, {'label': '1000012000', 'score': 0.0379515215754509}, | |||
@@ -551,6 +565,7 @@ TASK_OUTPUTS = { | |||
# {'label': '13421097', 'score': 2.75914817393641e-06}]] | |||
# } | |||
Tasks.faq_question_answering: [OutputKeys.OUTPUT], | |||
# image person reid result for single sample | |||
# { | |||
# "img_embedding": np.array with shape [1, D], | |||
@@ -71,6 +71,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask-large'), | |||
Tasks.action_recognition: (Pipelines.action_recognition, | |||
'damo/cv_TAdaConv_action-recognition'), | |||
Tasks.action_detection: (Pipelines.action_detection, | |||
'damo/cv_ResNetC3D_action-detection_detection2d'), | |||
Tasks.live_category: (Pipelines.live_category, | |||
'damo/cv_resnet50_live-category'), | |||
Tasks.video_category: (Pipelines.video_category, | |||
@@ -5,6 +5,7 @@ from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .action_recognition_pipeline import ActionRecognitionPipeline | |||
from .action_detection_pipeline import ActionDetectionPipeline | |||
from .animal_recognition_pipeline import AnimalRecognitionPipeline | |||
from .body_2d_keypoints_pipeline import Body2DKeypointsPipeline | |||
from .body_3d_keypoints_pipeline import Body3DKeypointsPipeline | |||
@@ -48,6 +49,7 @@ if TYPE_CHECKING: | |||
else: | |||
_import_structure = { | |||
'action_recognition_pipeline': ['ActionRecognitionPipeline'], | |||
'action_detection_pipeline': ['ActionDetectionPipeline'], | |||
'animal_recognition_pipeline': ['AnimalRecognitionPipeline'], | |||
'body_2d_keypoints_pipeline': ['Body2DKeypointsPipeline'], | |||
'body_3d_keypoints_pipeline': ['Body3DKeypointsPipeline'], | |||
@@ -0,0 +1,63 @@ | |||
import math | |||
import os.path as osp | |||
from typing import Any, Dict | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.models.cv.action_detection import ActionDetONNX | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.base import Input, Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
@PIPELINES.register_module( | |||
Tasks.action_detection, module_name=Pipelines.action_detection) | |||
class ActionDetectionPipeline(Pipeline): | |||
def __init__(self, model: str, **kwargs): | |||
""" | |||
use `model` to create a action detection pipeline for prediction | |||
Args: | |||
model: model id on modelscope hub. | |||
""" | |||
super().__init__(model=model, **kwargs) | |||
model_path = osp.join(self.model, ModelFile.ONNX_MODEL_FILE) | |||
logger.info(f'loading model from {model_path}') | |||
config_path = osp.join(self.model, ModelFile.CONFIGURATION) | |||
logger.info(f'loading config from {config_path}') | |||
self.cfg = Config.from_file(config_path) | |||
self.cfg.MODEL.model_file = model_path | |||
self.model = ActionDetONNX(self.model, self.cfg.MODEL, | |||
self.device_name) | |||
logger.info('load model done') | |||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||
if isinstance(input, str): | |||
video_name = input | |||
else: | |||
raise TypeError(f'input should be a str,' | |||
f' but got {type(input)}') | |||
result = {'video_name': video_name} | |||
return result | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
preds = self.model.forward(input['video_name']) | |||
labels = sum([pred['actions']['labels'] for pred in preds], []) | |||
scores = sum([pred['actions']['scores'] for pred in preds], []) | |||
boxes = sum([pred['actions']['boxes'] for pred in preds], []) | |||
timestamps = sum([[pred['timestamp']] * len(pred['actions']['labels']) | |||
for pred in preds], []) | |||
out = { | |||
OutputKeys.TIMESTAMPS: timestamps, | |||
OutputKeys.LABELS: labels, | |||
OutputKeys.SCORES: scores, | |||
OutputKeys.BOXES: boxes | |||
} | |||
return out | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
return inputs |
@@ -58,6 +58,7 @@ class CVTasks(object): | |||
# video recognition | |||
live_category = 'live-category' | |||
action_recognition = 'action-recognition' | |||
action_detection = 'action-detection' | |||
video_category = 'video-category' | |||
video_embedding = 'video-embedding' | |||
virtual_try_on = 'virtual-try-on' | |||
@@ -0,0 +1,22 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import unittest | |||
from modelscope.pipelines import pipeline | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.test_utils import test_level | |||
class ActionDetectionTest(unittest.TestCase): | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run(self): | |||
action_detection_pipline = pipeline( | |||
Tasks.action_detection, | |||
model='damo/cv_ResNetC3D_action-detection_detection2d') | |||
result = action_detection_pipline( | |||
'data/test/videos/action_detection_test_video.mp4') | |||
print('action detection results:', result) | |||
if __name__ == '__main__': | |||
unittest.main() |