cv/language_guided_video_summarization增加finetune Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10790262master^2
@@ -11,6 +11,7 @@ if TYPE_CHECKING: | |||||
from .image_instance_segmentation_coco_dataset import ImageInstanceSegmentationCocoDataset | from .image_instance_segmentation_coco_dataset import ImageInstanceSegmentationCocoDataset | ||||
from .movie_scene_segmentation import MovieSceneSegmentationDataset | from .movie_scene_segmentation import MovieSceneSegmentationDataset | ||||
from .video_summarization_dataset import VideoSummarizationDataset | from .video_summarization_dataset import VideoSummarizationDataset | ||||
from .language_guided_video_summarization_dataset import LanguageGuidedVideoSummarizationDataset | |||||
from .image_inpainting import ImageInpaintingDataset | from .image_inpainting import ImageInpaintingDataset | ||||
from .text_ranking_dataset import TextRankingDataset | from .text_ranking_dataset import TextRankingDataset | ||||
from .referring_video_object_segmentation import ReferringVideoObjectSegmentationDataset | from .referring_video_object_segmentation import ReferringVideoObjectSegmentationDataset | ||||
@@ -25,6 +26,8 @@ else: | |||||
'image_instance_segmentation_coco_dataset': | 'image_instance_segmentation_coco_dataset': | ||||
['ImageInstanceSegmentationCocoDataset'], | ['ImageInstanceSegmentationCocoDataset'], | ||||
'video_summarization_dataset': ['VideoSummarizationDataset'], | 'video_summarization_dataset': ['VideoSummarizationDataset'], | ||||
'language_guided_video_summarization_dataset': | |||||
['LanguageGuidedVideoSummarizationDataset'], | |||||
'movie_scene_segmentation': ['MovieSceneSegmentationDataset'], | 'movie_scene_segmentation': ['MovieSceneSegmentationDataset'], | ||||
'image_inpainting': ['ImageInpaintingDataset'], | 'image_inpainting': ['ImageInpaintingDataset'], | ||||
'sidd_image_denoising_dataset': ['SiddImageDenoisingDataset'], | 'sidd_image_denoising_dataset': ['SiddImageDenoisingDataset'], | ||||
@@ -0,0 +1,90 @@ | |||||
# Part of the implementation is borrowed and modified from PGL-SUM, | |||||
# publicly available at https://github.com/e-apostolidis/PGL-SUM, follow the | |||||
# license https://github.com/e-apostolidis/PGL-SUM/blob/master/LICENSE.md. | |||||
import os | |||||
import h5py | |||||
import json | |||||
import numpy as np | |||||
import torch | |||||
from modelscope.metainfo import Models | |||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS | |||||
from modelscope.msdatasets.task_datasets.torch_base_dataset import \ | |||||
TorchTaskDataset | |||||
from modelscope.utils.constant import Tasks | |||||
@TASK_DATASETS.register_module( | |||||
Tasks.language_guided_video_summarization, | |||||
module_name=Models.language_guided_video_summarization) | |||||
class LanguageGuidedVideoSummarizationDataset(TorchTaskDataset): | |||||
def __init__(self, mode, opt, root_dir): | |||||
self.mode = mode | |||||
self.data_filename = os.path.join(root_dir, opt.dataset_file) | |||||
self.split_filename = os.path.join(root_dir, opt.split_file) | |||||
self.split_index = opt.split_index | |||||
hdf = h5py.File(self.data_filename, 'r') | |||||
self.list_image_features = [] | |||||
self.list_text_features = [] | |||||
self.list_gtscores = [] | |||||
self.list_user_summary = [] | |||||
self.list_change_points = [] | |||||
self.list_n_frames = [] | |||||
self.list_positions = [] | |||||
with open(self.split_filename) as f: | |||||
data = json.loads(f.read()) | |||||
for i, split in enumerate(data): | |||||
if i == self.split_index: | |||||
self.split = split | |||||
break | |||||
for video_name in self.split[self.mode + '_keys']: | |||||
clip_image_features = torch.Tensor( | |||||
np.array(hdf[video_name + '/features_clip_image'])) | |||||
clip_txt_features = torch.Tensor( | |||||
np.array(hdf[video_name + '/features_clip_txt'])).reshape( | |||||
1, -1) | |||||
clip_txt_features = clip_txt_features.repeat( | |||||
clip_image_features.size(0), 1) | |||||
gtscore = torch.Tensor(np.array(hdf[video_name + '/gtscore'])) | |||||
user_summary = np.array(hdf[f'{video_name}/user_summary']) | |||||
change_points = np.array(hdf[f'{video_name}/change_points']) | |||||
n_frames = np.array(hdf[f'{video_name}/n_frames']) | |||||
positions = np.array(hdf[f'{video_name}/picks']) | |||||
self.list_image_features.append(clip_image_features) | |||||
self.list_text_features.append(clip_txt_features) | |||||
self.list_gtscores.append(gtscore) | |||||
self.list_user_summary.append(user_summary) | |||||
self.list_change_points.append(change_points) | |||||
self.list_n_frames.append(n_frames) | |||||
self.list_positions.append(positions) | |||||
hdf.close() | |||||
def __len__(self): | |||||
self.len = len(self.split[self.mode + '_keys']) | |||||
return self.len | |||||
def __getitem__(self, index): | |||||
clip_image_features = self.list_image_features[index] | |||||
clip_txt_features = self.list_text_features[index] | |||||
gtscore = self.list_gtscores[index] | |||||
user_summary = self.list_user_summary[index] | |||||
change_points = self.list_change_points[index] | |||||
n_frames = self.list_n_frames[index] | |||||
positions = self.list_positions[index] | |||||
return dict( | |||||
frame_features=clip_image_features, | |||||
txt_features=clip_txt_features, | |||||
gtscore=gtscore, | |||||
user_summary=user_summary, | |||||
change_points=change_points, | |||||
n_frames=n_frames, | |||||
positions=positions) |
@@ -0,0 +1,76 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import os | |||||
import shutil | |||||
import tempfile | |||||
import unittest | |||||
from modelscope.hub.snapshot_download import snapshot_download | |||||
from modelscope.models.cv.language_guided_video_summarization import \ | |||||
ClipItVideoSummarization | |||||
from modelscope.msdatasets.task_datasets import \ | |||||
LanguageGuidedVideoSummarizationDataset | |||||
from modelscope.trainers import build_trainer | |||||
from modelscope.utils.config import Config | |||||
from modelscope.utils.constant import ModelFile | |||||
from modelscope.utils.logger import get_logger | |||||
from modelscope.utils.test_utils import test_level | |||||
logger = get_logger() | |||||
class LanguageGuidedVideoSummarizationTrainerTest(unittest.TestCase): | |||||
def setUp(self): | |||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||||
self.tmp_dir = tempfile.TemporaryDirectory().name | |||||
if not os.path.exists(self.tmp_dir): | |||||
os.makedirs(self.tmp_dir) | |||||
self.model_id = 'damo/cv_clip-it_video-summarization_language-guided_en' | |||||
self.cache_path = snapshot_download(self.model_id) | |||||
self.config = Config.from_file( | |||||
os.path.join(self.cache_path, ModelFile.CONFIGURATION)) | |||||
self.dataset_train = LanguageGuidedVideoSummarizationDataset( | |||||
'train', self.config.dataset, self.cache_path) | |||||
self.dataset_val = LanguageGuidedVideoSummarizationDataset( | |||||
'test', self.config.dataset, self.cache_path) | |||||
def tearDown(self): | |||||
shutil.rmtree(self.tmp_dir, ignore_errors=True) | |||||
super().tearDown() | |||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
def test_trainer(self): | |||||
kwargs = dict( | |||||
model=self.model_id, | |||||
train_dataset=self.dataset_train, | |||||
eval_dataset=self.dataset_val, | |||||
max_epochs=2, | |||||
work_dir=self.tmp_dir) | |||||
trainer = build_trainer(default_args=kwargs) | |||||
trainer.train() | |||||
results_files = os.listdir(self.tmp_dir) | |||||
self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||||
for i in range(2): | |||||
self.assertIn(f'epoch_{i+1}.pth', results_files) | |||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
def test_trainer_with_model_and_args(self): | |||||
model = ClipItVideoSummarization.from_pretrained(self.cache_path) | |||||
kwargs = dict( | |||||
cfg_file=os.path.join(self.cache_path, ModelFile.CONFIGURATION), | |||||
model=model, | |||||
train_dataset=self.dataset_train, | |||||
eval_dataset=self.dataset_val, | |||||
max_epochs=2, | |||||
work_dir=self.tmp_dir) | |||||
trainer = build_trainer(default_args=kwargs) | |||||
trainer.train() | |||||
results_files = os.listdir(self.tmp_dir) | |||||
self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||||
for i in range(2): | |||||
self.assertIn(f'epoch_{i+1}.pth', results_files) | |||||
if __name__ == '__main__': | |||||
unittest.main() |