|
|
@@ -1,9 +1,13 @@ |
|
|
|
# The implementation is adopted from the CLIP4Clip implementation, |
|
|
|
# made pubicly available under Apache License, Version 2.0 at https://github.com/ArrowLuo/CLIP4Clip |
|
|
|
|
|
|
|
import os |
|
|
|
import random |
|
|
|
import uuid |
|
|
|
from os.path import exists |
|
|
|
from tempfile import TemporaryDirectory |
|
|
|
from typing import Any, Dict |
|
|
|
from urllib.parse import urlparse |
|
|
|
|
|
|
|
import json |
|
|
|
import numpy as np |
|
|
@@ -11,6 +15,7 @@ import torch |
|
|
|
from decord import VideoReader, cpu |
|
|
|
from PIL import Image |
|
|
|
|
|
|
|
from modelscope.hub.file_download import http_get_file |
|
|
|
from modelscope.metainfo import Models |
|
|
|
from modelscope.models import TorchModel |
|
|
|
from modelscope.models.builder import MODELS |
|
|
@@ -68,12 +73,16 @@ class VideoCLIPForMultiModalEmbedding(TorchModel): |
|
|
|
self.model.to(self.device) |
|
|
|
|
|
|
|
def _get_text(self, caption, tokenizer, enable_zh=False): |
|
|
|
if len(caption) == 3: |
|
|
|
_caption_text, s, e = caption |
|
|
|
elif len(caption) == 4: |
|
|
|
_caption_text, s, e, pos = caption |
|
|
|
else: |
|
|
|
NotImplementedError |
|
|
|
|
|
|
|
if type(caption) is str: |
|
|
|
_caption_text, s, e = caption, None, None |
|
|
|
elif type(caption) is tuple: |
|
|
|
if len(caption) == 3: |
|
|
|
_caption_text, s, e = caption |
|
|
|
elif len(caption) == 4: |
|
|
|
_caption_text, s, e, pos = caption |
|
|
|
else: |
|
|
|
NotImplementedError |
|
|
|
|
|
|
|
if isinstance(_caption_text, list): |
|
|
|
caption_text = random.choice(_caption_text) |
|
|
@@ -137,11 +146,25 @@ class VideoCLIPForMultiModalEmbedding(TorchModel): |
|
|
|
elif start_time == end_time: |
|
|
|
end_time = end_time + 1 |
|
|
|
|
|
|
|
if exists(video_path): |
|
|
|
url_parsed = urlparse(video_path) |
|
|
|
if url_parsed.scheme in ('file', '') and exists( |
|
|
|
url_parsed.path): # Possibly a local file |
|
|
|
vreader = VideoReader(video_path, ctx=cpu(0)) |
|
|
|
else: |
|
|
|
logger.error('non video input, output is wrong!!!') |
|
|
|
return video, video_mask |
|
|
|
try: |
|
|
|
with TemporaryDirectory() as temporary_cache_dir: |
|
|
|
random_str = uuid.uuid4().hex |
|
|
|
http_get_file( |
|
|
|
url=video_path, |
|
|
|
local_dir=temporary_cache_dir, |
|
|
|
file_name=random_str, |
|
|
|
cookies=None) |
|
|
|
temp_file_path = os.path.join(temporary_cache_dir, |
|
|
|
random_str) |
|
|
|
vreader = VideoReader(temp_file_path, ctx=cpu(0)) |
|
|
|
except Exception as ex: |
|
|
|
logger.error('non video input, output is {}!!!'.format(ex)) |
|
|
|
return video, video_mask |
|
|
|
|
|
|
|
fps = vreader.get_avg_fps() |
|
|
|
f_start = 0 if start_time is None else int(start_time * fps) |
|
|
|