Browse Source

[to #42322933] Add hicossl_video_embedding_pipeline to maas lib

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9969472
master
yongfei.zyf yingda.chen 3 years ago
parent
commit
c8b6030b8e
8 changed files with 538 additions and 32 deletions
  1. +1
    -0
      modelscope/metainfo.py
  2. +43
    -2
      modelscope/models/cv/action_recognition/models.py
  3. +301
    -0
      modelscope/models/cv/action_recognition/s3dg.py
  4. +2
    -0
      modelscope/pipelines/cv/__init__.py
  5. +1
    -0
      modelscope/pipelines/cv/action_recognition_pipeline.py
  6. +75
    -0
      modelscope/pipelines/cv/hicossl_video_embedding_pipeline.py
  7. +89
    -30
      modelscope/preprocessors/video.py
  8. +26
    -0
      tests/pipelines/test_hicossl_video_embedding.py

+ 1
- 0
modelscope/metainfo.py View File

@@ -99,6 +99,7 @@ class Pipelines(object):
animal_recognition = 'resnet101-animal-recognition'
general_recognition = 'resnet101-general-recognition'
cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding'
hicossl_video_embedding = 'hicossl-s3dg-video_embedding'
body_2d_keypoints = 'hrnetv2w32_body-2d-keypoints_image'
body_3d_keypoints = 'canonical_body-3d-keypoints_video'
human_detection = 'resnet18-human-detection'


+ 43
- 2
modelscope/models/cv/action_recognition/models.py View File

@@ -1,5 +1,6 @@
import torch.nn as nn

from .s3dg import Inception3D
from .tada_convnext import TadaConvNeXt


@@ -26,11 +27,25 @@ class BaseVideoModel(nn.Module):
super(BaseVideoModel, self).__init__()
# the backbone is created according to meta-architectures
# defined in models/base/backbone.py
self.backbone = TadaConvNeXt(cfg)
if cfg.MODEL.NAME == 'ConvNeXt_tiny':
self.backbone = TadaConvNeXt(cfg)
elif cfg.MODEL.NAME == 'S3DG':
self.backbone = Inception3D(cfg)
else:
error_str = 'backbone {} is not supported, ConvNeXt_tiny or S3DG is supported'.format(
cfg.MODEL.NAME)
raise NotImplementedError(error_str)

# the head is created according to the heads
# defined in models/module_zoo/heads
self.head = BaseHead(cfg)
if cfg.VIDEO.HEAD.NAME == 'BaseHead':
self.head = BaseHead(cfg)
elif cfg.VIDEO.HEAD.NAME == 'AvgHead':
self.head = AvgHead(cfg)
else:
error_str = 'head {} is not supported, BaseHead or AvgHead is supported'.format(
cfg.VIDEO.HEAD.NAME)
raise NotImplementedError(error_str)

def forward(self, x):
x = self.backbone(x)
@@ -88,3 +103,29 @@ class BaseHead(nn.Module):
out = self.activation(out)
out = out.view(out.shape[0], -1)
return out, x.view(x.shape[0], -1)


class AvgHead(nn.Module):
"""
Constructs base head.
"""

def __init__(
self,
cfg,
):
"""
Args:
cfg (Config): global config object.
"""
super(AvgHead, self).__init__()
self.cfg = cfg
self.global_avg_pool = nn.AdaptiveAvgPool3d(1)

def forward(self, x):
if len(x.shape) == 5:
x = self.global_avg_pool(x)
# (N, C, T, H, W) -> (N, T, H, W, C).
x = x.permute((0, 2, 3, 4, 1))
out = x.view(x.shape[0], -1)
return out, x.view(x.shape[0], -1)

+ 301
- 0
modelscope/models/cv/action_recognition/s3dg.py View File

@@ -0,0 +1,301 @@
import torch
import torch.nn as nn


class InceptionBaseConv3D(nn.Module):
"""
Constructs basic inception 3D conv.
Modified from https://github.com/TengdaHan/CoCLR/blob/main/backbone/s3dg.py.
"""

def __init__(self,
cfg,
in_planes,
out_planes,
kernel_size,
stride,
padding=0):
super(InceptionBaseConv3D, self).__init__()
self.conv = nn.Conv3d(
in_planes,
out_planes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=False)
self.bn = nn.BatchNorm3d(out_planes)
self.relu = nn.ReLU(inplace=True)

# init
self.conv.weight.data.normal_(
mean=0, std=0.01) # original s3d is truncated normal within 2 std
self.bn.weight.data.fill_(1)
self.bn.bias.data.zero_()

def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x


class InceptionBlock3D(nn.Module):
"""
Element constructing the S3D/S3DG.
See models/base/backbone.py L99-186.

Modifed from https://github.com/TengdaHan/CoCLR/blob/main/backbone/s3dg.py.
"""

def __init__(self, cfg, in_planes, out_planes):
super(InceptionBlock3D, self).__init__()

_gating = cfg.VIDEO.BACKBONE.BRANCH.GATING

assert len(out_planes) == 6
assert isinstance(out_planes, list)

[
num_out_0_0a, num_out_1_0a, num_out_1_0b, num_out_2_0a,
num_out_2_0b, num_out_3_0b
] = out_planes

self.branch0 = nn.Sequential(
InceptionBaseConv3D(
cfg, in_planes, num_out_0_0a, kernel_size=1, stride=1), )
self.branch1 = nn.Sequential(
InceptionBaseConv3D(
cfg, in_planes, num_out_1_0a, kernel_size=1, stride=1),
STConv3d(
cfg,
num_out_1_0a,
num_out_1_0b,
kernel_size=3,
stride=1,
padding=1),
)
self.branch2 = nn.Sequential(
InceptionBaseConv3D(
cfg, in_planes, num_out_2_0a, kernel_size=1, stride=1),
STConv3d(
cfg,
num_out_2_0a,
num_out_2_0b,
kernel_size=3,
stride=1,
padding=1),
)
self.branch3 = nn.Sequential(
nn.MaxPool3d(kernel_size=(3, 3, 3), stride=1, padding=1),
InceptionBaseConv3D(
cfg, in_planes, num_out_3_0b, kernel_size=1, stride=1),
)

self.out_channels = sum(
[num_out_0_0a, num_out_1_0b, num_out_2_0b, num_out_3_0b])

self.gating = _gating
if _gating:
self.gating_b0 = SelfGating(num_out_0_0a)
self.gating_b1 = SelfGating(num_out_1_0b)
self.gating_b2 = SelfGating(num_out_2_0b)
self.gating_b3 = SelfGating(num_out_3_0b)

def forward(self, x):
x0 = self.branch0(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
if self.gating:
x0 = self.gating_b0(x0)
x1 = self.gating_b1(x1)
x2 = self.gating_b2(x2)
x3 = self.gating_b3(x3)

out = torch.cat((x0, x1, x2, x3), 1)

return out


class SelfGating(nn.Module):

def __init__(self, input_dim):
super(SelfGating, self).__init__()
self.fc = nn.Linear(input_dim, input_dim)

def forward(self, input_tensor):
"""Feature gating as used in S3D-G"""
spatiotemporal_average = torch.mean(input_tensor, dim=[2, 3, 4])
weights = self.fc(spatiotemporal_average)
weights = torch.sigmoid(weights)
return weights[:, :, None, None, None] * input_tensor


class STConv3d(nn.Module):
"""
Element constructing the S3D/S3DG.
See models/base/backbone.py L99-186.

Modifed from https://github.com/TengdaHan/CoCLR/blob/main/backbone/s3dg.py.
"""

def __init__(self,
cfg,
in_planes,
out_planes,
kernel_size,
stride,
padding=0):
super(STConv3d, self).__init__()
if isinstance(stride, tuple):
t_stride = stride[0]
stride = stride[-1]
else: # int
t_stride = stride

self.bn_mmt = cfg.BN.MOMENTUM
self.bn_eps = float(cfg.BN.EPS)
self._construct_branch(cfg, in_planes, out_planes, kernel_size, stride,
t_stride, padding)

def _construct_branch(self,
cfg,
in_planes,
out_planes,
kernel_size,
stride,
t_stride,
padding=0):
self.conv1 = nn.Conv3d(
in_planes,
out_planes,
kernel_size=(1, kernel_size, kernel_size),
stride=(1, stride, stride),
padding=(0, padding, padding),
bias=False)
self.conv2 = nn.Conv3d(
out_planes,
out_planes,
kernel_size=(kernel_size, 1, 1),
stride=(t_stride, 1, 1),
padding=(padding, 0, 0),
bias=False)

self.bn1 = nn.BatchNorm3d(
out_planes, eps=self.bn_eps, momentum=self.bn_mmt)
self.bn2 = nn.BatchNorm3d(
out_planes, eps=self.bn_eps, momentum=self.bn_mmt)
self.relu = nn.ReLU(inplace=True)

# init
self.conv1.weight.data.normal_(
mean=0, std=0.01) # original s3d is truncated normal within 2 std
self.conv2.weight.data.normal_(
mean=0, std=0.01) # original s3d is truncated normal within 2 std
self.bn1.weight.data.fill_(1)
self.bn1.bias.data.zero_()
self.bn2.weight.data.fill_(1)
self.bn2.bias.data.zero_()

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
return x


class Inception3D(nn.Module):
"""
Backbone architecture for I3D/S3DG.
Modifed from https://github.com/TengdaHan/CoCLR/blob/main/backbone/s3dg.py.
"""

def __init__(self, cfg):
"""
Args:
cfg (Config): global config object.
"""
super(Inception3D, self).__init__()
_input_channel = cfg.DATA.NUM_INPUT_CHANNELS
self._construct_backbone(cfg, _input_channel)

def _construct_backbone(self, cfg, input_channel):
# ------------------- Block 1 -------------------
self.Conv_1a = STConv3d(
cfg, input_channel, 64, kernel_size=7, stride=2, padding=3)

self.block1 = nn.Sequential(self.Conv_1a) # (64, 32, 112, 112)

# ------------------- Block 2 -------------------
self.MaxPool_2a = nn.MaxPool3d(
kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1))
self.Conv_2b = InceptionBaseConv3D(
cfg, 64, 64, kernel_size=1, stride=1)
self.Conv_2c = STConv3d(
cfg, 64, 192, kernel_size=3, stride=1, padding=1)

self.block2 = nn.Sequential(
self.MaxPool_2a, # (64, 32, 56, 56)
self.Conv_2b, # (64, 32, 56, 56)
self.Conv_2c) # (192, 32, 56, 56)

# ------------------- Block 3 -------------------
self.MaxPool_3a = nn.MaxPool3d(
kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1))
self.Mixed_3b = InceptionBlock3D(
cfg, in_planes=192, out_planes=[64, 96, 128, 16, 32, 32])
self.Mixed_3c = InceptionBlock3D(
cfg, in_planes=256, out_planes=[128, 128, 192, 32, 96, 64])

