|
@@ -7,7 +7,8 @@ from typing import Any, Dict |
|
|
import torch |
|
|
import torch |
|
|
|
|
|
|
|
|
from modelscope.metainfo import Pipelines |
|
|
from modelscope.metainfo import Pipelines |
|
|
from modelscope.models.cv.action_recognition import BaseVideoModel |
|
|
|
|
|
|
|
|
from modelscope.models.cv.action_recognition import (BaseVideoModel, |
|
|
|
|
|
PatchShiftTransformer) |
|
|
from modelscope.outputs import OutputKeys |
|
|
from modelscope.outputs import OutputKeys |
|
|
from modelscope.pipelines.base import Input, Pipeline |
|
|
from modelscope.pipelines.base import Input, Pipeline |
|
|
from modelscope.pipelines.builder import PIPELINES |
|
|
from modelscope.pipelines.builder import PIPELINES |
|
@@ -69,3 +70,54 @@ class ActionRecognitionPipeline(Pipeline): |
|
|
|
|
|
|
|
|
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: |
|
|
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: |
|
|
return inputs |
|
|
return inputs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@PIPELINES.register_module( |
|
|
|
|
|
Tasks.action_recognition, module_name=Pipelines.pst_action_recognition) |
|
|
|
|
|
class PSTActionRecognitionPipeline(Pipeline): |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, model: str, **kwargs): |
|
|
|
|
|
""" |
|
|
|
|
|
use `model` to create a PST action recognition pipeline for prediction |
|
|
|
|
|
Args: |
|
|
|
|
|
model: model id on modelscope hub. |
|
|
|
|
|
""" |
|
|
|
|
|
super().__init__(model=model, **kwargs) |
|
|
|
|
|
model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) |
|
|
|
|
|
logger.info(f'loading model from {model_path}') |
|
|
|
|
|
config_path = osp.join(self.model, ModelFile.CONFIGURATION) |
|
|
|
|
|
logger.info(f'loading config from {config_path}') |
|
|
|
|
|
self.cfg = Config.from_file(config_path) |
|
|
|
|
|
self.infer_model = PatchShiftTransformer(model).to(self.device) |
|
|
|
|
|
self.infer_model.eval() |
|
|
|
|
|
self.infer_model.load_state_dict( |
|
|
|
|
|
torch.load(model_path, map_location=self.device)['state_dict']) |
|
|
|
|
|
self.label_mapping = self.cfg.label_mapping |
|
|
|
|
|
logger.info('load model done') |
|
|
|
|
|
|
|
|
|
|
|
def preprocess(self, input: Input) -> Dict[str, Any]: |
|
|
|
|
|
if isinstance(input, str): |
|
|
|
|
|
video_input_data = ReadVideoData(self.cfg, input).to(self.device) |
|
|
|
|
|
else: |
|
|
|
|
|
raise TypeError(f'input should be a str,' |
|
|
|
|
|
f' but got {type(input)}') |
|
|
|
|
|
result = {'video_data': video_input_data} |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: |
|
|
|
|
|
pred = self.perform_inference(input['video_data']) |
|
|
|
|
|
output_label = self.label_mapping[str(pred)] |
|
|
|
|
|
return {OutputKeys.LABELS: output_label} |
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
|
|
|
def perform_inference(self, data, max_bsz=4): |
|
|
|
|
|
iter_num = math.ceil(data.size(0) / max_bsz) |
|
|
|
|
|
preds_list = [] |
|
|
|
|
|
for i in range(iter_num): |
|
|
|
|
|
preds_list.append( |
|
|
|
|
|
self.infer_model(data[i * max_bsz:(i + 1) * max_bsz])) |
|
|
|
|
|
pred = torch.cat(preds_list, dim=0) |
|
|
|
|
|
return pred.mean(dim=0).argmax().item() |
|
|
|
|
|
|
|
|
|
|
|
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: |
|
|
|
|
|
return inputs |