You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_general_image_classification.py 1.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import unittest
  3. from modelscope.pipelines import pipeline
  4. from modelscope.utils.constant import Tasks
  5. from modelscope.utils.demo_utils import DemoCompatibilityCheck
  6. from modelscope.utils.test_utils import test_level
  7. class GeneralImageClassificationTest(unittest.TestCase,
  8. DemoCompatibilityCheck):
  9. def setUp(self) -> None:
  10. self.task = Tasks.image_classification
  11. self.model_id = 'damo/cv_vit-base_image-classification_Dailylife-labels'
  12. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  13. def test_run_ImageNet(self):
  14. general_image_classification = pipeline(
  15. Tasks.image_classification,
  16. model='damo/cv_vit-base_image-classification_ImageNet-labels')
  17. result = general_image_classification('data/test/images/bird.JPEG')
  18. print(result)
  19. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  20. def test_run_Dailylife(self):
  21. general_image_classification = pipeline(
  22. Tasks.image_classification,
  23. model='damo/cv_vit-base_image-classification_Dailylife-labels')
  24. result = general_image_classification('data/test/images/bird.JPEG')
  25. print(result)
  26. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  27. def test_run_Dailylife_default(self):
  28. general_image_classification = pipeline(Tasks.image_classification)
  29. result = general_image_classification('data/test/images/bird.JPEG')
  30. print(result)
  31. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  32. def test_demo_compatibility(self):
  33. self.compatibility_check()
  34. if __name__ == '__main__':
  35. unittest.main()