You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_video_summarization_trainer.py 3.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import shutil
  4. import tempfile
  5. import unittest
  6. from modelscope.hub.snapshot_download import snapshot_download
  7. from modelscope.models.cv.video_summarization import PGLVideoSummarization
  8. from modelscope.msdatasets.task_datasets import VideoSummarizationDataset
  9. from modelscope.trainers import build_trainer
  10. from modelscope.utils.config import Config
  11. from modelscope.utils.constant import ModelFile
  12. from modelscope.utils.logger import get_logger
  13. from modelscope.utils.test_utils import test_level
  14. logger = get_logger()
  15. class VideoSummarizationTrainerTest(unittest.TestCase):
  16. def setUp(self):
  17. print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
  18. self.tmp_dir = tempfile.TemporaryDirectory().name
  19. if not os.path.exists(self.tmp_dir):
  20. os.makedirs(self.tmp_dir)
  21. self.model_id = 'damo/cv_googlenet_pgl-video-summarization'
  22. self.cache_path = snapshot_download(self.model_id)
  23. self.config = Config.from_file(
  24. os.path.join(self.cache_path, ModelFile.CONFIGURATION))
  25. self.dataset_train = VideoSummarizationDataset('train',
  26. self.config.dataset,
  27. self.cache_path)
  28. self.dataset_val = VideoSummarizationDataset('test',
  29. self.config.dataset,
  30. self.cache_path)
  31. def tearDown(self):
  32. shutil.rmtree(self.tmp_dir, ignore_errors=True)
  33. super().tearDown()
  34. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  35. def test_trainer(self):
  36. kwargs = dict(
  37. model=self.model_id,
  38. train_dataset=self.dataset_train,
  39. eval_dataset=self.dataset_val,
  40. work_dir=self.tmp_dir)
  41. trainer = build_trainer(default_args=kwargs)
  42. trainer.train()
  43. results_files = os.listdir(self.tmp_dir)
  44. self.assertIn(f'{trainer.timestamp}.log.json', results_files)
  45. for i in range(2):
  46. self.assertIn(f'epoch_{i+1}.pth', results_files)
  47. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  48. def test_trainer_with_model_and_args(self):
  49. model = PGLVideoSummarization.from_pretrained(self.cache_path)
  50. kwargs = dict(
  51. cfg_file=os.path.join(self.cache_path, ModelFile.CONFIGURATION),
  52. model=model,
  53. train_dataset=self.dataset_train,
  54. eval_dataset=self.dataset_val,
  55. max_epochs=2,
  56. work_dir=self.tmp_dir)
  57. trainer = build_trainer(default_args=kwargs)
  58. trainer.train()
  59. results_files = os.listdir(self.tmp_dir)
  60. self.assertIn(f'{trainer.timestamp}.log.json', results_files)
  61. for i in range(2):
  62. self.assertIn(f'epoch_{i+1}.pth', results_files)
  63. if __name__ == '__main__':
  64. unittest.main()