|
|
@@ -34,14 +34,14 @@ class ImageDenoiseTrainerTest(unittest.TestCase): |
|
|
|
'SIDD', |
|
|
|
namespace='huizheng', |
|
|
|
subset_name='default', |
|
|
|
split='validation', |
|
|
|
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)._hf_ds |
|
|
|
split='test', |
|
|
|
download_mode=DownloadMode.FORCE_REDOWNLOAD)._hf_ds |
|
|
|
dataset_val = MsDataset.load( |
|
|
|
'SIDD', |
|
|
|
namespace='huizheng', |
|
|
|
subset_name='default', |
|
|
|
split='test', |
|
|
|
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)._hf_ds |
|
|
|
download_mode=DownloadMode.FORCE_REDOWNLOAD)._hf_ds |
|
|
|
self.dataset_train = SiddImageDenoisingDataset( |
|
|
|
dataset_train, self.config.dataset, is_train=True) |
|
|
|
self.dataset_val = SiddImageDenoisingDataset( |
|
|
@@ -51,7 +51,7 @@ class ImageDenoiseTrainerTest(unittest.TestCase): |
|
|
|
shutil.rmtree(self.tmp_dir, ignore_errors=True) |
|
|
|
super().tearDown() |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') |
|
|
|
def test_trainer(self): |
|
|
|
kwargs = dict( |
|
|
|
model=self.model_id, |
|
|
@@ -65,7 +65,7 @@ class ImageDenoiseTrainerTest(unittest.TestCase): |
|
|
|
for i in range(2): |
|
|
|
self.assertIn(f'epoch_{i+1}.pth', results_files) |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
|
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') |
|
|
|
def test_trainer_with_model_and_args(self): |
|
|
|
model = NAFNetForImageDenoise.from_pretrained(self.cache_path) |
|
|
|
kwargs = dict( |
|
|
|