From 291f8fe68c3462abc6462c5e408e7f349203f630 Mon Sep 17 00:00:00 2001 From: "lllcho.lc" Date: Thu, 1 Sep 2022 18:14:37 +0800 Subject: [PATCH] [to #42322933] Add action-detection model MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 添加新的action-detection task Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9898947 --- .../videos/action_detection_test_video.mp4 | 3 + modelscope/metainfo.py | 1 + .../models/cv/action_detection/__init__.py | 21 +++ .../action_detection/action_detection_onnx.py | 177 ++++++++++++++++++ modelscope/outputs.py | 15 ++ modelscope/pipelines/builder.py | 2 + modelscope/pipelines/cv/__init__.py | 2 + .../pipelines/cv/action_detection_pipeline.py | 63 +++++++ modelscope/utils/constant.py | 1 + tests/pipelines/test_action_detection.py | 22 +++ 10 files changed, 307 insertions(+) create mode 100644 data/test/videos/action_detection_test_video.mp4 create mode 100644 modelscope/models/cv/action_detection/__init__.py create mode 100644 modelscope/models/cv/action_detection/action_detection_onnx.py create mode 100644 modelscope/pipelines/cv/action_detection_pipeline.py create mode 100644 tests/pipelines/test_action_detection.py diff --git a/data/test/videos/action_detection_test_video.mp4 b/data/test/videos/action_detection_test_video.mp4 new file mode 100644 index 00000000..e2ea1d80 --- /dev/null +++ b/data/test/videos/action_detection_test_video.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0b7c3bc7c82ea5fee9d83130041df01046d89143ff77058b04577455ff6fdc92 +size 3191059 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 6f34b1a3..7c5afe80 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -133,6 +133,7 @@ class Pipelines(object): skin_retouching = 'unet-skin-retouching' tinynas_classification = 'tinynas-classification' crowd_counting = 'hrnet-crowd-counting' + action_detection = 'ResNetC3D-action-detection' video_single_object_tracking = 'ostrack-vitb-video-single-object-tracking' image_panoptic_segmentation = 'image-panoptic-segmentation' video_summarization = 'googlenet_pgl_video_summarization' diff --git a/modelscope/models/cv/action_detection/__init__.py b/modelscope/models/cv/action_detection/__init__.py new file mode 100644 index 00000000..fedbe19c --- /dev/null +++ b/modelscope/models/cv/action_detection/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + + from .action_detection_onnx import ActionDetONNX + +else: + _import_structure = {'action_detection_onnx': ['ActionDetONNX']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/action_detection/action_detection_onnx.py b/modelscope/models/cv/action_detection/action_detection_onnx.py new file mode 100644 index 00000000..3c171473 --- /dev/null +++ b/modelscope/models/cv/action_detection/action_detection_onnx.py @@ -0,0 +1,177 @@ +import os +import os.path as osp +import shutil +import subprocess + +import cv2 +import numpy as np +import onnxruntime as rt + +from modelscope.models import Model +from modelscope.utils.constant import Devices +from modelscope.utils.device import verify_device + + +class ActionDetONNX(Model): + + def __init__(self, model_dir, config, *args, **kwargs): + super().__init__(self, model_dir, *args, **kwargs) + model_file = osp.join(config['model_file']) + device_type, device_id = verify_device(self._device_name) + options = rt.SessionOptions() + options.intra_op_num_threads = 1 + options.inter_op_num_threads = 1 + if device_type == Devices.gpu: + sess = rt.InferenceSession( + model_file, + providers=['CUDAExecutionProvider'], + sess_options=options, + provider_options=[{ + 'device_id': device_id + }]) + else: + sess = rt.InferenceSession( + model_file, + providers=['CPUExecutionProvider'], + sess_options=options) + self.input_name = sess.get_inputs()[0].name + self.sess = sess + self.num_stride = len(config['fpn_strides']) + self.score_thresh = np.asarray( + config['pre_nms_thresh'], dtype='float32').reshape((1, -1)) + self.size_divisibility = config['size_divisibility'] + self.nms_threshold = config['nms_thresh'] + self.tmp_dir = config['tmp_dir'] + self.temporal_stride = config['step'] + self.input_data_type = config['input_type'] + self.action_names = config['action_names'] + self.video_length_limit = config['video_length_limit'] + + def resize_box(self, det, height, width, scale_h, scale_w): + bboxs = det[0] + bboxs[:, [0, 2]] *= scale_w + bboxs[:, [1, 3]] *= scale_h + bboxs[:, [0, 2]] = bboxs[:, [0, 2]].clip(0, width - 1) + bboxs[:, [1, 3]] = bboxs[:, [1, 3]].clip(0, height - 1) + result = { + 'boxes': bboxs.round().astype('int32').tolist(), + 'scores': det[1].tolist(), + 'labels': [self.action_names[i] for i in det[2].tolist()] + } + return result + + def parse_frames(self, frame_names): + imgs = [cv2.imread(name)[:, :, ::-1] for name in frame_names] + imgs = np.stack(imgs).astype(self.input_data_type).transpose( + (3, 0, 1, 2)) # c,t,h,w + imgs = imgs[None] + return imgs + + def forward_img(self, imgs, h, w): + pred = self.sess.run(None, { + self.input_name: imgs, + 'height': np.asarray(h), + 'width': np.asarray(w) + }) + dets = self.post_nms( + pred, + score_threshold=self.score_thresh, + nms_threshold=self.nms_threshold) + return dets + + def forward_video(self, video_name, scale): + min_size, max_size = self._get_sizes(scale) + + tmp_dir = osp.join(self.tmp_dir, osp.basename(video_name)[:-4]) + if osp.exists(tmp_dir): + shutil.rmtree(tmp_dir) + os.makedirs(tmp_dir) + frame_rate = 2 + cmd = f'ffmpeg -y -loglevel quiet -ss 0 -t {self.video_length_limit}' + \ + f' -i {video_name} -r {frame_rate} -f image2 {tmp_dir}/%06d.jpg' + + cmd = cmd.split(' ') + subprocess.call(cmd) + + frame_names = [ + osp.join(tmp_dir, name) for name in sorted(os.listdir(tmp_dir)) + if name.endswith('.jpg') + ] + frame_names = [ + frame_names[i:i + frame_rate * 2] + for i in range(0, + len(frame_names) - frame_rate * 2 + 1, frame_rate + * self.temporal_stride) + ] + timestamp = list( + range(1, + len(frame_names) * self.temporal_stride, + self.temporal_stride)) + batch_imgs = [self.parse_frames(names) for names in frame_names] + + N, _, T, H, W = batch_imgs[0].shape + scale_min = min_size / min(H, W) + h, w = min(int(scale_min * H), + max_size), min(int(scale_min * W), max_size) + h = round(h / self.size_divisibility) * self.size_divisibility + w = round(w / self.size_divisibility) * self.size_divisibility + scale_h, scale_w = H / h, W / w + + results = [] + for imgs in batch_imgs: + det = self.forward_img(imgs, h, w) + det = self.resize_box(det[0], H, W, scale_h, scale_w) + results.append(det) + results = [{ + 'timestamp': t, + 'actions': res + } for t, res in zip(timestamp, results)] + shutil.rmtree(tmp_dir) + return results + + def forward(self, video_name): + return self.forward_video(video_name, scale=1) + + def post_nms(self, pred, score_threshold, nms_threshold=0.3): + pred_bboxes, pred_scores = pred + N = len(pred_bboxes) + dets = [] + for i in range(N): + bboxes, scores = pred_bboxes[i], pred_scores[i] + candidate_inds = scores > score_threshold + scores = scores[candidate_inds] + candidate_nonzeros = candidate_inds.nonzero() + bboxes = bboxes[candidate_nonzeros[0]] + labels = candidate_nonzeros[1] + keep = self._nms(bboxes, scores, labels, nms_threshold) + bbox = bboxes[keep] + score = scores[keep] + label = labels[keep] + dets.append((bbox, score, label)) + return dets + + def _nms(self, boxes, scores, idxs, nms_threshold): + if len(boxes) == 0: + return [] + max_coordinate = boxes.max() + offsets = idxs * (max_coordinate + 1) + boxes_for_nms = boxes + offsets[:, None].astype('float32') + boxes_for_nms[:, 2] = boxes_for_nms[:, 2] - boxes_for_nms[:, 0] + boxes_for_nms[:, 3] = boxes_for_nms[:, 3] - boxes_for_nms[:, 1] + keep = cv2.dnn.NMSBoxes( + boxes_for_nms.tolist(), + scores.tolist(), + score_threshold=0, + nms_threshold=nms_threshold) + if len(keep.shape) == 2: + keep = np.squeeze(keep, 1) + return keep + + def _get_sizes(self, scale): + if scale == 1: + min_size, max_size = 512, 896 + elif scale == 2: + min_size, max_size = 768, 1280 + else: + min_size, max_size = 1024, 1792 + return min_size, max_size diff --git a/modelscope/outputs.py b/modelscope/outputs.py index aebb9138..7d6cdb59 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -35,6 +35,7 @@ class OutputKeys(object): UUID = 'uuid' WORD = 'word' KWS_LIST = 'kws_list' + TIMESTAMPS = 'timestamps' SPLIT_VIDEO_NUM = 'split_video_num' SPLIT_META_DICT = 'split_meta_dict' @@ -541,6 +542,19 @@ TASK_OUTPUTS = { # } Tasks.visual_entailment: [OutputKeys.SCORES, OutputKeys.LABELS], + # { + # 'labels': ['吸烟', '打电话', '吸烟'], + # 'scores': [0.7527753114700317, 0.753358006477356, 0.6880350708961487], + # 'boxes': [[547, 2, 1225, 719], [529, 8, 1255, 719], [584, 0, 1269, 719]], + # 'timestamps': [1, 3, 5] + # } + Tasks.action_detection: [ + OutputKeys.TIMESTAMPS, + OutputKeys.LABELS, + OutputKeys.SCORES, + OutputKeys.BOXES, + ], + # { # 'output': [ # [{'label': '6527856', 'score': 0.9942756295204163}, {'label': '1000012000', 'score': 0.0379515215754509}, @@ -551,6 +565,7 @@ TASK_OUTPUTS = { # {'label': '13421097', 'score': 2.75914817393641e-06}]] # } Tasks.faq_question_answering: [OutputKeys.OUTPUT], + # image person reid result for single sample # { # "img_embedding": np.array with shape [1, D], diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 8a1a3646..c9f0c252 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -71,6 +71,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask-large'), Tasks.action_recognition: (Pipelines.action_recognition, 'damo/cv_TAdaConv_action-recognition'), + Tasks.action_detection: (Pipelines.action_detection, + 'damo/cv_ResNetC3D_action-detection_detection2d'), Tasks.live_category: (Pipelines.live_category, 'damo/cv_resnet50_live-category'), Tasks.video_category: (Pipelines.video_category, diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index 01c69758..f4e6792b 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -5,6 +5,7 @@ from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: from .action_recognition_pipeline import ActionRecognitionPipeline + from .action_detection_pipeline import ActionDetectionPipeline from .animal_recognition_pipeline import AnimalRecognitionPipeline from .body_2d_keypoints_pipeline import Body2DKeypointsPipeline from .body_3d_keypoints_pipeline import Body3DKeypointsPipeline @@ -48,6 +49,7 @@ if TYPE_CHECKING: else: _import_structure = { 'action_recognition_pipeline': ['ActionRecognitionPipeline'], + 'action_detection_pipeline': ['ActionDetectionPipeline'], 'animal_recognition_pipeline': ['AnimalRecognitionPipeline'], 'body_2d_keypoints_pipeline': ['Body2DKeypointsPipeline'], 'body_3d_keypoints_pipeline': ['Body3DKeypointsPipeline'], diff --git a/modelscope/pipelines/cv/action_detection_pipeline.py b/modelscope/pipelines/cv/action_detection_pipeline.py new file mode 100644 index 00000000..72335d5b --- /dev/null +++ b/modelscope/pipelines/cv/action_detection_pipeline.py @@ -0,0 +1,63 @@ +import math +import os.path as osp +from typing import Any, Dict + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.action_detection import ActionDetONNX +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.action_detection, module_name=Pipelines.action_detection) +class ActionDetectionPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a action detection pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + model_path = osp.join(self.model, ModelFile.ONNX_MODEL_FILE) + logger.info(f'loading model from {model_path}') + config_path = osp.join(self.model, ModelFile.CONFIGURATION) + logger.info(f'loading config from {config_path}') + self.cfg = Config.from_file(config_path) + self.cfg.MODEL.model_file = model_path + self.model = ActionDetONNX(self.model, self.cfg.MODEL, + self.device_name) + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + if isinstance(input, str): + video_name = input + else: + raise TypeError(f'input should be a str,' + f' but got {type(input)}') + result = {'video_name': video_name} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + preds = self.model.forward(input['video_name']) + labels = sum([pred['actions']['labels'] for pred in preds], []) + scores = sum([pred['actions']['scores'] for pred in preds], []) + boxes = sum([pred['actions']['boxes'] for pred in preds], []) + timestamps = sum([[pred['timestamp']] * len(pred['actions']['labels']) + for pred in preds], []) + out = { + OutputKeys.TIMESTAMPS: timestamps, + OutputKeys.LABELS: labels, + OutputKeys.SCORES: scores, + OutputKeys.BOXES: boxes + } + return out + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 960e9600..2265ef5a 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -58,6 +58,7 @@ class CVTasks(object): # video recognition live_category = 'live-category' action_recognition = 'action-recognition' + action_detection = 'action-detection' video_category = 'video-category' video_embedding = 'video-embedding' virtual_try_on = 'virtual-try-on' diff --git a/tests/pipelines/test_action_detection.py b/tests/pipelines/test_action_detection.py new file mode 100644 index 00000000..c752dc78 --- /dev/null +++ b/tests/pipelines/test_action_detection.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.test_utils import test_level + + +class ActionDetectionTest(unittest.TestCase): + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run(self): + action_detection_pipline = pipeline( + Tasks.action_detection, + model='damo/cv_ResNetC3D_action-detection_detection2d') + result = action_detection_pipline( + 'data/test/videos/action_detection_test_video.mp4') + print('action detection results:', result) + + +if __name__ == '__main__': + unittest.main()