Browse Source

[to #42322933] Add cv-action-recongnition-pipeline to maas lib

达摩行为识别合入maas lib
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9134444
master
yongfei.zyf 3 years ago
parent
commit
ace8af9246
13 changed files with 930 additions and 0 deletions
  1. BIN
      data/test/videos/action_recognition_test_video.mp4
  2. +1
    -0
      modelscope/metainfo.py
  3. +0
    -0
      modelscope/models/cv/action_recognition/__init__.py
  4. +91
    -0
      modelscope/models/cv/action_recognition/models.py
  5. +472
    -0
      modelscope/models/cv/action_recognition/tada_convnext.py
  6. +2
    -0
      modelscope/pipelines/builder.py
  7. +1
    -0
      modelscope/pipelines/cv/__init__.py
  8. +65
    -0
      modelscope/pipelines/cv/action_recognition_pipeline.py
  9. +6
    -0
      modelscope/pipelines/outputs.py
  10. +232
    -0
      modelscope/preprocessors/video.py
  11. +1
    -0
      modelscope/utils/constant.py
  12. +1
    -0
      requirements/cv.txt
  13. +58
    -0
      tests/pipelines/test_action_recognition.py

BIN
data/test/videos/action_recognition_test_video.mp4 View File


+ 1
- 0
modelscope/metainfo.py View File

@@ -39,6 +39,7 @@ class Pipelines(object):
image_matting = 'unet-image-matting' image_matting = 'unet-image-matting'
person_image_cartoon = 'unet-person-image-cartoon' person_image_cartoon = 'unet-person-image-cartoon'
ocr_detection = 'resnet18-ocr-detection' ocr_detection = 'resnet18-ocr-detection'
action_recognition = 'TAdaConv_action-recognition'


# nlp tasks # nlp tasks
sentence_similarity = 'sentence-similarity' sentence_similarity = 'sentence-similarity'


+ 0
- 0
modelscope/models/cv/action_recognition/__init__.py View File


+ 91
- 0
modelscope/models/cv/action_recognition/models.py View File

@@ -0,0 +1,91 @@
import torch
import torch.nn as nn

from .tada_convnext import TadaConvNeXt


class BaseVideoModel(nn.Module):
"""
Standard video model.
The model is divided into the backbone and the head, where the backbone
extracts features and the head performs classification.

The backbones can be defined in model/base/backbone.py or anywhere else
as long as the backbone is registered by the BACKBONE_REGISTRY.
The heads can be defined in model/module_zoo/heads/ or anywhere else
as long as the head is registered by the HEAD_REGISTRY.

The registries automatically finds the registered modules and construct
the base video model.
"""

def __init__(self, cfg):
"""
Args:
cfg (Config): global config object.
"""
super(BaseVideoModel, self).__init__()
# the backbone is created according to meta-architectures
# defined in models/base/backbone.py
self.backbone = TadaConvNeXt(cfg)

# the head is created according to the heads
# defined in models/module_zoo/heads
self.head = BaseHead(cfg)

def forward(self, x):
x = self.backbone(x)
x = self.head(x)
return x


class BaseHead(nn.Module):
"""
Constructs base head.
"""

def __init__(
self,
cfg,
):
"""
Args:
cfg (Config): global config object.
"""
super(BaseHead, self).__init__()
self.cfg = cfg
dim = cfg.VIDEO.BACKBONE.NUM_OUT_FEATURES
num_classes = cfg.VIDEO.HEAD.NUM_CLASSES
dropout_rate = cfg.VIDEO.HEAD.DROPOUT_RATE
activation_func = cfg.VIDEO.HEAD.ACTIVATION
self._construct_head(dim, num_classes, dropout_rate, activation_func)

def _construct_head(self, dim, num_classes, dropout_rate, activation_func):
self.global_avg_pool = nn.AdaptiveAvgPool3d(1)

if dropout_rate > 0.0:
self.dropout = nn.Dropout(dropout_rate)

self.out = nn.Linear(dim, num_classes, bias=True)

if activation_func == 'softmax':
self.activation = nn.Softmax(dim=-1)
elif activation_func == 'sigmoid':
self.activation = nn.Sigmoid()
else:
raise NotImplementedError('{} is not supported as an activation'
'function.'.format(activation_func))

def forward(self, x):
if len(x.shape) == 5:
x = self.global_avg_pool(x)
# (N, C, T, H, W) -> (N, T, H, W, C).
x = x.permute((0, 2, 3, 4, 1))
if hasattr(self, 'dropout'):
out = self.dropout(x)
else:
out = x
out = self.out(out)
out = self.activation(out)
out = out.view(out.shape[0], -1)
return out, x.view(x.shape[0], -1)

+ 472
- 0
modelscope/models/cv/action_recognition/tada_convnext.py View File

@@ -0,0 +1,472 @@
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _pair, _triple


def drop_path(x, drop_prob: float = 0., training: bool = False):
"""
From https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py.
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0], ) + (1, ) * (
x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(
shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output


class DropPath(nn.Module):
"""
From https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py.
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""

def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob

def forward(self, x):
return drop_path(x, self.drop_prob, self.training)


class TadaConvNeXt(nn.Module):
r""" ConvNeXt
A PyTorch impl of : `A ConvNet for the 2020s` -
https://arxiv.org/pdf/2201.03545.pdf

Args:
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
drop_path_rate (float): Stochastic depth rate. Default: 0.
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
"""

def __init__(
self, cfg
# in_chans=3, num_classes=1000,
# depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
# layer_scale_init_value=1e-6, head_init_scale=1.,
):
super().__init__()
in_chans = cfg.VIDEO.BACKBONE.NUM_INPUT_CHANNELS
dims = cfg.VIDEO.BACKBONE.NUM_FILTERS
drop_path_rate = cfg.VIDEO.BACKBONE.DROP_PATH
depths = cfg.VIDEO.BACKBONE.DEPTH
layer_scale_init_value = cfg.VIDEO.BACKBONE.LARGE_SCALE_INIT_VALUE
stem_t_kernel_size = cfg.VIDEO.BACKBONE.STEM.T_KERNEL_SIZE if hasattr(
cfg.VIDEO.BACKBONE.STEM, 'T_KERNEL_SIZE') else 2
t_stride = cfg.VIDEO.BACKBONE.STEM.T_STRIDE if hasattr(
cfg.VIDEO.BACKBONE.STEM, 'T_STRIDE') else 2

self.downsample_layers = nn.ModuleList(
) # stem and 3 intermediate downsampling conv layers
stem = nn.Sequential(
nn.Conv3d(
in_chans,
dims[0],
kernel_size=(stem_t_kernel_size, 4, 4),
stride=(t_stride, 4, 4),
padding=((stem_t_kernel_size - 1) // 2, 0, 0)),
LayerNorm(dims[0], eps=1e-6, data_format='channels_first'))
self.downsample_layers.append(stem)
for i in range(3):
downsample_layer = nn.Sequential(
LayerNorm(dims[i], eps=1e-6, data_format='channels_first'),
nn.Conv3d(
dims[i],
dims[i + 1],
kernel_size=(1, 2, 2),
stride=(1, 2, 2)),
)
self.downsample_layers.append(downsample_layer)

self.stages = nn.ModuleList(
) # 4 feature resolution stages, each consisting of multiple residual blocks
dp_rates = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
]
cur = 0
for i in range(4):
stage = nn.Sequential(*[
TAdaConvNeXtBlock(
cfg,
dim=dims[i],
drop_path=dp_rates[cur + j],
layer_scale_init_value=layer_scale_init_value)
for j in range(depths[i])
])
self.stages.append(stage)
cur += depths[i]

self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer

def forward_features(self, x):
for i in range(4):
x = self.downsample_layers[i](x)
x = self.stages[i](x)
return self.norm(x.mean(
[-3, -2, -1])) # global average pooling, (N, C, H, W) -> (N, C)

def forward(self, x):
if isinstance(x, dict):
x = x['video']
x = self.forward_features(x)
return x

def get_num_layers(self):
return 12, 0


class ConvNeXtBlock(nn.Module):
r""" ConvNeXt Block. There are two equivalent implementations:
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
We use (2) as we find it slightly faster in PyTorch

Args:
dim (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
"""

def __init__(self, cfg, dim, drop_path=0., layer_scale_init_value=1e-6):
super().__init__()
self.dwconv = nn.Conv3d(
dim, dim, kernel_size=(1, 7, 7), padding=(0, 3, 3),
groups=dim) # depthwise conv
self.norm = LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(
dim,
4 * dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Linear(4 * dim, dim)
self.gamma = nn.Parameter(
layer_scale_init_value * torch.ones((dim)),
requires_grad=True) if layer_scale_init_value > 0 else None
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()

def forward(self, x):
input = x
x = self.dwconv(x)
x = x.permute(0, 2, 3, 4, 1) # (N, C, T, H, W) -> (N, T, H, W, C)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.permute(0, 4, 1, 2, 3) # (N, T, H, W, C) -> (N, C, T, H, W)

x = input + self.drop_path(x)
return x


class LayerNorm(nn.Module):
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
"""

def __init__(self,
normalized_shape,
eps=1e-6,
data_format='channels_last'):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ['channels_last', 'channels_first']:
raise NotImplementedError
self.normalized_shape = (normalized_shape, )

def forward(self, x):
if self.data_format == 'channels_last':
return F.layer_norm(x, self.normalized_shape, self.weight,
self.bias, self.eps)
elif self.data_format == 'channels_first':
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None, None] * x + self.bias[:, None, None,
None]
return x


class TAdaConvNeXtBlock(nn.Module):
r""" ConvNeXt Block. There are two equivalent implementations:
(1) DwConv -> LayerNorm (channels_fi rst) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
We use (2) as we find it slightly faster in PyTorch

Args:
dim (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
"""

def __init__(self, cfg, dim, drop_path=0., layer_scale_init_value=1e-6):
super().__init__()
layer_scale_init_value = float(layer_scale_init_value)
self.dwconv = TAdaConv2d(
dim,
dim,
kernel_size=(1, 7, 7),
padding=(0, 3, 3),
groups=dim,
cal_dim='cout')
route_func_type = cfg.VIDEO.BACKBONE.BRANCH.ROUTE_FUNC_TYPE
if route_func_type == 'normal':
self.dwconv_rf = RouteFuncMLP(
c_in=dim,
ratio=cfg.VIDEO.BACKBONE.BRANCH.ROUTE_FUNC_R,
kernels=cfg.VIDEO.BACKBONE.BRANCH.ROUTE_FUNC_K,
with_bias_cal=self.dwconv.bias is not None)
elif route_func_type == 'normal_lngelu':
self.dwconv_rf = RouteFuncMLPLnGelu(
c_in=dim,
ratio=cfg.VIDEO.BACKBONE.BRANCH.ROUTE_FUNC_R,
kernels=cfg.VIDEO.BACKBONE.BRANCH.ROUTE_FUNC_K,
with_bias_cal=self.dwconv.bias is not None)
else:
raise ValueError(
'Unknown route_func_type: {}'.format(route_func_type))
self.norm = LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(
dim,
4 * dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Linear(4 * dim, dim)
self.gamma = nn.Parameter(
layer_scale_init_value * torch.ones((dim)),
requires_grad=True) if layer_scale_init_value > 0 else None
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()

def forward(self, x):
input = x
x = self.dwconv(x, self.dwconv_rf(x))
x = x.permute(0, 2, 3, 4, 1) # (N, C, T, H, W) -> (N, T, H, W, C)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.permute(0, 4, 1, 2, 3) # (N, T, H, W, C) -> (N, C, T, H, W)

x = input + self.drop_path(x)
return x


class RouteFuncMLPLnGelu(nn.Module):
"""
The routing function for generating the calibration weights.
"""

def __init__(self,
c_in,
ratio,
kernels,
with_bias_cal=False,
bn_eps=1e-5,
bn_mmt=0.1):
"""
Args:
c_in (int): number of input channels.
ratio (int): reduction ratio for the routing function.
kernels (list): temporal kernel size of the stacked 1D convolutions
"""
super(RouteFuncMLPLnGelu, self).__init__()
self.c_in = c_in
self.with_bias_cal = with_bias_cal
self.avgpool = nn.AdaptiveAvgPool3d((None, 1, 1))
self.globalpool = nn.AdaptiveAvgPool3d(1)
self.g = nn.Conv3d(
in_channels=c_in,
out_channels=c_in,
kernel_size=1,
padding=0,
)
self.a = nn.Conv3d(
in_channels=c_in,
out_channels=int(c_in // ratio),
kernel_size=[kernels[0], 1, 1],
padding=[kernels[0] // 2, 0, 0],
)
# self.bn = nn.BatchNorm3d(int(c_in//ratio), eps=bn_eps, momentum=bn_mmt)
self.ln = LayerNorm(
int(c_in // ratio), eps=1e-6, data_format='channels_first')
self.gelu = nn.GELU()
# self.relu = nn.ReLU(inplace=True)
self.b = nn.Conv3d(
in_channels=int(c_in // ratio),
out_channels=c_in,
kernel_size=[kernels[1], 1, 1],
padding=[kernels[1] // 2, 0, 0],
bias=False)
self.b.skip_init = True
self.b.weight.data.zero_() # to make sure the initial values
# for the output is 1.
if with_bias_cal:
self.b_bias = nn.Conv3d(
in_channels=int(c_in // ratio),
out_channels=c_in,
kernel_size=[kernels[1], 1, 1],
padding=[kernels[1] // 2, 0, 0],
bias=False)
self.b_bias.skip_init = True
self.b_bias.weight.data.zero_() # to make sure the initial values
# for the output is 1.

def forward(self, x):
g = self.globalpool(x)
x = self.avgpool(x)
x = self.a(x + self.g(g))
# x = self.bn(x)
# x = self.relu(x)
x = self.ln(x)
x = self.gelu(x)
if self.with_bias_cal:
return [self.b(x) + 1, self.b_bias(x) + 1]
else:
return self.b(x) + 1


class TAdaConv2d(nn.Module):
"""
Performs temporally adaptive 2D convolution.
Currently, only application on 5D tensors is supported, which makes TAdaConv2d
essentially a 3D convolution with temporal kernel size of 1.
"""

def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
cal_dim='cin'):
super(TAdaConv2d, self).__init__()
"""
Args:
in_channels (int): number of input channels.
out_channels (int): number of output channels.
kernel_size (list): kernel size of TAdaConv2d.
stride (list): stride for the convolution in TAdaConv2d.
padding (list): padding for the convolution in TAdaConv2d.
dilation (list): dilation of the convolution in TAdaConv2d.
groups (int): number of groups for TAdaConv2d.
bias (bool): whether to use bias in TAdaConv2d.
calibration_mode (str): calibrated dimension in TAdaConv2d.
Supported input "cin", "cout".
"""

kernel_size = _triple(kernel_size)
stride = _triple(stride)
padding = _triple(padding)
dilation = _triple(dilation)

assert kernel_size[0] == 1
assert stride[0] == 1
assert padding[0] == 0
assert dilation[0] == 1
assert cal_dim in ['cin', 'cout']

self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.cal_dim = cal_dim

# base weights (W_b)
self.weight = nn.Parameter(
torch.Tensor(1, 1, out_channels, in_channels // groups,
kernel_size[1], kernel_size[2]))
if bias:
self.bias = nn.Parameter(torch.Tensor(1, 1, out_channels))
else:
self.register_parameter('bias', None)

nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias, -bound, bound)

def forward(self, x, alpha):
"""
Args:
x (tensor): feature to perform convolution on.
alpha (tensor): calibration weight for the base weights.
W_t = alpha_t * W_b
"""
if isinstance(alpha, list):
w_alpha, b_alpha = alpha[0], alpha[1]
else:
w_alpha = alpha
b_alpha = None
_, _, c_out, c_in, kh, kw = self.weight.size()
b, c_in, t, h, w = x.size()
x = x.permute(0, 2, 1, 3, 4).reshape(1, -1, h, w)

if self.cal_dim == 'cin':
# w_alpha: B, C, T, H(1), W(1) -> B, T, C, H(1), W(1) -> B, T, 1, C, H(1), W(1)
# corresponding to calibrating the input channel
weight = (w_alpha.permute(0, 2, 1, 3, 4).unsqueeze(2)
* self.weight).reshape(-1, c_in // self.groups, kh, kw)
elif self.cal_dim == 'cout':
# w_alpha: B, C, T, H(1), W(1) -> B, T, C, H(1), W(1) -> B, T, C, 1, H(1), W(1)
# corresponding to calibrating the input channel
weight = (w_alpha.permute(0, 2, 1, 3, 4).unsqueeze(3)
* self.weight).reshape(-1, c_in // self.groups, kh, kw)

bias = None
if self.bias is not None:
if b_alpha is not None:
# b_alpha: B, C, T, H(1), W(1) -> B, T, C, H(1), W(1) -> B, T, C
bias = (b_alpha.permute(0, 2, 1, 3, 4).squeeze()
* self.bias).reshape(-1)
else:
bias = self.bias.repeat(b, t, 1).reshape(-1)
output = F.conv2d(
x,
weight=weight,
bias=bias,
stride=self.stride[1:],
padding=self.padding[1:],
dilation=self.dilation[1:],
groups=self.groups * b * t)

output = output.view(b, t, c_out, output.size(-2),
output.size(-1)).permute(0, 2, 1, 3, 4)

return output

def __repr__(self):
return f'TAdaConv2d({self.in_channels}, {self.out_channels}, kernel_size={self.kernel_size}, ' +\
f"stride={self.stride}, padding={self.padding}, bias={self.bias is not None}, cal_dim=\"{self.cal_dim}\")"

+ 2
- 0
modelscope/pipelines/builder.py View File

@@ -37,6 +37,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/cv_unet_person-image-cartoon_compound-models'), 'damo/cv_unet_person-image-cartoon_compound-models'),
Tasks.ocr_detection: (Pipelines.ocr_detection, Tasks.ocr_detection: (Pipelines.ocr_detection,
'damo/cv_resnet18_ocr-detection-line-level_damo'), 'damo/cv_resnet18_ocr-detection-line-level_damo'),
Tasks.action_recognition: (Pipelines.action_recognition,
'damo/cv_TAdaConv_action-recognition'),
} }






+ 1
- 0
modelscope/pipelines/cv/__init__.py View File

@@ -1,3 +1,4 @@
from .action_recognition_pipeline import ActionRecognitionPipeline
from .image_cartoon_pipeline import ImageCartoonPipeline from .image_cartoon_pipeline import ImageCartoonPipeline
from .image_matting_pipeline import ImageMattingPipeline from .image_matting_pipeline import ImageMattingPipeline
from .ocr_detection_pipeline import OCRDetectionPipeline from .ocr_detection_pipeline import OCRDetectionPipeline

+ 65
- 0
modelscope/pipelines/cv/action_recognition_pipeline.py View File

@@ -0,0 +1,65 @@
import math
import os.path as osp
from typing import Any, Dict

import cv2
import numpy as np
import PIL
import torch

from modelscope.metainfo import Pipelines
from modelscope.models.cv.action_recognition.models import BaseVideoModel
from modelscope.pipelines.base import Input
from modelscope.preprocessors.video import ReadVideoData
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
from ..base import Pipeline
from ..builder import PIPELINES

logger = get_logger()


@PIPELINES.register_module(
Tasks.action_recognition, module_name=Pipelines.action_recognition)
class ActionRecognitionPipeline(Pipeline):

def __init__(self, model: str):
super().__init__(model=model)
model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE)
logger.info(f'loading model from {model_path}')
config_path = osp.join(self.model, ModelFile.CONFIGURATION)
logger.info(f'loading config from {config_path}')
self.cfg = Config.from_file(config_path)
self.infer_model = BaseVideoModel(cfg=self.cfg).cuda()
self.infer_model.eval()
self.infer_model.load_state_dict(torch.load(model_path)['model_state'])
self.label_mapping = self.cfg.label_mapping
logger.info('load model done')

def preprocess(self, input: Input) -> Dict[str, Any]:
if isinstance(input, str):
video_input_data = ReadVideoData(self.cfg, input).cuda()
else:
raise TypeError(f'input should be a str,'
f' but got {type(input)}')
result = {'video_data': video_input_data}
return result

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
pred = self.perform_inference(input['video_data'])
output_label = self.label_mapping[str(pred)]
return {'output_label': output_label}

@torch.no_grad()
def perform_inference(self, data, max_bsz=4):
iter_num = math.ceil(data.size(0) / max_bsz)
preds_list = []
for i in range(iter_num):
preds_list.append(
self.infer_model(data[i * max_bsz:(i + 1) * max_bsz])[0])
pred = torch.cat(preds_list, dim=0)
return pred.mean(dim=0).argmax().item()

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

+ 6
- 0
modelscope/pipelines/outputs.py View File

@@ -45,6 +45,12 @@ TASK_OUTPUTS = {
Tasks.image_matting: ['output_png'], Tasks.image_matting: ['output_png'],
Tasks.image_generation: ['output_png'], Tasks.image_generation: ['output_png'],


# action recognition result for single video
# {
# "output_label": "abseiling"
# }
Tasks.action_recognition: ['output_label'],

# pose estimation result for single sample # pose estimation result for single sample
# { # {
# "poses": np.array with shape [num_pose, num_keypoint, 3], # "poses": np.array with shape [num_pose, num_keypoint, 3],


+ 232
- 0
modelscope/preprocessors/video.py View File

@@ -0,0 +1,232 @@
import math
import os
import random

import decord
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data
import torch.utils.dlpack as dlpack
import torchvision.transforms._transforms_video as transforms
from decord import VideoReader
from torchvision.transforms import Compose


def ReadVideoData(cfg, video_path):
""" simple interface to load video frames from file

Args:
cfg (Config): The global config object.
video_path (str): video file path
"""
data = _decode_video(cfg, video_path)
transform = kinetics400_tranform(cfg)
data_list = []
for i in range(data.size(0)):
for j in range(cfg.TEST.NUM_SPATIAL_CROPS):
transform.transforms[1].set_spatial_index(j)
data_list.append(transform(data[i]))
return torch.stack(data_list, dim=0)


def kinetics400_tranform(cfg):
"""
Configs the transform for the kinetics-400 dataset.
We apply controlled spatial cropping and normalization.
Args:
cfg (Config): The global config object.
"""
resize_video = KineticsResizedCrop(
short_side_range=[cfg.DATA.TEST_SCALE, cfg.DATA.TEST_SCALE],
crop_size=cfg.DATA.TEST_CROP_SIZE,
num_spatial_crops=cfg.TEST.NUM_SPATIAL_CROPS)
std_transform_list = [
transforms.ToTensorVideo(), resize_video,
transforms.NormalizeVideo(
mean=cfg.DATA.MEAN, std=cfg.DATA.STD, inplace=True)
]
return Compose(std_transform_list)


def _interval_based_sampling(vid_length, vid_fps, target_fps, clip_idx,
num_clips, num_frames, interval, minus_interval):
"""
Generates the frame index list using interval based sampling.
Args:
vid_length (int): the length of the whole video (valid selection range).
vid_fps (int): the original video fps
target_fps (int): the normalized video fps
clip_idx (int): -1 for random temporal sampling, and positive values for
sampling specific clip from the video
num_clips (int): the total clips to be sampled from each video.
combined with clip_idx, the sampled video is the "clip_idx-th"
video from "num_clips" videos.
num_frames (int): number of frames in each sampled clips.
interval (int): the interval to sample each frame.
minus_interval (bool): control the end index
Returns:
index (tensor): the sampled frame indexes
"""
if num_frames == 1:
index = [random.randint(0, vid_length - 1)]
else:
# transform FPS
clip_length = num_frames * interval * vid_fps / target_fps

max_idx = max(vid_length - clip_length, 0)
start_idx = clip_idx * math.floor(max_idx / (num_clips - 1))
if minus_interval:
end_idx = start_idx + clip_length - interval
else:
end_idx = start_idx + clip_length - 1

index = torch.linspace(start_idx, end_idx, num_frames)
index = torch.clamp(index, 0, vid_length - 1).long()

return index


def _decode_video_frames_list(cfg, frames_list, vid_fps):
"""
Decodes the video given the numpy frames.
Args:
cfg (Config): The global config object.
frames_list (list): all frames for a video, the frames should be numpy array.
vid_fps (int): the fps of this video.
Returns:
frames (Tensor): video tensor data
"""
assert isinstance(frames_list, list)
num_clips_per_video = cfg.TEST.NUM_ENSEMBLE_VIEWS

frame_list = []
for clip_idx in range(num_clips_per_video):
# for each clip in the video,
# a list is generated before decoding the specified frames from the video
list_ = _interval_based_sampling(
len(frames_list), vid_fps, cfg.DATA.TARGET_FPS, clip_idx,
num_clips_per_video, cfg.DATA.NUM_INPUT_FRAMES,
cfg.DATA.SAMPLING_RATE, cfg.DATA.MINUS_INTERVAL)
frames = None
frames = torch.from_numpy(
np.stack([frames_list[l_index] for l_index in list_.tolist()],
axis=0))
frame_list.append(frames)
frames = torch.stack(frame_list)
if num_clips_per_video == 1:
frames = frames.squeeze(0)

return frames


def _decode_video(cfg, path):
"""
Decodes the video given the numpy frames.
Args:
path (str): video file path.
Returns:
frames (Tensor): video tensor data
"""
vr = VideoReader(path)

num_clips_per_video = cfg.TEST.NUM_ENSEMBLE_VIEWS

frame_list = []
for clip_idx in range(num_clips_per_video):
# for each clip in the video,
# a list is generated before decoding the specified frames from the video
list_ = _interval_based_sampling(
len(vr), vr.get_avg_fps(), cfg.DATA.TARGET_FPS, clip_idx,
num_clips_per_video, cfg.DATA.NUM_INPUT_FRAMES,
cfg.DATA.SAMPLING_RATE, cfg.DATA.MINUS_INTERVAL)
frames = None
if path.endswith('.avi'):
append_list = torch.arange(0, list_[0], 4)
frames = dlpack.from_dlpack(
vr.get_batch(torch.cat([append_list,
list_])).to_dlpack()).clone()
frames = frames[append_list.shape[0]:]
else:
frames = dlpack.from_dlpack(
vr.get_batch(list_).to_dlpack()).clone()
frame_list.append(frames)
frames = torch.stack(frame_list)
if num_clips_per_video == 1:
frames = frames.squeeze(0)
del vr
return frames


class KineticsResizedCrop(object):
"""Perform resize and crop for kinetics-400 dataset
Args:
short_side_range (list): The length of short side range. In inference, this shoudle be [256, 256]
crop_size (int): The cropped size for frames.
num_spatial_crops (int): The number of the cropped spatial regions in each video.
"""

def __init__(
self,
short_side_range,
crop_size,
num_spatial_crops=1,
):
self.idx = -1
self.short_side_range = short_side_range
self.crop_size = int(crop_size)
self.num_spatial_crops = num_spatial_crops

def _get_controlled_crop(self, clip):
"""Perform controlled crop for video tensor.
Args:
clip (Tensor): the video data, the shape is [T, C, H, W]
"""
_, _, clip_height, clip_width = clip.shape

length = self.short_side_range[0]

if clip_height < clip_width:
new_clip_height = int(length)
new_clip_width = int(clip_width / clip_height * new_clip_height)
new_clip = torch.nn.functional.interpolate(
clip, size=(new_clip_height, new_clip_width), mode='bilinear')
else:
new_clip_width = int(length)
new_clip_height = int(clip_height / clip_width * new_clip_width)
new_clip = torch.nn.functional.interpolate(
clip, size=(new_clip_height, new_clip_width), mode='bilinear')
x_max = int(new_clip_width - self.crop_size)
y_max = int(new_clip_height - self.crop_size)
if self.num_spatial_crops == 1:
x = x_max // 2
y = y_max // 2
elif self.num_spatial_crops == 3:
if self.idx == 0:
if new_clip_width == length:
x = x_max // 2
y = 0
elif new_clip_height == length:
x = 0
y = y_max // 2
elif self.idx == 1:
x = x_max // 2
y = y_max // 2
elif self.idx == 2:
if new_clip_width == length:
x = x_max // 2
y = y_max
elif new_clip_height == length:
x = x_max
y = y_max // 2
return new_clip[:, :, y:y + self.crop_size, x:x + self.crop_size]

def set_spatial_index(self, idx):
"""Set the spatial cropping index for controlled cropping..
Args:
idx (int): the spatial index. The value should be in [0, 1, 2], means [left, center, right], respectively.
"""
self.idx = idx

def __call__(self, clip):
return self._get_controlled_crop(clip)

+ 1
- 0
modelscope/utils/constant.py View File

@@ -29,6 +29,7 @@ class Tasks(object):
image_generation = 'image-generation' image_generation = 'image-generation'
image_matting = 'image-matting' image_matting = 'image-matting'
ocr_detection = 'ocr-detection' ocr_detection = 'ocr-detection'
action_recognition = 'action-recognition'


# nlp tasks # nlp tasks
word_segmentation = 'word-segmentation' word_segmentation = 'word-segmentation'


+ 1
- 0
requirements/cv.txt View File

@@ -1,2 +1,3 @@
decord>=0.6.0
easydict easydict
tf_slim tf_slim

+ 58
- 0
tests/pipelines/test_action_recognition.py View File

@@ -0,0 +1,58 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# !/usr/bin/env python
import os.path as osp
import shutil
import tempfile
import unittest

import cv2

from modelscope.fileio import File
from modelscope.pipelines import pipeline
from modelscope.pydatasets import PyDataset
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.test_utils import test_level


class ActionRecognitionTest(unittest.TestCase):

def setUp(self) -> None:
self.model_id = 'damo/cv_TAdaConv_action-recognition'

@unittest.skip('deprecated, download model from model hub instead')
def test_run_with_direct_file_download(self):
model_path = 'https://aquila2-online-models.oss-cn-shanghai.aliyuncs.com/maas_test/pytorch_model.pt'
config_path = 'https://aquila2-online-models.oss-cn-shanghai.aliyuncs.com/maas_test/configuration.json'
with tempfile.TemporaryDirectory() as tmp_dir:
model_file = osp.join(tmp_dir, ModelFile.TORCH_MODEL_FILE)
with open(model_file, 'wb') as ofile1:
ofile1.write(File.read(model_path))
config_file = osp.join(tmp_dir, ModelFile.CONFIGURATION)
with open(config_file, 'wb') as ofile2:
ofile2.write(File.read(config_path))
recognition_pipeline = pipeline(
Tasks.action_recognition, model=tmp_dir)
result = recognition_pipeline(
'data/test/videos/action_recognition_test_video.mp4')
print(f'recognition output: {result}.')

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub(self):
recognition_pipeline = pipeline(
Tasks.action_recognition, model=self.model_id)
result = recognition_pipeline(
'data/test/videos/action_recognition_test_video.mp4')

print(f'recognition output: {result}.')

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub_default_model(self):
recognition_pipeline = pipeline(Tasks.action_recognition)
result = recognition_pipeline(
'data/test/videos/action_recognition_test_video.mp4')

print(f'recognition output: {result}.')


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save