You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_image_portrait_enhancement_trainer.py 3.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import os.path as osp
  4. import shutil
  5. import tempfile
  6. import unittest
  7. from typing import Callable, List, Optional, Tuple, Union
  8. import cv2
  9. import torch
  10. from torch.utils import data as data
  11. from modelscope.hub.snapshot_download import snapshot_download
  12. from modelscope.metainfo import Trainers
  13. from modelscope.models.cv.image_portrait_enhancement import \
  14. ImagePortraitEnhancement
  15. from modelscope.msdatasets import MsDataset
  16. from modelscope.msdatasets.task_datasets.image_portrait_enhancement import \
  17. ImagePortraitEnhancementDataset
  18. from modelscope.trainers import build_trainer
  19. from modelscope.utils.constant import DownloadMode, ModelFile
  20. from modelscope.utils.test_utils import test_level
  21. class TestImagePortraitEnhancementTrainer(unittest.TestCase):
  22. def setUp(self):
  23. print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
  24. self.tmp_dir = tempfile.TemporaryDirectory().name
  25. if not os.path.exists(self.tmp_dir):
  26. os.makedirs(self.tmp_dir)
  27. self.model_id = 'damo/cv_gpen_image-portrait-enhancement'
  28. dataset_train = MsDataset.load(
  29. 'image-portrait-enhancement-dataset',
  30. namespace='modelscope',
  31. subset_name='default',
  32. split='test',
  33. download_mode=DownloadMode.FORCE_REDOWNLOAD)._hf_ds
  34. dataset_val = MsDataset.load(
  35. 'image-portrait-enhancement-dataset',
  36. namespace='modelscope',
  37. subset_name='default',
  38. split='test',
  39. download_mode=DownloadMode.FORCE_REDOWNLOAD)._hf_ds
  40. self.dataset_train = ImagePortraitEnhancementDataset(
  41. dataset_train, is_train=True)
  42. self.dataset_val = ImagePortraitEnhancementDataset(
  43. dataset_val, is_train=False)
  44. def tearDown(self):
  45. shutil.rmtree(self.tmp_dir, ignore_errors=True)
  46. super().tearDown()
  47. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  48. def test_trainer(self):
  49. kwargs = dict(
  50. model=self.model_id,
  51. train_dataset=self.dataset_train,
  52. eval_dataset=self.dataset_val,
  53. device='gpu',
  54. work_dir=self.tmp_dir)
  55. trainer = build_trainer(
  56. name=Trainers.image_portrait_enhancement, default_args=kwargs)
  57. trainer.train()
  58. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  59. def test_trainer_with_model_and_args(self):
  60. tmp_dir = tempfile.TemporaryDirectory().name
  61. if not os.path.exists(tmp_dir):
  62. os.makedirs(tmp_dir)
  63. cache_path = snapshot_download(self.model_id)
  64. model = ImagePortraitEnhancement.from_pretrained(cache_path)
  65. kwargs = dict(
  66. cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION),
  67. model=model,
  68. train_dataset=self.dataset_train,
  69. eval_dataset=self.dataset_val,
  70. device='gpu',
  71. max_epochs=2,
  72. work_dir=self.tmp_dir)
  73. trainer = build_trainer(
  74. name=Trainers.image_portrait_enhancement, default_args=kwargs)
  75. trainer.train()
  76. if __name__ == '__main__':
  77. unittest.main()