Browse Source

[to #42322933] Add action-detection model

添加新的action-detection task
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9898947
master
lllcho.lc yingda.chen 3 years ago
parent
commit
291f8fe68c
10 changed files with 307 additions and 0 deletions
  1. +3
    -0
      data/test/videos/action_detection_test_video.mp4
  2. +1
    -0
      modelscope/metainfo.py
  3. +21
    -0
      modelscope/models/cv/action_detection/__init__.py
  4. +177
    -0
      modelscope/models/cv/action_detection/action_detection_onnx.py
  5. +15
    -0
      modelscope/outputs.py
  6. +2
    -0
      modelscope/pipelines/builder.py
  7. +2
    -0
      modelscope/pipelines/cv/__init__.py
  8. +63
    -0
      modelscope/pipelines/cv/action_detection_pipeline.py
  9. +1
    -0
      modelscope/utils/constant.py
  10. +22
    -0
      tests/pipelines/test_action_detection.py

+ 3
- 0
data/test/videos/action_detection_test_video.mp4 View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0b7c3bc7c82ea5fee9d83130041df01046d89143ff77058b04577455ff6fdc92
size 3191059

+ 1
- 0
modelscope/metainfo.py View File

@@ -133,6 +133,7 @@ class Pipelines(object):
skin_retouching = 'unet-skin-retouching'
tinynas_classification = 'tinynas-classification'
crowd_counting = 'hrnet-crowd-counting'
action_detection = 'ResNetC3D-action-detection'
video_single_object_tracking = 'ostrack-vitb-video-single-object-tracking'
image_panoptic_segmentation = 'image-panoptic-segmentation'
video_summarization = 'googlenet_pgl_video_summarization'


+ 21
- 0
modelscope/models/cv/action_detection/__init__.py View File

@@ -0,0 +1,21 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING

from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:

from .action_detection_onnx import ActionDetONNX

else:
_import_structure = {'action_detection_onnx': ['ActionDetONNX']}

import sys

sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

+ 177
- 0
modelscope/models/cv/action_detection/action_detection_onnx.py View File

@@ -0,0 +1,177 @@
import os
import os.path as osp
import shutil
import subprocess

import cv2
import numpy as np
import onnxruntime as rt

from modelscope.models import Model
from modelscope.utils.constant import Devices
from modelscope.utils.device import verify_device


class ActionDetONNX(Model):

def __init__(self, model_dir, config, *args, **kwargs):
super().__init__(self, model_dir, *args, **kwargs)
model_file = osp.join(config['model_file'])
device_type, device_id = verify_device(self._device_name)
options = rt.SessionOptions()
options.intra_op_num_threads = 1
options.inter_op_num_threads = 1
if device_type == Devices.gpu:
sess = rt.InferenceSession(
model_file,
providers=['CUDAExecutionProvider'],
sess_options=options,
provider_options=[{
'device_id': device_id
}])
else:
sess = rt.InferenceSession(
model_file,
providers=['CPUExecutionProvider'],
sess_options=options)
self.input_name = sess.get_inputs()[0].name
self.sess = sess
self.num_stride = len(config['fpn_strides'])
self.score_thresh = np.asarray(
config['pre_nms_thresh'], dtype='float32').reshape((1, -1))
self.size_divisibility = config['size_divisibility']
self.nms_threshold = config['nms_thresh']
self.tmp_dir = config['tmp_dir']
self.temporal_stride = config['step']
self.input_data_type = config['input_type']
self.action_names = config['action_names']
self.video_length_limit = config['video_length_limit']

def resize_box(self, det, height, width, scale_h, scale_w):
bboxs = det[0]
bboxs[:, [0, 2]] *= scale_w
bboxs[:, [1, 3]] *= scale_h
bboxs[:, [0, 2]] = bboxs[:, [0, 2]].clip(0, width - 1)
bboxs[:, [1, 3]] = bboxs[:, [1, 3]].clip(0, height - 1)
result = {
'boxes': bboxs.round().astype('int32').tolist(),
'scores': det[1].tolist(),
'labels': [self.action_names[i] for i in det[2].tolist()]
}
return result

def parse_frames(self, frame_names):
imgs = [cv2.imread(name)[:, :, ::-1] for name in frame_names]
imgs = np.stack(imgs).astype(self.input_data_type).transpose(
(3, 0, 1, 2)) # c,t,h,w
imgs = imgs[None]
return imgs

def forward_img(self, imgs, h, w):
pred = self.sess.run(None, {
self.input_name: imgs,
'height': np.asarray(h),
'width': np.asarray(w)
})
dets = self.post_nms(
pred,
score_threshold=self.score_thresh,
nms_threshold=self.nms_threshold)
return dets

def forward_video(self, video_name, scale):
min_size, max_size = self._get_sizes(scale)

tmp_dir = osp.join(self.tmp_dir, osp.basename(video_name)[:-4])
if osp.exists(tmp_dir):
shutil.rmtree(tmp_dir)
os.makedirs(tmp_dir)
frame_rate = 2
cmd = f'ffmpeg -y -loglevel quiet -ss 0 -t {self.video_length_limit}' + \
f' -i {video_name} -r {frame_rate} -f image2 {tmp_dir}/%06d.jpg'

cmd = cmd.split(' ')
subprocess.call(cmd)

frame_names = [
osp.join(tmp_dir, name) for name in sorted(os.listdir(tmp_dir))
if name.endswith('.jpg')
]
frame_names = [
frame_names[i:i + frame_rate * 2]
for i in range(0,
len(frame_names) - frame_rate * 2 + 1, frame_rate
* self.temporal_stride)
]
timestamp = list(
range(1,
len(frame_names) * self.temporal_stride,
self.temporal_stride))
batch_imgs = [self.parse_frames(names) for names in frame_names]

N, _, T, H, W = batch_imgs[0].shape
scale_min = min_size / min(H, W)
h, w = min(int(scale_min * H),
max_size), min(int(scale_min * W), max_size)
h = round(h / self.size_divisibility) * self.size_divisibility
w = round(w / self.size_divisibility) * self.size_divisibility
scale_h, scale_w = H / h, W / w

results = []
for imgs in batch_imgs:
det = self.forward_img(imgs, h, w)
det = self.resize_box(det[0], H, W, scale_h, scale_w)
results.append(det)
results = [{
'timestamp': t,
'actions': res
} for t, res in zip(timestamp, results)]
shutil.rmtree(tmp_dir)
return results

def forward(self, video_name):
return self.forward_video(video_name, scale=1)

def post_nms(self, pred, score_threshold, nms_threshold=0.3):
pred_bboxes, pred_scores = pred
N = len(pred_bboxes)
dets = []
for i in range(N):
bboxes, scores = pred_bboxes[i], pred_scores[i]
candidate_inds = scores > score_threshold
scores = scores[candidate_inds]
candidate_nonzeros = candidate_inds.nonzero()
bboxes = bboxes[candidate_nonzeros[0]]
labels = candidate_nonzeros[1]
keep = self._nms(bboxes, scores, labels, nms_threshold)
bbox = bboxes[keep]
score = scores[keep]
label = labels[keep]
dets.append((bbox, score, label))
return dets

def _nms(self, boxes, scores, idxs, nms_threshold):
if len(boxes) == 0:
return []
max_coordinate = boxes.max()
offsets = idxs * (max_coordinate + 1)
boxes_for_nms = boxes + offsets[:, None].astype('float32')
boxes_for_nms[:, 2] = boxes_for_nms[:, 2] - boxes_for_nms[:, 0]
boxes_for_nms[:, 3] = boxes_for_nms[:, 3] - boxes_for_nms[:, 1]
keep = cv2.dnn.NMSBoxes(
boxes_for_nms.tolist(),
scores.tolist(),
score_threshold=0,
nms_threshold=nms_threshold)
if len(keep.shape) == 2:
keep = np.squeeze(keep, 1)
return keep

def _get_sizes(self, scale):
if scale == 1:
min_size, max_size = 512, 896
elif scale == 2:
min_size, max_size = 768, 1280
else:
min_size, max_size = 1024, 1792
return min_size, max_size

+ 15
- 0
modelscope/outputs.py View File

@@ -35,6 +35,7 @@ class OutputKeys(object):
UUID = 'uuid'
WORD = 'word'
KWS_LIST = 'kws_list'
TIMESTAMPS = 'timestamps'
SPLIT_VIDEO_NUM = 'split_video_num'
SPLIT_META_DICT = 'split_meta_dict'

@@ -541,6 +542,19 @@ TASK_OUTPUTS = {
# }
Tasks.visual_entailment: [OutputKeys.SCORES, OutputKeys.LABELS],

# {
# 'labels': ['吸烟', '打电话', '吸烟'],
# 'scores': [0.7527753114700317, 0.753358006477356, 0.6880350708961487],
# 'boxes': [[547, 2, 1225, 719], [529, 8, 1255, 719], [584, 0, 1269, 719]],
# 'timestamps': [1, 3, 5]
# }
Tasks.action_detection: [
OutputKeys.TIMESTAMPS,
OutputKeys.LABELS,
OutputKeys.SCORES,
OutputKeys.BOXES,
],

# {
# 'output': [
# [{'label': '6527856', 'score': 0.9942756295204163}, {'label': '1000012000', 'score': 0.0379515215754509},
@@ -551,6 +565,7 @@ TASK_OUTPUTS = {
# {'label': '13421097', 'score': 2.75914817393641e-06}]]
# }
Tasks.faq_question_answering: [OutputKeys.OUTPUT],

# image person reid result for single sample
# {
# "img_embedding": np.array with shape [1, D],


+ 2
- 0
modelscope/pipelines/builder.py View File

@@ -71,6 +71,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask-large'),
Tasks.action_recognition: (Pipelines.action_recognition,
'damo/cv_TAdaConv_action-recognition'),
Tasks.action_detection: (Pipelines.action_detection,
'damo/cv_ResNetC3D_action-detection_detection2d'),
Tasks.live_category: (Pipelines.live_category,
'damo/cv_resnet50_live-category'),
Tasks.video_category: (Pipelines.video_category,


+ 2
- 0
modelscope/pipelines/cv/__init__.py View File

@@ -5,6 +5,7 @@ from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .action_recognition_pipeline import ActionRecognitionPipeline
from .action_detection_pipeline import ActionDetectionPipeline
from .animal_recognition_pipeline import AnimalRecognitionPipeline
from .body_2d_keypoints_pipeline import Body2DKeypointsPipeline
from .body_3d_keypoints_pipeline import Body3DKeypointsPipeline
@@ -48,6 +49,7 @@ if TYPE_CHECKING:
else:
_import_structure = {
'action_recognition_pipeline': ['ActionRecognitionPipeline'],
'action_detection_pipeline': ['ActionDetectionPipeline'],
'animal_recognition_pipeline': ['AnimalRecognitionPipeline'],
'body_2d_keypoints_pipeline': ['Body2DKeypointsPipeline'],
'body_3d_keypoints_pipeline': ['Body3DKeypointsPipeline'],


+ 63
- 0
modelscope/pipelines/cv/action_detection_pipeline.py View File

@@ -0,0 +1,63 @@
import math
import os.path as osp
from typing import Any, Dict

from modelscope.metainfo import Pipelines
from modelscope.models.cv.action_detection import ActionDetONNX
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


@PIPELINES.register_module(
Tasks.action_detection, module_name=Pipelines.action_detection)
class ActionDetectionPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
"""
use `model` to create a action detection pipeline for prediction
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model, **kwargs)
model_path = osp.join(self.model, ModelFile.ONNX_MODEL_FILE)
logger.info(f'loading model from {model_path}')
config_path = osp.join(self.model, ModelFile.CONFIGURATION)
logger.info(f'loading config from {config_path}')
self.cfg = Config.from_file(config_path)
self.cfg.MODEL.model_file = model_path
self.model = ActionDetONNX(self.model, self.cfg.MODEL,
self.device_name)
logger.info('load model done')

def preprocess(self, input: Input) -> Dict[str, Any]:
if isinstance(input, str):
video_name = input
else:
raise TypeError(f'input should be a str,'
f' but got {type(input)}')
result = {'video_name': video_name}
return result

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
preds = self.model.forward(input['video_name'])
labels = sum([pred['actions']['labels'] for pred in preds], [])
scores = sum([pred['actions']['scores'] for pred in preds], [])
boxes = sum([pred['actions']['boxes'] for pred in preds], [])
timestamps = sum([[pred['timestamp']] * len(pred['actions']['labels'])
for pred in preds], [])
out = {
OutputKeys.TIMESTAMPS: timestamps,
OutputKeys.LABELS: labels,
OutputKeys.SCORES: scores,
OutputKeys.BOXES: boxes
}
return out

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

+ 1
- 0
modelscope/utils/constant.py View File

@@ -58,6 +58,7 @@ class CVTasks(object):
# video recognition
live_category = 'live-category'
action_recognition = 'action-recognition'
action_detection = 'action-detection'
video_category = 'video-category'
video_embedding = 'video-embedding'
virtual_try_on = 'virtual-try-on'


+ 22
- 0
tests/pipelines/test_action_detection.py View File

@@ -0,0 +1,22 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest

from modelscope.pipelines import pipeline
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.test_utils import test_level


class ActionDetectionTest(unittest.TestCase):

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run(self):
action_detection_pipline = pipeline(
Tasks.action_detection,
model='damo/cv_ResNetC3D_action-detection_detection2d')
result = action_detection_pipline(
'data/test/videos/action_detection_test_video.mp4')
print('action detection results:', result)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save