You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_base.py 3.6 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import unittest
  3. from typing import Any, Dict, List, Tuple, Union
  4. import numpy as np
  5. import PIL
  6. from maas_lib.pipelines import Pipeline, pipeline
  7. from maas_lib.pipelines.builder import PIPELINES
  8. from maas_lib.utils.constant import Tasks
  9. from maas_lib.utils.logger import get_logger
  10. from maas_lib.utils.registry import default_group
  11. logger = get_logger()
  12. Input = Union[str, 'PIL.Image', 'numpy.ndarray']
  13. class CustomPipelineTest(unittest.TestCase):
  14. def test_abstract(self):
  15. @PIPELINES.register_module()
  16. class CustomPipeline1(Pipeline):
  17. def __init__(self,
  18. config_file: str = None,
  19. model=None,
  20. preprocessor=None,
  21. **kwargs):
  22. super().__init__(config_file, model, preprocessor, **kwargs)
  23. with self.assertRaises(TypeError):
  24. CustomPipeline1()
  25. def test_custom(self):
  26. @PIPELINES.register_module(
  27. group_key=Tasks.image_tagging, module_name='custom-image')
  28. class CustomImagePipeline(Pipeline):
  29. def __init__(self,
  30. config_file: str = None,
  31. model=None,
  32. preprocessor=None,
  33. **kwargs):
  34. super().__init__(config_file, model, preprocessor, **kwargs)
  35. def preprocess(self, input: Union[str,
  36. 'PIL.Image']) -> Dict[str, Any]:
  37. """ Provide default implementation based on preprocess_cfg and user can reimplement it
  38. """
  39. if not isinstance(input, PIL.Image.Image):
  40. from maas_lib.preprocessors import load_image
  41. data_dict = {'img': load_image(input), 'url': input}
  42. else:
  43. data_dict = {'img': input}
  44. return data_dict
  45. def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  46. """ Provide default implementation using self.model and user can reimplement it
  47. """
  48. outputs = {}
  49. if 'url' in inputs:
  50. outputs['filename'] = inputs['url']
  51. img = inputs['img']
  52. new_image = img.resize((img.width // 2, img.height // 2))
  53. outputs['resize_image'] = np.array(new_image)
  54. outputs['dummy_result'] = 'dummy_result'
  55. return outputs
  56. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  57. return inputs
  58. self.assertTrue('custom-image' in PIPELINES.modules[default_group])
  59. pipe = pipeline(pipeline_name='custom-image')
  60. pipe2 = pipeline(Tasks.image_tagging)
  61. self.assertTrue(type(pipe) is type(pipe2))
  62. img_url = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.' \
  63. 'aliyuncs.com/data/test/images/image1.jpg'
  64. output = pipe(img_url)
  65. self.assertEqual(output['filename'], img_url)
  66. self.assertEqual(output['resize_image'].shape, (318, 512, 3))
  67. self.assertEqual(output['dummy_result'], 'dummy_result')
  68. outputs = pipe([img_url for i in range(4)])
  69. self.assertEqual(len(outputs), 4)
  70. for out in outputs:
  71. self.assertEqual(out['filename'], img_url)
  72. self.assertEqual(out['resize_image'].shape, (318, 512, 3))
  73. self.assertEqual(out['dummy_result'], 'dummy_result')
  74. if __name__ == '__main__':
  75. unittest.main()

致力于通过开放的社区合作,开源AI模型以及相关创新技术,推动基于模型即服务的生态繁荣发展