Browse Source

[to #42322933] add PST action recognition model

Add patch shift transformer model for action recognition task.
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10282964
master
lllcho.lc yingda.chen 3 years ago
parent
commit
9fa761d7a6
5 changed files with 1262 additions and 1 deletions
  1. +1
    -0
      modelscope/metainfo.py
  2. +2
    -0
      modelscope/models/cv/action_recognition/__init__.py
  3. +1198
    -0
      modelscope/models/cv/action_recognition/temporal_patch_shift_transformer.py
  4. +53
    -1
      modelscope/pipelines/cv/action_recognition_pipeline.py
  5. +8
    -0
      tests/pipelines/test_action_recognition.py

+ 1
- 0
modelscope/metainfo.py View File

@@ -179,6 +179,7 @@ class Pipelines(object):
movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation' movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation'
shop_segmentation = 'shop-segmentation' shop_segmentation = 'shop-segmentation'
video_inpainting = 'video-inpainting' video_inpainting = 'video-inpainting'
pst_action_recognition = 'patchshift-action-recognition'
hand_static = 'hand-static' hand_static = 'hand-static'


# nlp tasks # nlp tasks


+ 2
- 0
modelscope/models/cv/action_recognition/__init__.py View File

@@ -7,11 +7,13 @@ if TYPE_CHECKING:


from .models import BaseVideoModel from .models import BaseVideoModel
from .tada_convnext import TadaConvNeXt from .tada_convnext import TadaConvNeXt
from .temporal_patch_shift_transformer import PatchShiftTransformer


else: else:
_import_structure = { _import_structure = {
'models': ['BaseVideoModel'], 'models': ['BaseVideoModel'],
'tada_convnext': ['TadaConvNeXt'], 'tada_convnext': ['TadaConvNeXt'],
'temporal_patch_shift_transformer': ['PatchShiftTransformer']
} }


import sys import sys


+ 1198
- 0
modelscope/models/cv/action_recognition/temporal_patch_shift_transformer.py
File diff suppressed because it is too large
View File


+ 53
- 1
modelscope/pipelines/cv/action_recognition_pipeline.py View File

@@ -7,7 +7,8 @@ from typing import Any, Dict
import torch import torch


from modelscope.metainfo import Pipelines from modelscope.metainfo import Pipelines
from modelscope.models.cv.action_recognition import BaseVideoModel
from modelscope.models.cv.action_recognition import (BaseVideoModel,
PatchShiftTransformer)
from modelscope.outputs import OutputKeys from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES from modelscope.pipelines.builder import PIPELINES
@@ -69,3 +70,54 @@ class ActionRecognitionPipeline(Pipeline):


def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs return inputs


@PIPELINES.register_module(
Tasks.action_recognition, module_name=Pipelines.pst_action_recognition)
class PSTActionRecognitionPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
"""
use `model` to create a PST action recognition pipeline for prediction
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model, **kwargs)
model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE)
logger.info(f'loading model from {model_path}')
config_path = osp.join(self.model, ModelFile.CONFIGURATION)
logger.info(f'loading config from {config_path}')
self.cfg = Config.from_file(config_path)
self.infer_model = PatchShiftTransformer(model).to(self.device)
self.infer_model.eval()
self.infer_model.load_state_dict(
torch.load(model_path, map_location=self.device)['state_dict'])
self.label_mapping = self.cfg.label_mapping
logger.info('load model done')

def preprocess(self, input: Input) -> Dict[str, Any]:
if isinstance(input, str):
video_input_data = ReadVideoData(self.cfg, input).to(self.device)
else:
raise TypeError(f'input should be a str,'
f' but got {type(input)}')
result = {'video_data': video_input_data}
return result

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
pred = self.perform_inference(input['video_data'])
output_label = self.label_mapping[str(pred)]
return {OutputKeys.LABELS: output_label}

@torch.no_grad()
def perform_inference(self, data, max_bsz=4):
iter_num = math.ceil(data.size(0) / max_bsz)
preds_list = []
for i in range(iter_num):
preds_list.append(
self.infer_model(data[i * max_bsz:(i + 1) * max_bsz]))
pred = torch.cat(preds_list, dim=0)
return pred.mean(dim=0).argmax().item()

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

+ 8
- 0
tests/pipelines/test_action_recognition.py View File

@@ -29,6 +29,14 @@ class ActionRecognitionTest(unittest.TestCase, DemoCompatibilityCheck):


print(f'recognition output: {result}.') print(f'recognition output: {result}.')


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_pst(self):
pst_recognition_pipeline = pipeline(
self.task, model='damo/cv_pathshift_action-recognition')
result = pst_recognition_pipeline(
'data/test/videos/action_recognition_test_video.mp4')
print('pst recognition results:', result)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_demo_compatibility(self): def test_demo_compatibility(self):
self.compatibility_check() self.compatibility_check()


Loading…
Cancel
Save