|
|
@@ -19,7 +19,7 @@ logger = get_logger() |
|
|
|
|
|
|
|
|
|
|
|
@PIPELINES.register_module( |
|
|
|
Tasks.image_restoration, module_name=Pipelines.image_super_resolution) |
|
|
|
Tasks.image_super_resolution, module_name=Pipelines.image_super_resolution) |
|
|
|
class ImageSuperResolutionPipeline(Pipeline): |
|
|
|
|
|
|
|
def __init__(self, model: str): |
|
|
@@ -29,6 +29,7 @@ class ImageSuperResolutionPipeline(Pipeline): |
|
|
|
model: model id on modelscope hub. |
|
|
|
""" |
|
|
|
super().__init__(model=model) |
|
|
|
self.device = 'cpu' |
|
|
|
self.num_feat = 64 |
|
|
|
self.num_block = 23 |
|
|
|
self.scale = 4 |
|
|
@@ -38,7 +39,7 @@ class ImageSuperResolutionPipeline(Pipeline): |
|
|
|
num_feat=self.num_feat, |
|
|
|
num_block=self.num_block, |
|
|
|
num_grow_ch=32, |
|
|
|
scale=self.scale) |
|
|
|
scale=self.scale).to(self.device) |
|
|
|
|
|
|
|
model_path = f'{self.model}/{ModelFile.TORCH_MODEL_FILE}' |
|
|
|
self.sr_model.load_state_dict(torch.load(model_path), strict=True) |
|
|
@@ -58,7 +59,8 @@ class ImageSuperResolutionPipeline(Pipeline): |
|
|
|
raise TypeError(f'input should be either str, PIL.Image,' |
|
|
|
f' np.array, but got {type(input)}') |
|
|
|
|
|
|
|
img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0) / 255. |
|
|
|
img = torch.from_numpy(img).to(self.device).permute( |
|
|
|
2, 0, 1).unsqueeze(0) / 255. |
|
|
|
result = {'img': img} |
|
|
|
|
|
|
|
return result |
|
|
|