|
|
@@ -58,7 +58,8 @@ class ImagePortraitEnhancementPipeline(Pipeline): |
|
|
|
|
|
|
|
gpen_model_path = f'{model}/{ModelFile.TORCH_MODEL_FILE}' |
|
|
|
self.face_enhancer.load_state_dict( |
|
|
|
torch.load(gpen_model_path), strict=True) |
|
|
|
torch.load(gpen_model_path, map_location=torch.device('cpu')), |
|
|
|
strict=True) |
|
|
|
|
|
|
|
logger.info('load face enhancer model done') |
|
|
|
|
|
|
@@ -82,7 +83,9 @@ class ImagePortraitEnhancementPipeline(Pipeline): |
|
|
|
|
|
|
|
sr_model_path = f'{model}/super_resolution/realesrnet_x{self.scale}.pth' |
|
|
|
self.sr_model.load_state_dict( |
|
|
|
torch.load(sr_model_path)['params_ema'], strict=True) |
|
|
|
torch.load(sr_model_path, |
|
|
|
map_location=torch.device('cpu'))['params_ema'], |
|
|
|
strict=True) |
|
|
|
|
|
|
|
logger.info('load sr model done') |
|
|
|
|
|
|
|