cv/language_guided_video_summarization增加finetune
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10790262
master^2
| @@ -11,6 +11,7 @@ if TYPE_CHECKING: | |||
| from .image_instance_segmentation_coco_dataset import ImageInstanceSegmentationCocoDataset | |||
| from .movie_scene_segmentation import MovieSceneSegmentationDataset | |||
| from .video_summarization_dataset import VideoSummarizationDataset | |||
| from .language_guided_video_summarization_dataset import LanguageGuidedVideoSummarizationDataset | |||
| from .image_inpainting import ImageInpaintingDataset | |||
| from .text_ranking_dataset import TextRankingDataset | |||
| from .referring_video_object_segmentation import ReferringVideoObjectSegmentationDataset | |||
| @@ -25,6 +26,8 @@ else: | |||
| 'image_instance_segmentation_coco_dataset': | |||
| ['ImageInstanceSegmentationCocoDataset'], | |||
| 'video_summarization_dataset': ['VideoSummarizationDataset'], | |||
| 'language_guided_video_summarization_dataset': | |||
| ['LanguageGuidedVideoSummarizationDataset'], | |||
| 'movie_scene_segmentation': ['MovieSceneSegmentationDataset'], | |||
| 'image_inpainting': ['ImageInpaintingDataset'], | |||
| 'sidd_image_denoising_dataset': ['SiddImageDenoisingDataset'], | |||
| @@ -0,0 +1,90 @@ | |||
| # Part of the implementation is borrowed and modified from PGL-SUM, | |||
| # publicly available at https://github.com/e-apostolidis/PGL-SUM, follow the | |||
| # license https://github.com/e-apostolidis/PGL-SUM/blob/master/LICENSE.md. | |||
| import os | |||
| import h5py | |||
| import json | |||
| import numpy as np | |||
| import torch | |||
| from modelscope.metainfo import Models | |||
| from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS | |||
| from modelscope.msdatasets.task_datasets.torch_base_dataset import \ | |||
| TorchTaskDataset | |||
| from modelscope.utils.constant import Tasks | |||
| @TASK_DATASETS.register_module( | |||
| Tasks.language_guided_video_summarization, | |||
| module_name=Models.language_guided_video_summarization) | |||
| class LanguageGuidedVideoSummarizationDataset(TorchTaskDataset): | |||
| def __init__(self, mode, opt, root_dir): | |||
| self.mode = mode | |||
| self.data_filename = os.path.join(root_dir, opt.dataset_file) | |||
| self.split_filename = os.path.join(root_dir, opt.split_file) | |||
| self.split_index = opt.split_index | |||
| hdf = h5py.File(self.data_filename, 'r') | |||
| self.list_image_features = [] | |||
| self.list_text_features = [] | |||
| self.list_gtscores = [] | |||
| self.list_user_summary = [] | |||
| self.list_change_points = [] | |||
| self.list_n_frames = [] | |||
| self.list_positions = [] | |||
| with open(self.split_filename) as f: | |||
| data = json.loads(f.read()) | |||
| for i, split in enumerate(data): | |||
| if i == self.split_index: | |||
| self.split = split | |||
| break | |||
| for video_name in self.split[self.mode + '_keys']: | |||
| clip_image_features = torch.Tensor( | |||
| np.array(hdf[video_name + '/features_clip_image'])) | |||
| clip_txt_features = torch.Tensor( | |||
| np.array(hdf[video_name + '/features_clip_txt'])).reshape( | |||
| 1, -1) | |||
| clip_txt_features = clip_txt_features.repeat( | |||
| clip_image_features.size(0), 1) | |||
| gtscore = torch.Tensor(np.array(hdf[video_name + '/gtscore'])) | |||
| user_summary = np.array(hdf[f'{video_name}/user_summary']) | |||
| change_points = np.array(hdf[f'{video_name}/change_points']) | |||
| n_frames = np.array(hdf[f'{video_name}/n_frames']) | |||
| positions = np.array(hdf[f'{video_name}/picks']) | |||
| self.list_image_features.append(clip_image_features) | |||
| self.list_text_features.append(clip_txt_features) | |||
| self.list_gtscores.append(gtscore) | |||
| self.list_user_summary.append(user_summary) | |||
| self.list_change_points.append(change_points) | |||
| self.list_n_frames.append(n_frames) | |||
| self.list_positions.append(positions) | |||
| hdf.close() | |||
| def __len__(self): | |||
| self.len = len(self.split[self.mode + '_keys']) | |||
| return self.len | |||
| def __getitem__(self, index): | |||
| clip_image_features = self.list_image_features[index] | |||
| clip_txt_features = self.list_text_features[index] | |||
| gtscore = self.list_gtscores[index] | |||
| user_summary = self.list_user_summary[index] | |||
| change_points = self.list_change_points[index] | |||
| n_frames = self.list_n_frames[index] | |||
| positions = self.list_positions[index] | |||
| return dict( | |||
| frame_features=clip_image_features, | |||
| txt_features=clip_txt_features, | |||
| gtscore=gtscore, | |||
| user_summary=user_summary, | |||
| change_points=change_points, | |||
| n_frames=n_frames, | |||
| positions=positions) | |||
| @@ -0,0 +1,76 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import shutil | |||
| import tempfile | |||
| import unittest | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models.cv.language_guided_video_summarization import \ | |||
| ClipItVideoSummarization | |||
| from modelscope.msdatasets.task_datasets import \ | |||
| LanguageGuidedVideoSummarizationDataset | |||
| from modelscope.trainers import build_trainer | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import ModelFile | |||
| from modelscope.utils.logger import get_logger | |||
| from modelscope.utils.test_utils import test_level | |||
| logger = get_logger() | |||
| class LanguageGuidedVideoSummarizationTrainerTest(unittest.TestCase): | |||
| def setUp(self): | |||
| print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||
| self.tmp_dir = tempfile.TemporaryDirectory().name | |||
| if not os.path.exists(self.tmp_dir): | |||
| os.makedirs(self.tmp_dir) | |||
| self.model_id = 'damo/cv_clip-it_video-summarization_language-guided_en' | |||
| self.cache_path = snapshot_download(self.model_id) | |||
| self.config = Config.from_file( | |||
| os.path.join(self.cache_path, ModelFile.CONFIGURATION)) | |||
| self.dataset_train = LanguageGuidedVideoSummarizationDataset( | |||
| 'train', self.config.dataset, self.cache_path) | |||
| self.dataset_val = LanguageGuidedVideoSummarizationDataset( | |||
| 'test', self.config.dataset, self.cache_path) | |||
| def tearDown(self): | |||
| shutil.rmtree(self.tmp_dir, ignore_errors=True) | |||
| super().tearDown() | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_trainer(self): | |||
| kwargs = dict( | |||
| model=self.model_id, | |||
| train_dataset=self.dataset_train, | |||
| eval_dataset=self.dataset_val, | |||
| max_epochs=2, | |||
| work_dir=self.tmp_dir) | |||
| trainer = build_trainer(default_args=kwargs) | |||
| trainer.train() | |||
| results_files = os.listdir(self.tmp_dir) | |||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
| for i in range(2): | |||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_trainer_with_model_and_args(self): | |||
| model = ClipItVideoSummarization.from_pretrained(self.cache_path) | |||
| kwargs = dict( | |||
| cfg_file=os.path.join(self.cache_path, ModelFile.CONFIGURATION), | |||
| model=model, | |||
| train_dataset=self.dataset_train, | |||
| eval_dataset=self.dataset_val, | |||
| max_epochs=2, | |||
| work_dir=self.tmp_dir) | |||
| trainer = build_trainer(default_args=kwargs) | |||
| trainer.train() | |||
| results_files = os.listdir(self.tmp_dir) | |||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
| for i in range(2): | |||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||