You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_builder.py 2.2 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import unittest
  3. from asyncio import Task
  4. from typing import Any, Dict, List, Tuple, Union
  5. import numpy as np
  6. import PIL
  7. from modelscope.models.base import Model
  8. from modelscope.pipelines import Pipeline, pipeline
  9. from modelscope.pipelines.builder import PIPELINES, add_default_pipeline_info
  10. from modelscope.utils.constant import Tasks
  11. from modelscope.utils.logger import get_logger
  12. from modelscope.utils.registry import default_group
  13. logger = get_logger()
  14. @PIPELINES.register_module(
  15. group_key=Tasks.image_tagging, module_name='custom_single_model')
  16. class CustomSingleModelPipeline(Pipeline):
  17. def __init__(self,
  18. config_file: str = None,
  19. model: List[Union[str, Model]] = None,
  20. preprocessor=None,
  21. **kwargs):
  22. super().__init__(config_file, model, preprocessor, **kwargs)
  23. assert isinstance(model, str), 'model is not str'
  24. print(model)
  25. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  26. return super().postprocess(inputs)
  27. @PIPELINES.register_module(
  28. group_key=Tasks.image_tagging, module_name='model1_model2')
  29. class CustomMultiModelPipeline(Pipeline):
  30. def __init__(self,
  31. config_file: str = None,
  32. model: List[Union[str, Model]] = None,
  33. preprocessor=None,
  34. **kwargs):
  35. super().__init__(config_file, model, preprocessor, **kwargs)
  36. assert isinstance(model, list), 'model is not list'
  37. for m in model:
  38. assert isinstance(m, str), 'submodel is not str'
  39. print(m)
  40. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  41. return super().postprocess(inputs)
  42. class PipelineInterfaceTest(unittest.TestCase):
  43. def test_single_model(self):
  44. pipe = pipeline(Tasks.image_tagging, model='custom_single_model')
  45. assert isinstance(pipe, CustomSingleModelPipeline)
  46. def test_multi_model(self):
  47. pipe = pipeline(Tasks.image_tagging, model=['model1', 'model2'])
  48. assert isinstance(pipe, CustomMultiModelPipeline)
  49. if __name__ == '__main__':
  50. unittest.main()

致力于通过开放的社区合作,开源AI模型以及相关创新技术,推动基于模型即服务的生态繁荣发展