huizheng.hz yingda.chen 2 years ago
parent
commit
a1738690c9
1 changed files with 5 additions and 5 deletions
  1. +5
    -5
      tests/trainers/test_image_denoise_trainer.py

+ 5
- 5
tests/trainers/test_image_denoise_trainer.py View File

@@ -34,14 +34,14 @@ class ImageDenoiseTrainerTest(unittest.TestCase):
'SIDD', 'SIDD',
namespace='huizheng', namespace='huizheng',
subset_name='default', subset_name='default',
split='validation',
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)._hf_ds
split='test',
download_mode=DownloadMode.FORCE_REDOWNLOAD)._hf_ds
dataset_val = MsDataset.load( dataset_val = MsDataset.load(
'SIDD', 'SIDD',
namespace='huizheng', namespace='huizheng',
subset_name='default', subset_name='default',
split='test', split='test',
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)._hf_ds
download_mode=DownloadMode.FORCE_REDOWNLOAD)._hf_ds
self.dataset_train = SiddImageDenoisingDataset( self.dataset_train = SiddImageDenoisingDataset(
dataset_train, self.config.dataset, is_train=True) dataset_train, self.config.dataset, is_train=True)
self.dataset_val = SiddImageDenoisingDataset( self.dataset_val = SiddImageDenoisingDataset(
@@ -51,7 +51,7 @@ class ImageDenoiseTrainerTest(unittest.TestCase):
shutil.rmtree(self.tmp_dir, ignore_errors=True) shutil.rmtree(self.tmp_dir, ignore_errors=True)
super().tearDown() super().tearDown()


@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_trainer(self): def test_trainer(self):
kwargs = dict( kwargs = dict(
model=self.model_id, model=self.model_id,
@@ -65,7 +65,7 @@ class ImageDenoiseTrainerTest(unittest.TestCase):
for i in range(2): for i in range(2):
self.assertIn(f'epoch_{i+1}.pth', results_files) self.assertIn(f'epoch_{i+1}.pth', results_files)


@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_trainer_with_model_and_args(self): def test_trainer_with_model_and_args(self):
model = NAFNetForImageDenoise.from_pretrained(self.cache_path) model = NAFNetForImageDenoise.from_pretrained(self.cache_path)
kwargs = dict( kwargs = dict(


Loading…
Cancel
Save