You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_image_inpainting_trainer.py 3.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import shutil
  4. import tempfile
  5. import unittest
  6. from modelscope.hub.snapshot_download import snapshot_download
  7. from modelscope.metainfo import Trainers
  8. from modelscope.models.cv.image_inpainting import FFTInpainting
  9. from modelscope.msdatasets import MsDataset
  10. from modelscope.trainers import build_trainer
  11. from modelscope.utils.config import Config, ConfigDict
  12. from modelscope.utils.constant import ModelFile
  13. from modelscope.utils.logger import get_logger
  14. from modelscope.utils.test_utils import test_level
  15. logger = get_logger()
  16. class ImageInpaintingTrainerTest(unittest.TestCase):
  17. def setUp(self):
  18. print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
  19. self.tmp_dir = tempfile.TemporaryDirectory().name
  20. if not os.path.exists(self.tmp_dir):
  21. os.makedirs(self.tmp_dir)
  22. self.model_id = 'damo/cv_fft_inpainting_lama'
  23. self.cache_path = snapshot_download(self.model_id)
  24. cfg = Config.from_file(
  25. os.path.join(self.cache_path, ModelFile.CONFIGURATION))
  26. train_data_cfg = ConfigDict(
  27. name='PlacesToydataset',
  28. split='train',
  29. mask_gen_kwargs=cfg.dataset.mask_gen_kwargs,
  30. out_size=cfg.dataset.train_out_size,
  31. test_mode=False)
  32. test_data_cfg = ConfigDict(
  33. name='PlacesToydataset',
  34. split='test',
  35. mask_gen_kwargs=cfg.dataset.mask_gen_kwargs,
  36. out_size=cfg.dataset.val_out_size,
  37. test_mode=True)
  38. self.train_dataset = MsDataset.load(
  39. dataset_name=train_data_cfg.name,
  40. split=train_data_cfg.split,
  41. mask_gen_kwargs=train_data_cfg.mask_gen_kwargs,
  42. out_size=train_data_cfg.out_size,
  43. test_mode=train_data_cfg.test_mode)
  44. assert next(
  45. iter(self.train_dataset.config_kwargs['split_config'].values()))
  46. self.test_dataset = MsDataset.load(
  47. dataset_name=test_data_cfg.name,
  48. split=test_data_cfg.split,
  49. mask_gen_kwargs=test_data_cfg.mask_gen_kwargs,
  50. out_size=test_data_cfg.out_size,
  51. test_mode=test_data_cfg.test_mode)
  52. assert next(
  53. iter(self.test_dataset.config_kwargs['split_config'].values()))
  54. def tearDown(self):
  55. shutil.rmtree(self.tmp_dir, ignore_errors=True)
  56. super().tearDown()
  57. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  58. def test_trainer(self):
  59. kwargs = dict(
  60. model=self.model_id,
  61. train_dataset=self.train_dataset,
  62. eval_dataset=self.test_dataset)
  63. trainer = build_trainer(
  64. name=Trainers.image_inpainting, default_args=kwargs)
  65. trainer.train()
  66. results_files = os.listdir(trainer.work_dir)
  67. self.assertIn(f'{trainer.timestamp}.log.json', results_files)
  68. if __name__ == '__main__':
  69. unittest.main()