self.block3 = nn.Sequential(
self.MaxPool_3a, # (192, 32, 28, 28)
self.Mixed_3b, # (256, 32, 28, 28)
self.Mixed_3c) # (480, 32, 28, 28)

# ------------------- Block 4 -------------------
self.MaxPool_4a = nn.MaxPool3d(
kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
self.Mixed_4b = InceptionBlock3D(
cfg, in_planes=480, out_planes=[192, 96, 208, 16, 48, 64])
self.Mixed_4c = InceptionBlock3D(
cfg, in_planes=512, out_planes=[160, 112, 224, 24, 64, 64])
self.Mixed_4d = InceptionBlock3D(
cfg, in_planes=512, out_planes=[128, 128, 256, 24, 64, 64])
self.Mixed_4e = InceptionBlock3D(
cfg, in_planes=512, out_planes=[112, 144, 288, 32, 64, 64])
self.Mixed_4f = InceptionBlock3D(
cfg, in_planes=528, out_planes=[256, 160, 320, 32, 128, 128])

self.block4 = nn.Sequential(
self.MaxPool_4a, # (480, 16, 14, 14)
self.Mixed_4b, # (512, 16, 14, 14)
self.Mixed_4c, # (512, 16, 14, 14)
self.Mixed_4d, # (512, 16, 14, 14)
self.Mixed_4e, # (528, 16, 14, 14)
self.Mixed_4f) # (832, 16, 14, 14)

# ------------------- Block 5 -------------------
self.MaxPool_5a = nn.MaxPool3d(
kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 0, 0))
self.Mixed_5b = InceptionBlock3D(
cfg, in_planes=832, out_planes=[256, 160, 320, 32, 128, 128])
self.Mixed_5c = InceptionBlock3D(
cfg, in_planes=832, out_planes=[384, 192, 384, 48, 128, 128])

