siyang.ssy yingda.chen 2 years ago
parent
commit
6d51f44dc7
2 changed files with 34 additions and 11 deletions
  1. +32
    -9
      modelscope/models/multi_modal/mmr/models/clip_for_mm_video_embedding.py
  2. +2
    -2
      tests/pipelines/test_video_multi_modal_embedding.py

+ 32
- 9
modelscope/models/multi_modal/mmr/models/clip_for_mm_video_embedding.py View File

@@ -1,9 +1,13 @@
# The implementation is adopted from the CLIP4Clip implementation,
# made pubicly available under Apache License, Version 2.0 at https://github.com/ArrowLuo/CLIP4Clip

import os
import random
import uuid
from os.path import exists
from tempfile import TemporaryDirectory
from typing import Any, Dict
from urllib.parse import urlparse

import json
import numpy as np
@@ -11,6 +15,7 @@ import torch
from decord import VideoReader, cpu
from PIL import Image

from modelscope.hub.file_download import http_get_file
from modelscope.metainfo import Models
from modelscope.models import TorchModel
from modelscope.models.builder import MODELS
@@ -68,12 +73,16 @@ class VideoCLIPForMultiModalEmbedding(TorchModel):
self.model.to(self.device)

def _get_text(self, caption, tokenizer, enable_zh=False):
if len(caption) == 3:
_caption_text, s, e = caption
elif len(caption) == 4:
_caption_text, s, e, pos = caption
else:
NotImplementedError

if type(caption) is str:
_caption_text, s, e = caption, None, None
elif type(caption) is tuple:
if len(caption) == 3:
_caption_text, s, e = caption
elif len(caption) == 4:
_caption_text, s, e, pos = caption
else:
NotImplementedError

if isinstance(_caption_text, list):
caption_text = random.choice(_caption_text)
@@ -137,11 +146,25 @@ class VideoCLIPForMultiModalEmbedding(TorchModel):
elif start_time == end_time:
end_time = end_time + 1

if exists(video_path):
url_parsed = urlparse(video_path)
if url_parsed.scheme in ('file', '') and exists(
url_parsed.path): # Possibly a local file
vreader = VideoReader(video_path, ctx=cpu(0))
else:
logger.error('non video input, output is wrong!!!')
return video, video_mask
try:
with TemporaryDirectory() as temporary_cache_dir:
random_str = uuid.uuid4().hex
http_get_file(
url=video_path,
local_dir=temporary_cache_dir,
file_name=random_str,
cookies=None)
temp_file_path = os.path.join(temporary_cache_dir,
random_str)
vreader = VideoReader(temp_file_path, ctx=cpu(0))
except Exception as ex:
logger.error('non video input, output is {}!!!'.format(ex))
return video, video_mask

fps = vreader.get_avg_fps()
f_start = 0 if start_time is None else int(start_time * fps)


+ 2
- 2
tests/pipelines/test_video_multi_modal_embedding.py View File

@@ -17,8 +17,8 @@ class VideoMultiModalEmbeddingTest(unittest.TestCase, DemoCompatibilityCheck):
self.task = Tasks.video_multi_modal_embedding
self.model_id = 'damo/multi_modal_clip_vtretrival_msrvtt_53'

video_path = 'data/test/videos/multi_modal_test_video_9770.mp4'
caption = ('a person is connecting something to system', None, None)
video_path = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/videos/multi_modal_test_video_9770.mp4'
caption = 'a person is connecting something to system'
_input = {'video': video_path, 'text': caption}

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')


Loading…
Cancel
Save