You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_base.py 3.4 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import unittest
  3. from typing import Any, Dict, Union
  4. import numpy as np
  5. from PIL import Image
  6. from modelscope.outputs import OutputKeys
  7. from modelscope.pipelines import Pipeline, pipeline
  8. from modelscope.pipelines.builder import PIPELINES, add_default_pipeline_info
  9. from modelscope.utils.logger import get_logger
  10. logger = get_logger()
  11. Input = Union[str, 'PIL.Image', 'numpy.ndarray']
  12. class CustomPipelineTest(unittest.TestCase):
  13. def test_abstract(self):
  14. @PIPELINES.register_module()
  15. class CustomPipeline1(Pipeline):
  16. def __init__(self,
  17. config_file: str = None,
  18. model=None,
  19. preprocessor=None,
  20. **kwargs):
  21. super().__init__(config_file, model, preprocessor, **kwargs)
  22. with self.assertRaises(TypeError):
  23. CustomPipeline1()
  24. def test_custom(self):
  25. dummy_task = 'dummy-task'
  26. @PIPELINES.register_module(
  27. group_key=dummy_task, module_name='custom-image')
  28. class CustomImagePipeline(Pipeline):
  29. def __init__(self,
  30. config_file: str = None,
  31. model=None,
  32. preprocessor=None,
  33. **kwargs):
  34. super().__init__(config_file, model, preprocessor, **kwargs)
  35. def preprocess(self, input: Union[str,
  36. 'PIL.Image']) -> Dict[str, Any]:
  37. """ Provide default implementation based on preprocess_cfg and user can reimplement it
  38. """
  39. if not isinstance(input, Image.Image):
  40. from modelscope.preprocessors import load_image
  41. data_dict = {'img': load_image(input), 'url': input}
  42. else:
  43. data_dict = {'img': input}
  44. return data_dict
  45. def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  46. """ Provide default implementation using self.model and user can reimplement it
  47. """
  48. outputs = {}
  49. if 'url' in inputs:
  50. outputs['filename'] = inputs['url']
  51. img = inputs['img']
  52. new_image = img.resize((img.width // 2, img.height // 2))
  53. outputs[OutputKeys.OUTPUT_IMG] = np.array(new_image)
  54. return outputs
  55. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  56. return inputs
  57. self.assertTrue('custom-image' in PIPELINES.modules[dummy_task])
  58. add_default_pipeline_info(dummy_task, 'custom-image', overwrite=True)
  59. pipe = pipeline(task=dummy_task, pipeline_name='custom-image')
  60. pipe2 = pipeline(dummy_task)
  61. self.assertTrue(type(pipe) is type(pipe2))
  62. img_url = 'data/test/images/dogs.jpg'
  63. output = pipe(img_url)
  64. self.assertEqual(output['filename'], img_url)
  65. self.assertEqual(output[OutputKeys.OUTPUT_IMG].shape, (318, 512, 3))
  66. outputs = pipe([img_url for i in range(4)])
  67. self.assertEqual(len(outputs), 4)
  68. for out in outputs:
  69. self.assertEqual(out['filename'], img_url)
  70. self.assertEqual(out[OutputKeys.OUTPUT_IMG].shape, (318, 512, 3))
  71. if __name__ == '__main__':
  72. unittest.main()