You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_base.py 3.5 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import unittest
  3. from typing import Any, Dict, List, Tuple, Union
  4. import numpy as np
  5. import PIL
  6. from modelscope.pipelines import Pipeline, pipeline
  7. from modelscope.pipelines.builder import PIPELINES, add_default_pipeline_info
  8. from modelscope.pipelines.outputs import OutputKeys
  9. from modelscope.utils.constant import Tasks
  10. from modelscope.utils.logger import get_logger
  11. from modelscope.utils.registry import default_group
  12. logger = get_logger()
  13. Input = Union[str, 'PIL.Image', 'numpy.ndarray']
  14. class CustomPipelineTest(unittest.TestCase):
  15. def test_abstract(self):
  16. @PIPELINES.register_module()
  17. class CustomPipeline1(Pipeline):
  18. def __init__(self,
  19. config_file: str = None,
  20. model=None,
  21. preprocessor=None,
  22. **kwargs):
  23. super().__init__(config_file, model, preprocessor, **kwargs)
  24. with self.assertRaises(TypeError):
  25. CustomPipeline1()
  26. def test_custom(self):
  27. dummy_task = 'dummy-task'
  28. @PIPELINES.register_module(
  29. group_key=dummy_task, module_name='custom-image')
  30. class CustomImagePipeline(Pipeline):
  31. def __init__(self,
  32. config_file: str = None,
  33. model=None,
  34. preprocessor=None,
  35. **kwargs):
  36. super().__init__(config_file, model, preprocessor, **kwargs)
  37. def preprocess(self, input: Union[str,
  38. 'PIL.Image']) -> Dict[str, Any]:
  39. """ Provide default implementation based on preprocess_cfg and user can reimplement it
  40. """
  41. if not isinstance(input, PIL.Image.Image):
  42. from modelscope.preprocessors import load_image
  43. data_dict = {'img': load_image(input), 'url': input}
  44. else:
  45. data_dict = {'img': input}
  46. return data_dict
  47. def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  48. """ Provide default implementation using self.model and user can reimplement it
  49. """
  50. outputs = {}
  51. if 'url' in inputs:
  52. outputs['filename'] = inputs['url']
  53. img = inputs['img']
  54. new_image = img.resize((img.width // 2, img.height // 2))
  55. outputs[OutputKeys.OUTPUT_IMG] = np.array(new_image)
  56. return outputs
  57. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  58. return inputs
  59. self.assertTrue('custom-image' in PIPELINES.modules[dummy_task])
  60. add_default_pipeline_info(dummy_task, 'custom-image', overwrite=True)
  61. pipe = pipeline(task=dummy_task, pipeline_name='custom-image')
  62. pipe2 = pipeline(dummy_task)
  63. self.assertTrue(type(pipe) is type(pipe2))
  64. img_url = 'data/test/images/dogs.jpg'
  65. output = pipe(img_url)
  66. self.assertEqual(output['filename'], img_url)
  67. self.assertEqual(output[OutputKeys.OUTPUT_IMG].shape, (318, 512, 3))
  68. outputs = pipe([img_url for i in range(4)])
  69. self.assertEqual(len(outputs), 4)
  70. for out in outputs:
  71. self.assertEqual(out['filename'], img_url)
  72. self.assertEqual(out[OutputKeys.OUTPUT_IMG].shape, (318, 512, 3))
  73. if __name__ == '__main__':
  74. unittest.main()