You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_visual_question_answering.py 2.5 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import unittest
  3. from modelscope.hub.snapshot_download import snapshot_download
  4. from modelscope.models import Model
  5. from modelscope.models.multi_modal import MPlugForVisualQuestionAnswering
  6. from modelscope.pipelines import pipeline
  7. from modelscope.pipelines.multi_modal import VisualQuestionAnsweringPipeline
  8. from modelscope.preprocessors import MPlugVisualQuestionAnsweringPreprocessor
  9. from modelscope.utils.constant import Tasks
  10. from modelscope.utils.test_utils import test_level
  11. class VisualQuestionAnsweringTest(unittest.TestCase):
  12. model_id = 'damo/mplug_visual-question-answering_coco_large_en'
  13. input_vqa = {
  14. 'image': 'data/test/images/image_mplug_vqa.jpg',
  15. 'question': 'What is the woman doing?',
  16. }
  17. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  18. def test_run(self):
  19. cache_path = snapshot_download(self.model_id)
  20. preprocessor = MPlugVisualQuestionAnsweringPreprocessor(cache_path)
  21. model = MPlugForVisualQuestionAnswering(cache_path)
  22. pipeline1 = VisualQuestionAnsweringPipeline(
  23. model, preprocessor=preprocessor)
  24. pipeline2 = pipeline(
  25. Tasks.visual_question_answering,
  26. model=model,
  27. preprocessor=preprocessor)
  28. print(f"question: {self.input_vqa['question']}")
  29. print(f'pipeline1: {pipeline1(self.input_vqa)}')
  30. print(f'pipeline2: {pipeline2(self.input_vqa)}')
  31. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  32. def test_run_with_model_from_modelhub(self):
  33. model = Model.from_pretrained(self.model_id)
  34. preprocessor = MPlugVisualQuestionAnsweringPreprocessor(
  35. model.model_dir)
  36. pipeline_vqa = pipeline(
  37. task=Tasks.visual_question_answering,
  38. model=model,
  39. preprocessor=preprocessor)
  40. print(pipeline_vqa(self.input_vqa))
  41. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  42. def test_run_with_model_name(self):
  43. pipeline_vqa = pipeline(
  44. Tasks.visual_question_answering, model=self.model_id)
  45. print(pipeline_vqa(self.input_vqa))
  46. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  47. def test_run_with_default_model(self):
  48. pipeline_vqa = pipeline(task=Tasks.visual_question_answering)
  49. print(pipeline_vqa(self.input_vqa))
  50. if __name__ == '__main__':
  51. unittest.main()