You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_image_denoise_trainer.py 3.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import shutil
  4. import tempfile
  5. import unittest
  6. from modelscope.hub.snapshot_download import snapshot_download
  7. from modelscope.models.cv.image_denoise import NAFNetForImageDenoise
  8. from modelscope.msdatasets import MsDataset
  9. from modelscope.msdatasets.task_datasets.sidd_image_denoising import \
  10. SiddImageDenoisingDataset
  11. from modelscope.trainers import build_trainer
  12. from modelscope.utils.config import Config
  13. from modelscope.utils.constant import DownloadMode, ModelFile
  14. from modelscope.utils.logger import get_logger
  15. from modelscope.utils.test_utils import test_level
  16. logger = get_logger()
  17. class ImageDenoiseTrainerTest(unittest.TestCase):
  18. def setUp(self):
  19. print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
  20. self.tmp_dir = tempfile.TemporaryDirectory().name
  21. if not os.path.exists(self.tmp_dir):
  22. os.makedirs(self.tmp_dir)
  23. self.model_id = 'damo/cv_nafnet_image-denoise_sidd'
  24. self.cache_path = snapshot_download(self.model_id)
  25. self.config = Config.from_file(
  26. os.path.join(self.cache_path, ModelFile.CONFIGURATION))
  27. dataset_train = MsDataset.load(
  28. 'SIDD',
  29. namespace='huizheng',
  30. subset_name='default',
  31. split='test',
  32. download_mode=DownloadMode.FORCE_REDOWNLOAD)._hf_ds
  33. dataset_val = MsDataset.load(
  34. 'SIDD',
  35. namespace='huizheng',
  36. subset_name='default',
  37. split='test',
  38. download_mode=DownloadMode.FORCE_REDOWNLOAD)._hf_ds
  39. self.dataset_train = SiddImageDenoisingDataset(
  40. dataset_train, self.config.dataset, is_train=True)
  41. self.dataset_val = SiddImageDenoisingDataset(
  42. dataset_val, self.config.dataset, is_train=False)
  43. def tearDown(self):
  44. shutil.rmtree(self.tmp_dir, ignore_errors=True)
  45. super().tearDown()
  46. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  47. def test_trainer(self):
  48. kwargs = dict(
  49. model=self.model_id,
  50. train_dataset=self.dataset_train,
  51. eval_dataset=self.dataset_val,
  52. work_dir=self.tmp_dir)
  53. trainer = build_trainer(default_args=kwargs)
  54. trainer.train()
  55. results_files = os.listdir(self.tmp_dir)
  56. self.assertIn(f'{trainer.timestamp}.log.json', results_files)
  57. for i in range(1):
  58. self.assertIn(f'epoch_{i+1}.pth', results_files)
  59. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  60. def test_trainer_with_model_and_args(self):
  61. model = NAFNetForImageDenoise.from_pretrained(self.cache_path)
  62. kwargs = dict(
  63. cfg_file=os.path.join(self.cache_path, ModelFile.CONFIGURATION),
  64. model=model,
  65. train_dataset=self.dataset_train,
  66. eval_dataset=self.dataset_val,
  67. max_epochs=1,
  68. work_dir=self.tmp_dir)
  69. trainer = build_trainer(default_args=kwargs)
  70. trainer.train()
  71. results_files = os.listdir(self.tmp_dir)
  72. self.assertIn(f'{trainer.timestamp}.log.json', results_files)
  73. for i in range(1):
  74. self.assertIn(f'epoch_{i+1}.pth', results_files)
  75. if __name__ == '__main__':
  76. unittest.main()