You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_movie_scene_segmentation_trainer.py 3.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import shutil
  4. import tempfile
  5. import unittest
  6. import zipfile
  7. from modelscope.hub.snapshot_download import snapshot_download
  8. from modelscope.metainfo import Trainers
  9. from modelscope.models.cv.movie_scene_segmentation import \
  10. MovieSceneSegmentationModel
  11. from modelscope.msdatasets import MsDataset
  12. from modelscope.trainers import build_trainer
  13. from modelscope.utils.config import Config, ConfigDict
  14. from modelscope.utils.constant import ModelFile
  15. from modelscope.utils.test_utils import test_level
  16. class TestImageInstanceSegmentationTrainer(unittest.TestCase):
  17. model_id = 'damo/cv_resnet50-bert_video-scene-segmentation_movienet'
  18. def setUp(self):
  19. print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
  20. cache_path = snapshot_download(self.model_id)
  21. config_path = os.path.join(cache_path, ModelFile.CONFIGURATION)
  22. cfg = Config.from_file(config_path)
  23. max_epochs = cfg.train.max_epochs
  24. train_data_cfg = ConfigDict(
  25. name='movie_scene_seg_toydata',
  26. split='train',
  27. cfg=cfg.preprocessor,
  28. test_mode=False)
  29. test_data_cfg = ConfigDict(
  30. name='movie_scene_seg_toydata',
  31. split='test',
  32. cfg=cfg.preprocessor,
  33. test_mode=True)
  34. self.train_dataset = MsDataset.load(
  35. dataset_name=train_data_cfg.name,
  36. split=train_data_cfg.split,
  37. namespace=train_data_cfg.namespace,
  38. cfg=train_data_cfg.cfg,
  39. test_mode=train_data_cfg.test_mode)
  40. assert next(
  41. iter(self.train_dataset.config_kwargs['split_config'].values()))
  42. self.test_dataset = MsDataset.load(
  43. dataset_name=test_data_cfg.name,
  44. split=test_data_cfg.split,
  45. namespace=test_data_cfg.namespace,
  46. cfg=test_data_cfg.cfg,
  47. test_mode=test_data_cfg.test_mode)
  48. assert next(
  49. iter(self.test_dataset.config_kwargs['split_config'].values()))
  50. self.max_epochs = max_epochs
  51. self.tmp_dir = tempfile.TemporaryDirectory().name
  52. if not os.path.exists(self.tmp_dir):
  53. os.makedirs(self.tmp_dir)
  54. def tearDown(self):
  55. shutil.rmtree(self.tmp_dir)
  56. super().tearDown()
  57. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  58. def test_trainer(self):
  59. kwargs = dict(
  60. model=self.model_id,
  61. train_dataset=self.train_dataset,
  62. eval_dataset=self.test_dataset,
  63. work_dir=self.tmp_dir)
  64. trainer = build_trainer(
  65. name=Trainers.movie_scene_segmentation, default_args=kwargs)
  66. trainer.train()
  67. results_files = os.listdir(trainer.work_dir)
  68. self.assertIn(f'{trainer.timestamp}.log.json', results_files)
  69. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  70. def test_trainer_with_model_and_args(self):
  71. tmp_dir = tempfile.TemporaryDirectory().name
  72. if not os.path.exists(tmp_dir):
  73. os.makedirs(tmp_dir)
  74. cache_path = snapshot_download(self.model_id)
  75. model = MovieSceneSegmentationModel.from_pretrained(cache_path)
  76. kwargs = dict(
  77. cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION),
  78. model=model,
  79. train_dataset=self.train_dataset,
  80. eval_dataset=self.test_dataset,
  81. work_dir=tmp_dir)
  82. trainer = build_trainer(
  83. name=Trainers.movie_scene_segmentation, default_args=kwargs)
  84. trainer.train()
  85. results_files = os.listdir(trainer.work_dir)
  86. self.assertIn(f'{trainer.timestamp}.log.json', results_files)
  87. if __name__ == '__main__':
  88. unittest.main()