You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_language_guided_video_summarization_trainer.py 3.0 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import shutil
  4. import tempfile
  5. import unittest
  6. from modelscope.hub.snapshot_download import snapshot_download
  7. from modelscope.models.cv.language_guided_video_summarization import \
  8. ClipItVideoSummarization
  9. from modelscope.msdatasets.task_datasets import \
  10. LanguageGuidedVideoSummarizationDataset
  11. from modelscope.trainers import build_trainer
  12. from modelscope.utils.config import Config
  13. from modelscope.utils.constant import ModelFile
  14. from modelscope.utils.logger import get_logger
  15. from modelscope.utils.test_utils import test_level
  16. logger = get_logger()
  17. class LanguageGuidedVideoSummarizationTrainerTest(unittest.TestCase):
  18. def setUp(self):
  19. print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
  20. self.tmp_dir = tempfile.TemporaryDirectory().name
  21. if not os.path.exists(self.tmp_dir):
  22. os.makedirs(self.tmp_dir)
  23. self.model_id = 'damo/cv_clip-it_video-summarization_language-guided_en'
  24. self.cache_path = snapshot_download(self.model_id)
  25. self.config = Config.from_file(
  26. os.path.join(self.cache_path, ModelFile.CONFIGURATION))
  27. self.dataset_train = LanguageGuidedVideoSummarizationDataset(
  28. 'train', self.config.dataset, self.cache_path)
  29. self.dataset_val = LanguageGuidedVideoSummarizationDataset(
  30. 'test', self.config.dataset, self.cache_path)
  31. def tearDown(self):
  32. shutil.rmtree(self.tmp_dir, ignore_errors=True)
  33. super().tearDown()
  34. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  35. def test_trainer(self):
  36. kwargs = dict(
  37. model=self.model_id,
  38. train_dataset=self.dataset_train,
  39. eval_dataset=self.dataset_val,
  40. max_epochs=2,
  41. work_dir=self.tmp_dir)
  42. trainer = build_trainer(default_args=kwargs)
  43. trainer.train()
  44. results_files = os.listdir(self.tmp_dir)
  45. self.assertIn(f'{trainer.timestamp}.log.json', results_files)
  46. for i in range(2):
  47. self.assertIn(f'epoch_{i+1}.pth', results_files)
  48. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  49. def test_trainer_with_model_and_args(self):
  50. model = ClipItVideoSummarization.from_pretrained(self.cache_path)
  51. kwargs = dict(
  52. cfg_file=os.path.join(self.cache_path, ModelFile.CONFIGURATION),
  53. model=model,
  54. train_dataset=self.dataset_train,
  55. eval_dataset=self.dataset_val,
  56. max_epochs=2,
  57. work_dir=self.tmp_dir)
  58. trainer = build_trainer(default_args=kwargs)
  59. trainer.train()
  60. results_files = os.listdir(self.tmp_dir)
  61. self.assertIn(f'{trainer.timestamp}.log.json', results_files)
  62. for i in range(2):
  63. self.assertIn(f'epoch_{i+1}.pth', results_files)
  64. if __name__ == '__main__':
  65. unittest.main()