You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_general_image_classification.py 2.2 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import unittest
  3. from modelscope.pipelines import pipeline
  4. from modelscope.utils.constant import Tasks
  5. from modelscope.utils.demo_utils import DemoCompatibilityCheck
  6. from modelscope.utils.test_utils import test_level
  7. class GeneralImageClassificationTest(unittest.TestCase,
  8. DemoCompatibilityCheck):
  9. def setUp(self) -> None:
  10. self.task = Tasks.image_classification
  11. self.model_id = 'damo/cv_vit-base_image-classification_Dailylife-labels'
  12. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  13. def test_run_ImageNet(self):
  14. general_image_classification = pipeline(
  15. Tasks.image_classification,
  16. model='damo/cv_vit-base_image-classification_ImageNet-labels')
  17. result = general_image_classification('data/test/images/bird.JPEG')
  18. print(result)
  19. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  20. def test_run_Dailylife(self):
  21. general_image_classification = pipeline(
  22. Tasks.image_classification,
  23. model='damo/cv_vit-base_image-classification_Dailylife-labels')
  24. result = general_image_classification('data/test/images/bird.JPEG')
  25. print(result)
  26. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  27. def test_run_nextvit(self):
  28. nexit_image_classification = pipeline(
  29. Tasks.image_classification,
  30. model='damo/cv_nextvit-small_image-classification_Dailylife-labels'
  31. )
  32. result = nexit_image_classification('data/test/images/bird.JPEG')
  33. print(result)
  34. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  35. def test_run_Dailylife_default(self):
  36. general_image_classification = pipeline(Tasks.image_classification)
  37. result = general_image_classification('data/test/images/bird.JPEG')
  38. print(result)
  39. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  40. def test_demo_compatibility(self):
  41. self.compatibility_check()
  42. if __name__ == '__main__':
  43. unittest.main()