self.block5 = nn.Sequential(
self.MaxPool_5a, # (832, 8, 7, 7)
self.Mixed_5b, # (832, 8, 7, 7)
self.Mixed_5c) # (1024, 8, 7, 7)

def forward(self, x):
if isinstance(x, dict):
x = x['video']
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.block4(x)
x = self.block5(x)
return x

+ 2
- 0
modelscope/pipelines/cv/__init__.py View File

@@ -9,6 +9,7 @@ if TYPE_CHECKING:
from .body_2d_keypoints_pipeline import Body2DKeypointsPipeline
from .body_3d_keypoints_pipeline import Body3DKeypointsPipeline
from .cmdssl_video_embedding_pipeline import CMDSSLVideoEmbeddingPipeline
from .hicossl_video_embedding_pipeline import HICOSSLVideoEmbeddingPipeline
from .crowd_counting_pipeline import CrowdCountingPipeline
from .image_detection_pipeline import ImageDetectionPipeline
from .image_salient_detection_pipeline import ImageSalientDetectionPipeline
@@ -51,6 +52,7 @@ else:
'body_2d_keypoints_pipeline': ['Body2DKeypointsPipeline'],
'body_3d_keypoints_pipeline': ['Body3DKeypointsPipeline'],
'cmdssl_video_embedding_pipeline': ['CMDSSLVideoEmbeddingPipeline'],
'hicossl_video_embedding_pipeline': ['HICOSSLVideoEmbeddingPipeline'],
'crowd_counting_pipeline': ['CrowdCountingPipeline'],
'image_detection_pipeline': ['ImageDetectionPipeline'],
'image_salient_detection_pipeline': ['ImageSalientDetectionPipeline'],


