You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_image2image_generation.py 1.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import unittest
  3. from torchvision.utils import save_image
  4. from modelscope.pipelines import pipeline
  5. from modelscope.utils.constant import Tasks
  6. from modelscope.utils.test_utils import test_level
  7. class Image2ImageGenerationTest(unittest.TestCase):
  8. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  9. def test_run_modelhub(self):
  10. r"""We provide two generation modes, i.e., Similar Image Generation and Interpolation.
  11. You can pass the following parameters for different mode.
  12. 1. Similar Image Generation Mode:
  13. 2. Interpolation Mode:
  14. """
  15. img2img_gen_pipeline = pipeline(
  16. Tasks.image_to_image_generation,
  17. model='damo/cv_latent_diffusion_image2image_generate')
  18. # Similar Image Generation mode
  19. result1 = img2img_gen_pipeline('data/test/images/img2img_input.jpg')
  20. # Interpolation Mode
  21. result2 = img2img_gen_pipeline(('data/test/images/img2img_input.jpg',
  22. 'data/test/images/img2img_style.jpg'))
  23. save_image(
  24. result1['output_img'].clamp(-1, 1),
  25. 'result1.jpg',
  26. range=(-1, 1),
  27. normalize=True,
  28. nrow=4)
  29. save_image(
  30. result2['output_img'].clamp(-1, 1),
  31. 'result2.jpg',
  32. range=(-1, 1),
  33. normalize=True,
  34. nrow=4)
  35. if __name__ == '__main__':
  36. unittest.main()