You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_image2image_generation.py 1.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import unittest
  3. from torchvision.utils import save_image
  4. from modelscope.outputs import OutputKeys
  5. from modelscope.pipelines import pipeline
  6. from modelscope.utils.constant import Tasks
  7. from modelscope.utils.test_utils import test_level
  8. class Image2ImageGenerationTest(unittest.TestCase):
  9. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  10. def test_run_modelhub(self):
  11. r"""We provide two generation modes, i.e., Similar Image Generation and Interpolation.
  12. You can pass the following parameters for different mode.
  13. 1. Similar Image Generation Mode:
  14. 2. Interpolation Mode:
  15. """
  16. img2img_gen_pipeline = pipeline(
  17. Tasks.image_to_image_generation,
  18. model='damo/cv_latent_diffusion_image2image_generate')
  19. # Similar Image Generation mode
  20. result1 = img2img_gen_pipeline('data/test/images/img2img_input.jpg')
  21. # Interpolation Mode
  22. result2 = img2img_gen_pipeline(('data/test/images/img2img_input.jpg',
  23. 'data/test/images/img2img_style.jpg'))
  24. save_image(
  25. result1[OutputKeys.OUTPUT_IMG].clamp(-1, 1),
  26. 'result1.jpg',
  27. range=(-1, 1),
  28. normalize=True,
  29. nrow=4)
  30. save_image(
  31. result2[OutputKeys.OUTPUT_IMG].clamp(-1, 1),
  32. 'result2.jpg',
  33. range=(-1, 1),
  34. normalize=True,
  35. nrow=4)
  36. if __name__ == '__main__':
  37. unittest.main()