You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_base.py 3.5 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import unittest
  3. from typing import Any, Dict, List, Tuple, Union
  4. import numpy as np
  5. import PIL
  6. from modelscope.pipelines import Pipeline, pipeline
  7. from modelscope.pipelines.builder import PIPELINES, add_default_pipeline_info
  8. from modelscope.utils.constant import Tasks
  9. from modelscope.utils.logger import get_logger
  10. from modelscope.utils.registry import default_group
  11. logger = get_logger()
  12. Input = Union[str, 'PIL.Image', 'numpy.ndarray']
  13. class CustomPipelineTest(unittest.TestCase):
  14. def test_abstract(self):
  15. @PIPELINES.register_module()
  16. class CustomPipeline1(Pipeline):
  17. def __init__(self,
  18. config_file: str = None,
  19. model=None,
  20. preprocessor=None,
  21. **kwargs):
  22. super().__init__(config_file, model, preprocessor, **kwargs)
  23. with self.assertRaises(TypeError):
  24. CustomPipeline1()
  25. def test_custom(self):
  26. dummy_task = 'dummy-task'
  27. @PIPELINES.register_module(
  28. group_key=dummy_task, module_name='custom-image')
  29. class CustomImagePipeline(Pipeline):
  30. def __init__(self,
  31. config_file: str = None,
  32. model=None,
  33. preprocessor=None,
  34. **kwargs):
  35. super().__init__(config_file, model, preprocessor, **kwargs)
  36. def preprocess(self, input: Union[str,
  37. 'PIL.Image']) -> Dict[str, Any]:
  38. """ Provide default implementation based on preprocess_cfg and user can reimplement it
  39. """
  40. if not isinstance(input, PIL.Image.Image):
  41. from modelscope.preprocessors import load_image
  42. data_dict = {'img': load_image(input), 'url': input}
  43. else:
  44. data_dict = {'img': input}
  45. return data_dict
  46. def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  47. """ Provide default implementation using self.model and user can reimplement it
  48. """
  49. outputs = {}
  50. if 'url' in inputs:
  51. outputs['filename'] = inputs['url']
  52. img = inputs['img']
  53. new_image = img.resize((img.width // 2, img.height // 2))
  54. outputs['output_png'] = np.array(new_image)
  55. return outputs
  56. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  57. return inputs
  58. self.assertTrue('custom-image' in PIPELINES.modules[default_group])
  59. add_default_pipeline_info(dummy_task, 'custom-image', overwrite=True)
  60. pipe = pipeline(pipeline_name='custom-image')
  61. pipe2 = pipeline(dummy_task)
  62. self.assertTrue(type(pipe) is type(pipe2))
  63. img_url = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.' \
  64. 'aliyuncs.com/data/test/images/image1.jpg'
  65. output = pipe(img_url)
  66. self.assertEqual(output['filename'], img_url)
  67. self.assertEqual(output['output_png'].shape, (318, 512, 3))
  68. outputs = pipe([img_url for i in range(4)])
  69. self.assertEqual(len(outputs), 4)
  70. for out in outputs:
  71. self.assertEqual(out['filename'], img_url)
  72. self.assertEqual(out['output_png'].shape, (318, 512, 3))
  73. if __name__ == '__main__':
  74. unittest.main()