baiguan.yt yingda.chen 3 years ago
parent
commit
60b1539e5d
3 changed files with 9 additions and 5 deletions
  1. +2
    -2
      modelscope/models/cv/image_portrait_enhancement/eqface/fqa.py
  2. +2
    -1
      modelscope/pipelines/cv/image_colorization_pipeline.py
  3. +5
    -2
      modelscope/pipelines/cv/image_portrait_enhancement_pipeline.py

+ 2
- 2
modelscope/models/cv/image_portrait_enhancement/eqface/fqa.py View File

@@ -45,8 +45,8 @@ class FQA(object):
model.load_state_dict(model_dict)

def get_face_quality(self, img):
img = torch.from_numpy(img).permute(2, 0,
1).unsqueeze(0).flip(1).cuda()
img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).flip(1).to(
self.device)
img = (img - 127.5) / 128.0

# extract features & predict quality


+ 2
- 1
modelscope/pipelines/cv/image_colorization_pipeline.py View File

@@ -36,7 +36,6 @@ class ImageColorizationPipeline(Pipeline):
self.device = torch.device('cuda')
else:
self.device = torch.device('cpu')
self.size = 1024

self.orig_img = None
self.model_type = 'stable'
@@ -91,6 +90,8 @@ class ImageColorizationPipeline(Pipeline):
img = LoadImage.convert_to_img(input).convert('LA').convert('RGB')

self.wide, self.height = img.size
if self.wide * self.height < 100000:
self.size = 256
self.orig_img = img.copy()
img = img.resize((self.size, self.size), resample=PIL.Image.BILINEAR)



+ 5
- 2
modelscope/pipelines/cv/image_portrait_enhancement_pipeline.py View File

@@ -58,7 +58,8 @@ class ImagePortraitEnhancementPipeline(Pipeline):

gpen_model_path = f'{model}/{ModelFile.TORCH_MODEL_FILE}'
self.face_enhancer.load_state_dict(
torch.load(gpen_model_path), strict=True)
torch.load(gpen_model_path, map_location=torch.device('cpu')),
strict=True)

logger.info('load face enhancer model done')

@@ -82,7 +83,9 @@ class ImagePortraitEnhancementPipeline(Pipeline):

sr_model_path = f'{model}/super_resolution/realesrnet_x{self.scale}.pth'
self.sr_model.load_state_dict(
torch.load(sr_model_path)['params_ema'], strict=True)
torch.load(sr_model_path,
map_location=torch.device('cpu'))['params_ema'],
strict=True)

logger.info('load sr model done')



Loading…
Cancel
Save