Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9969472master
@@ -99,6 +99,7 @@ class Pipelines(object): | |||
animal_recognition = 'resnet101-animal-recognition' | |||
general_recognition = 'resnet101-general-recognition' | |||
cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' | |||
hicossl_video_embedding = 'hicossl-s3dg-video_embedding' | |||
body_2d_keypoints = 'hrnetv2w32_body-2d-keypoints_image' | |||
body_3d_keypoints = 'canonical_body-3d-keypoints_video' | |||
human_detection = 'resnet18-human-detection' | |||
@@ -1,5 +1,6 @@ | |||
import torch.nn as nn | |||
from .s3dg import Inception3D | |||
from .tada_convnext import TadaConvNeXt | |||
@@ -26,11 +27,25 @@ class BaseVideoModel(nn.Module): | |||
super(BaseVideoModel, self).__init__() | |||
# the backbone is created according to meta-architectures | |||
# defined in models/base/backbone.py | |||
self.backbone = TadaConvNeXt(cfg) | |||
if cfg.MODEL.NAME == 'ConvNeXt_tiny': | |||
self.backbone = TadaConvNeXt(cfg) | |||
elif cfg.MODEL.NAME == 'S3DG': | |||
self.backbone = Inception3D(cfg) | |||
else: | |||
error_str = 'backbone {} is not supported, ConvNeXt_tiny or S3DG is supported'.format( | |||
cfg.MODEL.NAME) | |||
raise NotImplementedError(error_str) | |||
# the head is created according to the heads | |||
# defined in models/module_zoo/heads | |||
self.head = BaseHead(cfg) | |||
if cfg.VIDEO.HEAD.NAME == 'BaseHead': | |||
self.head = BaseHead(cfg) | |||
elif cfg.VIDEO.HEAD.NAME == 'AvgHead': | |||
self.head = AvgHead(cfg) | |||
else: | |||
error_str = 'head {} is not supported, BaseHead or AvgHead is supported'.format( | |||
cfg.VIDEO.HEAD.NAME) | |||
raise NotImplementedError(error_str) | |||
def forward(self, x): | |||
x = self.backbone(x) | |||
@@ -88,3 +103,29 @@ class BaseHead(nn.Module): | |||
out = self.activation(out) | |||
out = out.view(out.shape[0], -1) | |||
return out, x.view(x.shape[0], -1) | |||
class AvgHead(nn.Module): | |||
""" | |||
Constructs base head. | |||
""" | |||
def __init__( | |||
self, | |||
cfg, | |||
): | |||
""" | |||
Args: | |||
cfg (Config): global config object. | |||
""" | |||
super(AvgHead, self).__init__() | |||
self.cfg = cfg | |||
self.global_avg_pool = nn.AdaptiveAvgPool3d(1) | |||
def forward(self, x): | |||
if len(x.shape) == 5: | |||
x = self.global_avg_pool(x) | |||
# (N, C, T, H, W) -> (N, T, H, W, C). | |||
x = x.permute((0, 2, 3, 4, 1)) | |||
out = x.view(x.shape[0], -1) | |||
return out, x.view(x.shape[0], -1) |
@@ -0,0 +1,301 @@ | |||
import torch | |||
import torch.nn as nn | |||
class InceptionBaseConv3D(nn.Module): | |||
""" | |||
Constructs basic inception 3D conv. | |||
Modified from https://github.com/TengdaHan/CoCLR/blob/main/backbone/s3dg.py. | |||
""" | |||
def __init__(self, | |||
cfg, | |||
in_planes, | |||
out_planes, | |||
kernel_size, | |||
stride, | |||
padding=0): | |||
super(InceptionBaseConv3D, self).__init__() | |||
self.conv = nn.Conv3d( | |||
in_planes, | |||
out_planes, | |||
kernel_size=kernel_size, | |||
stride=stride, | |||
padding=padding, | |||
bias=False) | |||
self.bn = nn.BatchNorm3d(out_planes) | |||
self.relu = nn.ReLU(inplace=True) | |||
# init | |||
self.conv.weight.data.normal_( | |||
mean=0, std=0.01) # original s3d is truncated normal within 2 std | |||
self.bn.weight.data.fill_(1) | |||
self.bn.bias.data.zero_() | |||
def forward(self, x): | |||
x = self.conv(x) | |||
x = self.bn(x) | |||
x = self.relu(x) | |||
return x | |||
class InceptionBlock3D(nn.Module): | |||
""" | |||
Element constructing the S3D/S3DG. | |||
See models/base/backbone.py L99-186. | |||
Modifed from https://github.com/TengdaHan/CoCLR/blob/main/backbone/s3dg.py. | |||
""" | |||
def __init__(self, cfg, in_planes, out_planes): | |||
super(InceptionBlock3D, self).__init__() | |||
_gating = cfg.VIDEO.BACKBONE.BRANCH.GATING | |||
assert len(out_planes) == 6 | |||
assert isinstance(out_planes, list) | |||
[ | |||
num_out_0_0a, num_out_1_0a, num_out_1_0b, num_out_2_0a, | |||
num_out_2_0b, num_out_3_0b | |||
] = out_planes | |||
self.branch0 = nn.Sequential( | |||
InceptionBaseConv3D( | |||
cfg, in_planes, num_out_0_0a, kernel_size=1, stride=1), ) | |||
self.branch1 = nn.Sequential( | |||
InceptionBaseConv3D( | |||
cfg, in_planes, num_out_1_0a, kernel_size=1, stride=1), | |||
STConv3d( | |||
cfg, | |||
num_out_1_0a, | |||
num_out_1_0b, | |||
kernel_size=3, | |||
stride=1, | |||
padding=1), | |||
) | |||
self.branch2 = nn.Sequential( | |||
InceptionBaseConv3D( | |||
cfg, in_planes, num_out_2_0a, kernel_size=1, stride=1), | |||
STConv3d( | |||
cfg, | |||
num_out_2_0a, | |||
num_out_2_0b, | |||
kernel_size=3, | |||
stride=1, | |||
padding=1), | |||
) | |||
self.branch3 = nn.Sequential( | |||
nn.MaxPool3d(kernel_size=(3, 3, 3), stride=1, padding=1), | |||
InceptionBaseConv3D( | |||
cfg, in_planes, num_out_3_0b, kernel_size=1, stride=1), | |||
) | |||
self.out_channels = sum( | |||
[num_out_0_0a, num_out_1_0b, num_out_2_0b, num_out_3_0b]) | |||
self.gating = _gating | |||
if _gating: | |||
self.gating_b0 = SelfGating(num_out_0_0a) | |||
self.gating_b1 = SelfGating(num_out_1_0b) | |||
self.gating_b2 = SelfGating(num_out_2_0b) | |||
self.gating_b3 = SelfGating(num_out_3_0b) | |||
def forward(self, x): | |||
x0 = self.branch0(x) | |||
x1 = self.branch1(x) | |||
x2 = self.branch2(x) | |||
x3 = self.branch3(x) | |||
if self.gating: | |||
x0 = self.gating_b0(x0) | |||
x1 = self.gating_b1(x1) | |||
x2 = self.gating_b2(x2) | |||
x3 = self.gating_b3(x3) | |||
out = torch.cat((x0, x1, x2, x3), 1) | |||
return out | |||
class SelfGating(nn.Module): | |||
def __init__(self, input_dim): | |||
super(SelfGating, self).__init__() | |||
self.fc = nn.Linear(input_dim, input_dim) | |||
def forward(self, input_tensor): | |||
"""Feature gating as used in S3D-G""" | |||
spatiotemporal_average = torch.mean(input_tensor, dim=[2, 3, 4]) | |||
weights = self.fc(spatiotemporal_average) | |||
weights = torch.sigmoid(weights) | |||
return weights[:, :, None, None, None] * input_tensor | |||
class STConv3d(nn.Module): | |||
""" | |||
Element constructing the S3D/S3DG. | |||
See models/base/backbone.py L99-186. | |||
Modifed from https://github.com/TengdaHan/CoCLR/blob/main/backbone/s3dg.py. | |||
""" | |||
def __init__(self, | |||
cfg, | |||
in_planes, | |||
out_planes, | |||
kernel_size, | |||
stride, | |||
padding=0): | |||
super(STConv3d, self).__init__() | |||
if isinstance(stride, tuple): | |||
t_stride = stride[0] | |||
stride = stride[-1] | |||
else: # int | |||
t_stride = stride | |||
self.bn_mmt = cfg.BN.MOMENTUM | |||
self.bn_eps = float(cfg.BN.EPS) | |||
self._construct_branch(cfg, in_planes, out_planes, kernel_size, stride, | |||
t_stride, padding) | |||
def _construct_branch(self, | |||
cfg, | |||
in_planes, | |||
out_planes, | |||
kernel_size, | |||
stride, | |||
t_stride, | |||
padding=0): | |||
self.conv1 = nn.Conv3d( | |||
in_planes, | |||
out_planes, | |||
kernel_size=(1, kernel_size, kernel_size), | |||
stride=(1, stride, stride), | |||
padding=(0, padding, padding), | |||
bias=False) | |||
self.conv2 = nn.Conv3d( | |||
out_planes, | |||
out_planes, | |||
kernel_size=(kernel_size, 1, 1), | |||
stride=(t_stride, 1, 1), | |||
padding=(padding, 0, 0), | |||
bias=False) | |||
self.bn1 = nn.BatchNorm3d( | |||
out_planes, eps=self.bn_eps, momentum=self.bn_mmt) | |||
self.bn2 = nn.BatchNorm3d( | |||
out_planes, eps=self.bn_eps, momentum=self.bn_mmt) | |||
self.relu = nn.ReLU(inplace=True) | |||
# init | |||
self.conv1.weight.data.normal_( | |||
mean=0, std=0.01) # original s3d is truncated normal within 2 std | |||
self.conv2.weight.data.normal_( | |||
mean=0, std=0.01) # original s3d is truncated normal within 2 std | |||
self.bn1.weight.data.fill_(1) | |||
self.bn1.bias.data.zero_() | |||
self.bn2.weight.data.fill_(1) | |||
self.bn2.bias.data.zero_() | |||
def forward(self, x): | |||
x = self.conv1(x) | |||
x = self.bn1(x) | |||
x = self.relu(x) | |||
x = self.conv2(x) | |||
x = self.bn2(x) | |||
x = self.relu(x) | |||
return x | |||
class Inception3D(nn.Module): | |||
""" | |||
Backbone architecture for I3D/S3DG. | |||
Modifed from https://github.com/TengdaHan/CoCLR/blob/main/backbone/s3dg.py. | |||
""" | |||
def __init__(self, cfg): | |||
""" | |||
Args: | |||
cfg (Config): global config object. | |||
""" | |||
super(Inception3D, self).__init__() | |||
_input_channel = cfg.DATA.NUM_INPUT_CHANNELS | |||
self._construct_backbone(cfg, _input_channel) | |||
def _construct_backbone(self, cfg, input_channel): | |||
# ------------------- Block 1 ------------------- | |||
self.Conv_1a = STConv3d( | |||
cfg, input_channel, 64, kernel_size=7, stride=2, padding=3) | |||
self.block1 = nn.Sequential(self.Conv_1a) # (64, 32, 112, 112) | |||
# ------------------- Block 2 ------------------- | |||
self.MaxPool_2a = nn.MaxPool3d( | |||
kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) | |||
self.Conv_2b = InceptionBaseConv3D( | |||
cfg, 64, 64, kernel_size=1, stride=1) | |||
self.Conv_2c = STConv3d( | |||
cfg, 64, 192, kernel_size=3, stride=1, padding=1) | |||
self.block2 = nn.Sequential( | |||
self.MaxPool_2a, # (64, 32, 56, 56) | |||
self.Conv_2b, # (64, 32, 56, 56) | |||
self.Conv_2c) # (192, 32, 56, 56) | |||
# ------------------- Block 3 ------------------- | |||
self.MaxPool_3a = nn.MaxPool3d( | |||
kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) | |||
self.Mixed_3b = InceptionBlock3D( | |||
cfg, in_planes=192, out_planes=[64, 96, 128, 16, 32, 32]) | |||
self.Mixed_3c = InceptionBlock3D( | |||
cfg, in_planes=256, out_planes=[128, 128, 192, 32, 96, 64]) | |||
self.block3 = nn.Sequential( | |||
self.MaxPool_3a, # (192, 32, 28, 28) | |||
self.Mixed_3b, # (256, 32, 28, 28) | |||
self.Mixed_3c) # (480, 32, 28, 28) | |||
# ------------------- Block 4 ------------------- | |||
self.MaxPool_4a = nn.MaxPool3d( | |||
kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)) | |||
self.Mixed_4b = InceptionBlock3D( | |||
cfg, in_planes=480, out_planes=[192, 96, 208, 16, 48, 64]) | |||
self.Mixed_4c = InceptionBlock3D( | |||
cfg, in_planes=512, out_planes=[160, 112, 224, 24, 64, 64]) | |||
self.Mixed_4d = InceptionBlock3D( | |||
cfg, in_planes=512, out_planes=[128, 128, 256, 24, 64, 64]) | |||
self.Mixed_4e = InceptionBlock3D( | |||
cfg, in_planes=512, out_planes=[112, 144, 288, 32, 64, 64]) | |||
self.Mixed_4f = InceptionBlock3D( | |||
cfg, in_planes=528, out_planes=[256, 160, 320, 32, 128, 128]) | |||
self.block4 = nn.Sequential( | |||
self.MaxPool_4a, # (480, 16, 14, 14) | |||
self.Mixed_4b, # (512, 16, 14, 14) | |||
self.Mixed_4c, # (512, 16, 14, 14) | |||
self.Mixed_4d, # (512, 16, 14, 14) | |||
self.Mixed_4e, # (528, 16, 14, 14) | |||
self.Mixed_4f) # (832, 16, 14, 14) | |||
# ------------------- Block 5 ------------------- | |||
self.MaxPool_5a = nn.MaxPool3d( | |||
kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 0, 0)) | |||
self.Mixed_5b = InceptionBlock3D( | |||
cfg, in_planes=832, out_planes=[256, 160, 320, 32, 128, 128]) | |||
self.Mixed_5c = InceptionBlock3D( | |||
cfg, in_planes=832, out_planes=[384, 192, 384, 48, 128, 128]) | |||
self.block5 = nn.Sequential( | |||
self.MaxPool_5a, # (832, 8, 7, 7) | |||
self.Mixed_5b, # (832, 8, 7, 7) | |||
self.Mixed_5c) # (1024, 8, 7, 7) | |||
def forward(self, x): | |||
if isinstance(x, dict): | |||
x = x['video'] | |||
x = self.block1(x) | |||
x = self.block2(x) | |||
x = self.block3(x) | |||
x = self.block4(x) | |||
x = self.block5(x) | |||
return x |
@@ -9,6 +9,7 @@ if TYPE_CHECKING: | |||
from .body_2d_keypoints_pipeline import Body2DKeypointsPipeline | |||
from .body_3d_keypoints_pipeline import Body3DKeypointsPipeline | |||
from .cmdssl_video_embedding_pipeline import CMDSSLVideoEmbeddingPipeline | |||
from .hicossl_video_embedding_pipeline import HICOSSLVideoEmbeddingPipeline | |||
from .crowd_counting_pipeline import CrowdCountingPipeline | |||
from .image_detection_pipeline import ImageDetectionPipeline | |||
from .image_salient_detection_pipeline import ImageSalientDetectionPipeline | |||
@@ -51,6 +52,7 @@ else: | |||
'body_2d_keypoints_pipeline': ['Body2DKeypointsPipeline'], | |||
'body_3d_keypoints_pipeline': ['Body3DKeypointsPipeline'], | |||
'cmdssl_video_embedding_pipeline': ['CMDSSLVideoEmbeddingPipeline'], | |||
'hicossl_video_embedding_pipeline': ['HICOSSLVideoEmbeddingPipeline'], | |||
'crowd_counting_pipeline': ['CrowdCountingPipeline'], | |||
'image_detection_pipeline': ['ImageDetectionPipeline'], | |||
'image_salient_detection_pipeline': ['ImageSalientDetectionPipeline'], | |||
@@ -33,6 +33,7 @@ class ActionRecognitionPipeline(Pipeline): | |||
config_path = osp.join(self.model, ModelFile.CONFIGURATION) | |||
logger.info(f'loading config from {config_path}') | |||
self.cfg = Config.from_file(config_path) | |||
self.infer_model = BaseVideoModel(cfg=self.cfg).to(self.device) | |||
self.infer_model.eval() | |||
self.infer_model.load_state_dict( | |||
@@ -0,0 +1,75 @@ | |||
import math | |||
import os.path as osp | |||
from typing import Any, Dict | |||
import torch | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.models.cv.action_recognition import BaseVideoModel | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.base import Input, Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import ReadVideoData | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
@PIPELINES.register_module( | |||
Tasks.video_embedding, module_name=Pipelines.hicossl_video_embedding) | |||
class HICOSSLVideoEmbeddingPipeline(Pipeline): | |||
def __init__(self, model: str, **kwargs): | |||
""" | |||
use `model` to create a hicossl video embedding pipeline for prediction | |||
Args: | |||
model: model id on modelscope hub. | |||
""" | |||
super().__init__(model=model, **kwargs) | |||
model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) | |||
logger.info(f'loading model from {model_path}') | |||
config_path = osp.join(self.model, ModelFile.CONFIGURATION) | |||
logger.info(f'loading config from {config_path}') | |||
self.cfg = Config.from_file(config_path) | |||
self.infer_model = BaseVideoModel(cfg=self.cfg).to(self.device) | |||
self.infer_model.eval() | |||
self.infer_model.load_state_dict( | |||
torch.load(model_path, map_location=self.device)['model_state'], | |||
strict=False) | |||
logger.info('load model done') | |||
def preprocess(self, input: Input) -> Dict[str, Any]: | |||
if isinstance(input, str): | |||
video_input_data = ReadVideoData( | |||
self.cfg, input, num_temporal_views_override=1).to(self.device) | |||
else: | |||
raise TypeError(f'input should be a str,' | |||
f' but got {type(input)}') | |||
result = {'video_data': video_input_data} | |||
return result | |||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
feature = self.perform_inference(input['video_data']) | |||
return {OutputKeys.VIDEO_EMBEDDING: feature.data.cpu().numpy()} | |||
@torch.no_grad() | |||
def perform_inference(self, data, max_bsz=4): | |||
""" Perform feature extracting for a given video | |||
Args: | |||
model (BaseVideoModel): video model with loadded state dict. | |||
max_bsz (int): the maximum batch size, limited by GPU memory. | |||
Returns: | |||
pred (Tensor): the extracted features for input video clips. | |||
""" | |||
iter_num = math.ceil(data.size(0) / max_bsz) | |||
preds_list = [] | |||
for i in range(iter_num): | |||
preds_list.append( | |||
self.infer_model(data[i * max_bsz:(i + 1) * max_bsz])[0]) | |||
pred = torch.cat(preds_list, dim=0) | |||
return pred | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
return inputs |
@@ -16,34 +16,49 @@ from .base import Preprocessor | |||
from .builder import PREPROCESSORS | |||
def ReadVideoData(cfg, video_path): | |||
def ReadVideoData(cfg, | |||
video_path, | |||
num_spatial_crops_override=None, | |||
num_temporal_views_override=None): | |||
""" simple interface to load video frames from file | |||
Args: | |||
cfg (Config): The global config object. | |||
video_path (str): video file path | |||
num_spatial_crops_override (int): the spatial crops per clip | |||
num_temporal_views_override (int): the temporal clips per video | |||
Returns: | |||
data (Tensor): the normalized video clips for model inputs | |||
""" | |||
data = _decode_video(cfg, video_path) | |||
transform = kinetics400_tranform(cfg) | |||
data = _decode_video(cfg, video_path, num_temporal_views_override) | |||
if num_spatial_crops_override is not None: | |||
num_spatial_crops = num_spatial_crops_override | |||
transform = kinetics400_tranform(cfg, num_spatial_crops_override) | |||
else: | |||
num_spatial_crops = cfg.TEST.NUM_SPATIAL_CROPS | |||
transform = kinetics400_tranform(cfg, cfg.TEST.NUM_SPATIAL_CROPS) | |||
data_list = [] | |||
for i in range(data.size(0)): | |||
for j in range(cfg.TEST.NUM_SPATIAL_CROPS): | |||
for j in range(num_spatial_crops): | |||
transform.transforms[1].set_spatial_index(j) | |||
data_list.append(transform(data[i])) | |||
return torch.stack(data_list, dim=0) | |||
def kinetics400_tranform(cfg): | |||
def kinetics400_tranform(cfg, num_spatial_crops): | |||
""" | |||
Configs the transform for the kinetics-400 dataset. | |||
We apply controlled spatial cropping and normalization. | |||
Args: | |||
cfg (Config): The global config object. | |||
num_spatial_crops (int): the spatial crops per clip | |||
Returns: | |||
transform_function (Compose): the transform function for input clips | |||
""" | |||
resize_video = KineticsResizedCrop( | |||
short_side_range=[cfg.DATA.TEST_SCALE, cfg.DATA.TEST_SCALE], | |||
crop_size=cfg.DATA.TEST_CROP_SIZE, | |||
num_spatial_crops=cfg.TEST.NUM_SPATIAL_CROPS) | |||
num_spatial_crops=num_spatial_crops) | |||
std_transform_list = [ | |||
transforms.ToTensorVideo(), resize_video, | |||
transforms.NormalizeVideo( | |||
@@ -60,17 +75,17 @@ def _interval_based_sampling(vid_length, vid_fps, target_fps, clip_idx, | |||
vid_length (int): the length of the whole video (valid selection range). | |||
vid_fps (int): the original video fps | |||
target_fps (int): the normalized video fps | |||
clip_idx (int): -1 for random temporal sampling, and positive values for | |||
sampling specific clip from the video | |||
clip_idx (int): -1 for random temporal sampling, and positive values for sampling specific | |||
clip from the video | |||
num_clips (int): the total clips to be sampled from each video. | |||
combined with clip_idx, the sampled video is the "clip_idx-th" | |||
video from "num_clips" videos. | |||
combined with clip_idx, the sampled video is the "clip_idx-th" video from | |||
"num_clips" videos. | |||
num_frames (int): number of frames in each sampled clips. | |||
interval (int): the interval to sample each frame. | |||
minus_interval (bool): control the end index | |||
Returns: | |||
index (tensor): the sampled frame indexes | |||
""" | |||
""" | |||
if num_frames == 1: | |||
index = [random.randint(0, vid_length - 1)] | |||
else: | |||
@@ -78,7 +93,10 @@ def _interval_based_sampling(vid_length, vid_fps, target_fps, clip_idx, | |||
clip_length = num_frames * interval * vid_fps / target_fps | |||
max_idx = max(vid_length - clip_length, 0) | |||
start_idx = clip_idx * math.floor(max_idx / (num_clips - 1)) | |||
if num_clips == 1: | |||
start_idx = max_idx / 2 | |||
else: | |||
start_idx = clip_idx * math.floor(max_idx / (num_clips - 1)) | |||
if minus_interval: | |||
end_idx = start_idx + clip_length - interval | |||
else: | |||
@@ -90,59 +108,79 @@ def _interval_based_sampling(vid_length, vid_fps, target_fps, clip_idx, | |||
return index | |||
def _decode_video_frames_list(cfg, frames_list, vid_fps): | |||
def _decode_video_frames_list(cfg, | |||
frames_list, | |||
vid_fps, | |||
num_temporal_views_override=None): | |||
""" | |||
Decodes the video given the numpy frames. | |||
Args: | |||
cfg (Config): The global config object. | |||
frames_list (list): all frames for a video, the frames should be numpy array. | |||
vid_fps (int): the fps of this video. | |||
num_temporal_views_override (int): the temporal clips per video | |||
Returns: | |||
frames (Tensor): video tensor data | |||
""" | |||
assert isinstance(frames_list, list) | |||
num_clips_per_video = cfg.TEST.NUM_ENSEMBLE_VIEWS | |||
if num_temporal_views_override is not None: | |||
num_clips_per_video = num_temporal_views_override | |||
else: | |||
num_clips_per_video = cfg.TEST.NUM_ENSEMBLE_VIEWS | |||
frame_list = [] | |||
for clip_idx in range(num_clips_per_video): | |||
# for each clip in the video, | |||
# a list is generated before decoding the specified frames from the video | |||
list_ = _interval_based_sampling( | |||
len(frames_list), vid_fps, cfg.DATA.TARGET_FPS, clip_idx, | |||
num_clips_per_video, cfg.DATA.NUM_INPUT_FRAMES, | |||
cfg.DATA.SAMPLING_RATE, cfg.DATA.MINUS_INTERVAL) | |||
len(frames_list), | |||
vid_fps, | |||
cfg.DATA.TARGET_FPS, | |||
clip_idx, | |||
num_clips_per_video, | |||
cfg.DATA.NUM_INPUT_FRAMES, | |||
cfg.DATA.SAMPLING_RATE, | |||
cfg.DATA.MINUS_INTERVAL, | |||
) | |||
frames = None | |||
frames = torch.from_numpy( | |||
np.stack([frames_list[l_index] for l_index in list_.tolist()], | |||
axis=0)) | |||
np.stack([frames_list[index] for index in list_.tolist()], axis=0)) | |||
frame_list.append(frames) | |||
frames = torch.stack(frame_list) | |||
if num_clips_per_video == 1: | |||
frames = frames.squeeze(0) | |||
del vr | |||
return frames | |||
def _decode_video(cfg, path): | |||
def _decode_video(cfg, path, num_temporal_views_override=None): | |||
""" | |||
Decodes the video given the numpy frames. | |||
Args: | |||
cfg (Config): The global config object. | |||
path (str): video file path. | |||
num_temporal_views_override (int): the temporal clips per video | |||
Returns: | |||
frames (Tensor): video tensor data | |||
""" | |||
vr = VideoReader(path) | |||
num_clips_per_video = cfg.TEST.NUM_ENSEMBLE_VIEWS | |||
if num_temporal_views_override is not None: | |||
num_clips_per_video = num_temporal_views_override | |||
else: | |||
num_clips_per_video = cfg.TEST.NUM_ENSEMBLE_VIEWS | |||
frame_list = [] | |||
for clip_idx in range(num_clips_per_video): | |||
# for each clip in the video, | |||
# a list is generated before decoding the specified frames from the video | |||
list_ = _interval_based_sampling( | |||
len(vr), vr.get_avg_fps(), cfg.DATA.TARGET_FPS, clip_idx, | |||
num_clips_per_video, cfg.DATA.NUM_INPUT_FRAMES, | |||
cfg.DATA.SAMPLING_RATE, cfg.DATA.MINUS_INTERVAL) | |||
len(vr), | |||
vr.get_avg_fps(), | |||
cfg.DATA.TARGET_FPS, | |||
clip_idx, | |||
num_clips_per_video, | |||
cfg.DATA.NUM_INPUT_FRAMES, | |||
cfg.DATA.SAMPLING_RATE, | |||
cfg.DATA.MINUS_INTERVAL, | |||
) | |||
frames = None | |||
if path.endswith('.avi'): | |||
append_list = torch.arange(0, list_[0], 4) | |||
@@ -155,8 +193,6 @@ def _decode_video(cfg, path): | |||
vr.get_batch(list_).to_dlpack()).clone() | |||
frame_list.append(frames) | |||
frames = torch.stack(frame_list) | |||
if num_clips_per_video == 1: | |||
frames = frames.squeeze(0) | |||
del vr | |||
return frames | |||
@@ -224,6 +260,29 @@ class KineticsResizedCrop(object): | |||
y = y_max // 2 | |||
return new_clip[:, :, y:y + self.crop_size, x:x + self.crop_size] | |||
def _get_random_crop(self, clip): | |||
_, _, clip_height, clip_width = clip.shape | |||
short_side = min(clip_height, clip_width) | |||
long_side = max(clip_height, clip_width) | |||
new_short_side = int(random.uniform(*self.short_side_range)) | |||
new_long_side = int(long_side / short_side * new_short_side) | |||
if clip_height < clip_width: | |||
new_clip_height = new_short_side | |||
new_clip_width = new_long_side | |||
else: | |||
new_clip_height = new_long_side | |||
new_clip_width = new_short_side | |||
new_clip = torch.nn.functional.interpolate( | |||
clip, size=(new_clip_height, new_clip_width), mode='bilinear') | |||
x_max = int(new_clip_width - self.crop_size) | |||
y_max = int(new_clip_height - self.crop_size) | |||
x = int(random.uniform(0, x_max)) | |||
y = int(random.uniform(0, y_max)) | |||
return new_clip[:, :, y:y + self.crop_size, x:x + self.crop_size] | |||
def set_spatial_index(self, idx): | |||
"""Set the spatial cropping index for controlled cropping.. | |||
Args: | |||
@@ -0,0 +1,26 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
# !/usr/bin/env python | |||
import unittest | |||
from modelscope.pipelines import pipeline | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.test_utils import test_level | |||
class HICOSSLVideoEmbeddingTest(unittest.TestCase): | |||
def setUp(self) -> None: | |||
self.model_id = 'damo/cv_s3dg_video-embedding' | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_modelhub(self): | |||
videossl_pipeline = pipeline( | |||
Tasks.video_embedding, model=self.model_id) | |||
result = videossl_pipeline( | |||
'data/test/videos/action_recognition_test_video.mp4') | |||
print(f'video embedding output: {result}.') | |||
if __name__ == '__main__': | |||
unittest.main() |