Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9465050master
@@ -88,9 +88,6 @@ def could_use_op(input): | |||||
if input.device.type != 'cuda': | if input.device.type != 'cuda': | ||||
return False | return False | ||||
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.']): | |||||
return True | |||||
warnings.warn( | warnings.warn( | ||||
f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().' | f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().' | ||||
) | ) | ||||
@@ -70,7 +70,8 @@ TASK_OUTPUTS = { | |||||
Tasks.image_editing: [OutputKeys.OUTPUT_IMG], | Tasks.image_editing: [OutputKeys.OUTPUT_IMG], | ||||
Tasks.image_matting: [OutputKeys.OUTPUT_IMG], | Tasks.image_matting: [OutputKeys.OUTPUT_IMG], | ||||
Tasks.image_generation: [OutputKeys.OUTPUT_IMG], | Tasks.image_generation: [OutputKeys.OUTPUT_IMG], | ||||
Tasks.image_restoration: [OutputKeys.OUTPUT_IMG], | |||||
Tasks.face_image_generation: [OutputKeys.OUTPUT_IMG], | |||||
Tasks.image_super_resolution: [OutputKeys.OUTPUT_IMG], | |||||
# action recognition result for single video | # action recognition result for single video | ||||
# { | # { | ||||
@@ -69,7 +69,11 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
(Pipelines.text_to_image_synthesis, | (Pipelines.text_to_image_synthesis, | ||||
'damo/cv_imagen_text-to-image-synthesis_tiny'), | 'damo/cv_imagen_text-to-image-synthesis_tiny'), | ||||
Tasks.style_transfer: (Pipelines.style_transfer, | Tasks.style_transfer: (Pipelines.style_transfer, | ||||
'damo/cv_aams_style-transfer_damo') | |||||
'damo/cv_aams_style-transfer_damo'), | |||||
Tasks.face_image_generation: (Pipelines.face_image_generation, | |||||
'damo/cv_gan_face-image-generation'), | |||||
Tasks.image_super_resolution: (Pipelines.image_super_resolution, | |||||
'damo/cv_rrdb_image-super-resolution'), | |||||
} | } | ||||
@@ -20,16 +20,17 @@ logger = get_logger() | |||||
@PIPELINES.register_module( | @PIPELINES.register_module( | ||||
Tasks.image_generation, module_name=Pipelines.face_image_generation) | |||||
Tasks.face_image_generation, module_name=Pipelines.face_image_generation) | |||||
class FaceImageGenerationPipeline(Pipeline): | class FaceImageGenerationPipeline(Pipeline): | ||||
def __init__(self, model: str): | def __init__(self, model: str): | ||||
""" | """ | ||||
use `model` and `preprocessor` to create a kws pipeline for prediction | |||||
use `model` to create a kws pipeline for prediction | |||||
Args: | Args: | ||||
model: model id on modelscope hub. | model: model id on modelscope hub. | ||||
""" | """ | ||||
super().__init__(model=model) | super().__init__(model=model) | ||||
self.device = 'cpu' | |||||
self.size = 1024 | self.size = 1024 | ||||
self.latent = 512 | self.latent = 512 | ||||
self.n_mlp = 8 | self.n_mlp = 8 | ||||
@@ -40,7 +41,7 @@ class FaceImageGenerationPipeline(Pipeline): | |||||
self.size, | self.size, | ||||
self.latent, | self.latent, | ||||
self.n_mlp, | self.n_mlp, | ||||
channel_multiplier=self.channel_multiplier) | |||||
channel_multiplier=self.channel_multiplier).to(self.device) | |||||
self.model_file = f'{model}/{ModelFile.TORCH_MODEL_FILE}' | self.model_file = f'{model}/{ModelFile.TORCH_MODEL_FILE}' | ||||
@@ -63,7 +64,7 @@ class FaceImageGenerationPipeline(Pipeline): | |||||
torch.cuda.manual_seed_all(input) | torch.cuda.manual_seed_all(input) | ||||
self.generator.eval() | self.generator.eval() | ||||
with torch.no_grad(): | with torch.no_grad(): | ||||
sample_z = torch.randn(1, self.latent) | |||||
sample_z = torch.randn(1, self.latent).to(self.device) | |||||
sample, _ = self.generator([sample_z], | sample, _ = self.generator([sample_z], | ||||
truncation=self.truncation, | truncation=self.truncation, | ||||
@@ -19,7 +19,7 @@ logger = get_logger() | |||||
@PIPELINES.register_module( | @PIPELINES.register_module( | ||||
Tasks.image_restoration, module_name=Pipelines.image_super_resolution) | |||||
Tasks.image_super_resolution, module_name=Pipelines.image_super_resolution) | |||||
class ImageSuperResolutionPipeline(Pipeline): | class ImageSuperResolutionPipeline(Pipeline): | ||||
def __init__(self, model: str): | def __init__(self, model: str): | ||||
@@ -29,6 +29,7 @@ class ImageSuperResolutionPipeline(Pipeline): | |||||
model: model id on modelscope hub. | model: model id on modelscope hub. | ||||
""" | """ | ||||
super().__init__(model=model) | super().__init__(model=model) | ||||
self.device = 'cpu' | |||||
self.num_feat = 64 | self.num_feat = 64 | ||||
self.num_block = 23 | self.num_block = 23 | ||||
self.scale = 4 | self.scale = 4 | ||||
@@ -38,7 +39,7 @@ class ImageSuperResolutionPipeline(Pipeline): | |||||
num_feat=self.num_feat, | num_feat=self.num_feat, | ||||
num_block=self.num_block, | num_block=self.num_block, | ||||
num_grow_ch=32, | num_grow_ch=32, | ||||
scale=self.scale) | |||||
scale=self.scale).to(self.device) | |||||
model_path = f'{self.model}/{ModelFile.TORCH_MODEL_FILE}' | model_path = f'{self.model}/{ModelFile.TORCH_MODEL_FILE}' | ||||
self.sr_model.load_state_dict(torch.load(model_path), strict=True) | self.sr_model.load_state_dict(torch.load(model_path), strict=True) | ||||
@@ -58,7 +59,8 @@ class ImageSuperResolutionPipeline(Pipeline): | |||||
raise TypeError(f'input should be either str, PIL.Image,' | raise TypeError(f'input should be either str, PIL.Image,' | ||||
f' np.array, but got {type(input)}') | f' np.array, but got {type(input)}') | ||||
img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0) / 255. | |||||
img = torch.from_numpy(img).to(self.device).permute( | |||||
2, 0, 1).unsqueeze(0) / 255. | |||||
result = {'img': img} | result = {'img': img} | ||||
return result | return result | ||||
@@ -27,7 +27,8 @@ class CVTasks(object): | |||||
ocr_detection = 'ocr-detection' | ocr_detection = 'ocr-detection' | ||||
action_recognition = 'action-recognition' | action_recognition = 'action-recognition' | ||||
video_embedding = 'video-embedding' | video_embedding = 'video-embedding' | ||||
image_restoration = 'image-restoration' | |||||
face_image_generation = 'face-image-generation' | |||||
image_super_resolution = 'image-super-resolution' | |||||
style_transfer = 'style-transfer' | style_transfer = 'style-transfer' | ||||
@@ -27,11 +27,17 @@ class FaceGenerationTest(unittest.TestCase): | |||||
def test_run_modelhub(self): | def test_run_modelhub(self): | ||||
seed = 10 | seed = 10 | ||||
face_generation = pipeline( | face_generation = pipeline( | ||||
Tasks.image_generation, | |||||
Tasks.face_image_generation, | |||||
model=self.model_id, | model=self.model_id, | ||||
) | ) | ||||
self.pipeline_inference(face_generation, seed) | self.pipeline_inference(face_generation, seed) | ||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
def test_run_modelhub_default_model(self): | |||||
seed = 10 | |||||
face_generation = pipeline(Tasks.face_image_generation) | |||||
self.pipeline_inference(face_generation, seed) | |||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
unittest.main() | unittest.main() |
@@ -28,10 +28,15 @@ class ImageSuperResolutionTest(unittest.TestCase): | |||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | ||||
def test_run_modelhub(self): | def test_run_modelhub(self): | ||||
super_resolution = pipeline( | super_resolution = pipeline( | ||||
Tasks.image_restoration, model=self.model_id) | |||||
Tasks.image_super_resolution, model=self.model_id) | |||||
self.pipeline_inference(super_resolution, self.img) | self.pipeline_inference(super_resolution, self.img) | ||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
def test_run_modelhub_default_model(self): | |||||
super_resolution = pipeline(Tasks.image_super_resolution) | |||||
self.pipeline_inference(super_resolution, self.img) | |||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
unittest.main() | unittest.main() |