+ 1
- 0
modelscope/pipelines/cv/action_recognition_pipeline.py View File

@@ -33,6 +33,7 @@ class ActionRecognitionPipeline(Pipeline):
config_path = osp.join(self.model, ModelFile.CONFIGURATION)
logger.info(f'loading config from {config_path}')
self.cfg = Config.from_file(config_path)

self.infer_model = BaseVideoModel(cfg=self.cfg).to(self.device)
self.infer_model.eval()
self.infer_model.load_state_dict(


+ 75
- 0
modelscope/pipelines/cv/hicossl_video_embedding_pipeline.py View File

@@ -0,0 +1,75 @@
import math
import os.path as osp
from typing import Any, Dict

import torch

from modelscope.metainfo import Pipelines
from modelscope.models.cv.action_recognition import BaseVideoModel
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import ReadVideoData
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


@PIPELINES.register_module(
Tasks.video_embedding, module_name=Pipelines.hicossl_video_embedding)
class HICOSSLVideoEmbeddingPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
"""
use `model` to create a hicossl video embedding pipeline for prediction
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model, **kwargs)
model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE)
logger.info(f'loading model from {model_path}')
config_path = osp.join(self.model, ModelFile.CONFIGURATION)
logger.info(f'loading config from {config_path}')
self.cfg = Config.from_file(config_path)
self.infer_model = BaseVideoModel(cfg=self.cfg).to(self.device)
self.infer_model.eval()
self.infer_model.load_state_dict(
torch.load(model_path, map_location=self.device)['model_state'],
strict=False)
logger.info('load model done')

def preprocess(self, input: Input) -> Dict[str, Any]:
if isinstance(input, str):
video_input_data = ReadVideoData(
self.cfg, input, num_temporal_views_override=1).to(self.device)
else:
raise TypeError(f'input should be a str,'
f' but got {type(input)}')
result = {'video_data': video_input_data}
return result

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
feature = self.perform_inference(input['video_data'])
return {OutputKeys.VIDEO_EMBEDDING: feature.data.cpu().numpy()}

@torch.no_grad()
def perform_inference(self, data, max_bsz=4):
""" Perform feature extracting for a given video
Args:
model (BaseVideoModel): video model with loadded state dict.
max_bsz (int): the maximum batch size, limited by GPU memory.
Returns:
pred (Tensor): the extracted features for input video clips.
"""
iter_num = math.ceil(data.size(0) / max_bsz)
preds_list = []
for i in range(iter_num):
preds_list.append(
self.infer_model(data[i * max_bsz:(i + 1) * max_bsz])[0])
pred = torch.cat(preds_list, dim=0)
return pred

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

+ 89
- 30
modelscope/preprocessors/video.py View File

@@ -16,34 +16,49 @@ from .base import Preprocessor
from .builder import PREPROCESSORS


def ReadVideoData(cfg, video_path):
def ReadVideoData(cfg,
video_path,
num_spatial_crops_override=None,
num_temporal_views_override=None):
""" simple interface to load video frames from file

Args:
cfg (Config): The global config object.
video_path (str): video file path
num_spatial_crops_override (int): the spatial crops per clip
num_temporal_views_override (int): the temporal clips per video
Returns:
data (Tensor): the normalized video clips for model inputs
"""
data = _decode_video(cfg, video_path)
transform = kinetics400_tranform(cfg)
data = _decode_video(cfg, video_path, num_temporal_views_override)
if num_spatial_crops_override is not None:
num_spatial_crops = num_spatial_crops_override
transform = kinetics400_tranform(cfg, num_spatial_crops_override)
else:
num_spatial_crops = cfg.TEST.NUM_SPATIAL_CROPS
transform = kinetics400_tranform(cfg, cfg.TEST.NUM_SPATIAL_CROPS)
data_list = []
for i in range(data.size(0)):
for j in range(cfg.TEST.NUM_SPATIAL_CROPS):
for j in range(num_spatial_crops):
transform.transforms[1].set_spatial_index(j)
data_list.append(transform(data[i]))
return torch.stack(data_list, dim=0)


def kinetics400_tranform(cfg):
def kinetics400_tranform(cfg, num_spatial_crops):
"""
Configs the transform for the kinetics-400 dataset.
We apply controlled spatial cropping and normalization.
Args:
cfg (Config): The global config object.
num_spatial_crops (int): the spatial crops per clip
Returns:
transform_function (Compose): the transform function for input clips
"""
resize_video = KineticsResizedCrop(
short_side_range=[cfg.DATA.TEST_SCALE, cfg.DATA.TEST_SCALE],
crop_size=cfg.DATA.TEST_CROP_SIZE,
num_spatial_crops=cfg.TEST.NUM_SPATIAL_CROPS)
num_spatial_crops=num_spatial_crops)
std_transform_list = [
transforms.ToTensorVideo(), resize_video,
transforms.NormalizeVideo(
@@ -60,17 +75,17 @@ def _interval_based_sampling(vid_length, vid_fps, target_fps, clip_idx,
vid_length (int): the length of the whole video (valid selection range).
vid_fps (int): the original video fps
target_fps (int): the normalized video fps
clip_idx (int): -1 for random temporal sampling, and positive values for
sampling specific clip from the video
clip_idx (int): -1 for random temporal sampling, and positive values for sampling specific
clip from the video
num_clips (int): the total clips to be sampled from each video.
combined with clip_idx, the sampled video is the "clip_idx-th"
video from "num_clips" videos.
combined with clip_idx, the sampled video is the "clip_idx-th" video from
"num_clips" videos.
num_frames (int): number of frames in each sampled clips.
interval (int): the interval to sample each frame.
minus_interval (bool): control the end index
Returns:
index (tensor): the sampled frame indexes
"""
"""
if num_frames == 1:
index = [random.randint(0, vid_length - 1)]
else:
@@ -78,7 +93,10 @@ def _interval_based_sampling(vid_length, vid_fps, target_fps, clip_idx,
clip_length = num_frames * interval * vid_fps / target_fps

max_idx = max(vid_length - clip_length, 0)
start_idx = clip_idx * math.floor(max_idx / (num_clips - 1))
if num_clips == 1:
start_idx = max_idx / 2
else:
start_idx = clip_idx * math.floor(max_idx / (num_clips - 1))
if minus_interval:
end_idx = start_idx + clip_length - interval
else:
@@ -90,59 +108,79 @@ def _interval_based_sampling(vid_length, vid_fps, target_fps, clip_idx,
return index


def _decode_video_frames_list(cfg, frames_list, vid_fps):
def _decode_video_frames_list(cfg,
frames_list,
vid_fps,
num_temporal_views_override=None):
"""
Decodes the video given the numpy frames.
Args:
cfg (Config): The global config object.
frames_list (list): all frames for a video, the frames should be numpy array.
vid_fps (int): the fps of this video.
num_temporal_views_override (int): the temporal clips per video
Returns:
frames (Tensor): video tensor data
"""
assert isinstance(frames_list, list)
num_clips_per_video = cfg.TEST.NUM_ENSEMBLE_VIEWS
if num_temporal_views_override is not None:
num_clips_per_video = num_temporal_views_override
else:
num_clips_per_video = cfg.TEST.NUM_ENSEMBLE_VIEWS

frame_list = []
for clip_idx in range(num_clips_per_video):
# for each clip in the video,
# a list is generated before decoding the specified frames from the video
list_ = _interval_based_sampling(
len(frames_list), vid_fps, cfg.DATA.TARGET_FPS, clip_idx,
num_clips_per_video, cfg.DATA.NUM_INPUT_FRAMES,
cfg.DATA.SAMPLING_RATE, cfg.DATA.MINUS_INTERVAL)
len(frames_list),
vid_fps,
cfg.DATA.TARGET_FPS,
clip_idx,
num_clips_per_video,
cfg.DATA.NUM_INPUT_FRAMES,
cfg.DATA.SAMPLING_RATE,
cfg.DATA.MINUS_INTERVAL,
)
frames = None
frames = torch.from_numpy(
np.stack([frames_list[l_index] for l_index in list_.tolist()],
axis=0))
np.stack([frames_list[index] for index in list_.tolist()], axis=0))
frame_list.append(frames)
frames = torch.stack(frame_list)
if num_clips_per_video == 1:
frames = frames.squeeze(0)

del vr
return frames


def _decode_video(cfg, path):
def _decode_video(cfg, path, num_temporal_views_override=None):
"""
Decodes the video given the numpy frames.
Args:
cfg (Config): The global config object.
path (str): video file path.
num_temporal_views_override (int): the temporal clips per video
Returns:
frames (Tensor): video tensor data
"""
vr = VideoReader(path)

num_clips_per_video = cfg.TEST.NUM_ENSEMBLE_VIEWS
if num_temporal_views_override is not None:
num_clips_per_video = num_temporal_views_override
else:
num_clips_per_video = cfg.TEST.NUM_ENSEMBLE_VIEWS

frame_list = []
for clip_idx in range(num_clips_per_video):
# for each clip in the video,
# a list is generated before decoding the specified frames from the video
list_ = _interval_based_sampling(
len(vr), vr.get_avg_fps(), cfg.DATA.TARGET_FPS, clip_idx,
num_clips_per_video, cfg.DATA.NUM_INPUT_FRAMES,
cfg.DATA.SAMPLING_RATE, cfg.DATA.MINUS_INTERVAL)
len(vr),
vr.get_avg_fps(),
cfg.DATA.TARGET_FPS,
clip_idx,
num_clips_per_video,
cfg.DATA.NUM_INPUT_FRAMES,
cfg.DATA.SAMPLING_RATE,
cfg.DATA.MINUS_INTERVAL,
)
frames = None
if path.endswith('.avi'):
append_list = torch.arange(0, list_[0], 4)
@@ -155,8 +193,6 @@ def _decode_video(cfg, path):
vr.get_batch(list_).to_dlpack()).clone()
frame_list.append(frames)
frames = torch.stack(frame_list)
if num_clips_per_video == 1:
frames = frames.squeeze(0)
del vr
return frames

@@ -224,6 +260,29 @@ class KineticsResizedCrop(object):
y = y_max // 2
return new_clip[:, :, y:y + self.crop_size, x:x + self.crop_size]

def _get_random_crop(self, clip):
_, _, clip_height, clip_width = clip.shape

short_side = min(clip_height, clip_width)
long_side = max(clip_height, clip_width)
new_short_side = int(random.uniform(*self.short_side_range))
new_long_side = int(long_side / short_side * new_short_side)
if clip_height < clip_width:
new_clip_height = new_short_side
new_clip_width = new_long_side
else:
new_clip_height = new_long_side
new_clip_width = new_short_side

new_clip = torch.nn.functional.interpolate(
clip, size=(new_clip_height, new_clip_width), mode='bilinear')

x_max = int(new_clip_width - self.crop_size)
y_max = int(new_clip_height - self.crop_size)
x = int(random.uniform(0, x_max))
y = int(random.uniform(0, y_max))
return new_clip[:, :, y:y + self.crop_size, x:x + self.crop_size]

def set_spatial_index(self, idx):
"""Set the spatial cropping index for controlled cropping..
Args:


+ 26
- 0
tests/pipelines/test_hicossl_video_embedding.py View File

@@ -0,0 +1,26 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# !/usr/bin/env python
import unittest

from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class HICOSSLVideoEmbeddingTest(unittest.TestCase):

def setUp(self) -> None:
self.model_id = 'damo/cv_s3dg_video-embedding'

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub(self):
videossl_pipeline = pipeline(
Tasks.video_embedding, model=self.model_id)
result = videossl_pipeline(
'data/test/videos/action_recognition_test_video.mp4')

print(f'video embedding output: {result}.')


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save