Browse Source

[to #42322933]cv/cvdet_fix_outputs->master fix outputs

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10421413

    * fix outputs
master
wendi.hwd yingda.chen 3 years ago
parent
commit
674e1a7878
2 changed files with 10 additions and 18 deletions
  1. +6
    -2
      modelscope/pipelines/cv/image_detection_pipeline.py
  2. +4
    -16
      tests/pipelines/test_object_detection.py

+ 6
- 2
modelscope/pipelines/cv/image_detection_pipeline.py View File

@@ -43,11 +43,15 @@ class ImageDetectionPipeline(Pipeline):

bboxes, scores, labels = self.model.postprocess(inputs['data'])
if bboxes is None:
return None
outputs = {
OutputKeys.SCORES: [],
OutputKeys.LABELS: [],
OutputKeys.BOXES: []
}
return outputs
outputs = {
OutputKeys.SCORES: scores,
OutputKeys.LABELS: labels,
OutputKeys.BOXES: bboxes
}

return outputs

+ 4
- 16
tests/pipelines/test_object_detection.py View File

@@ -19,20 +19,14 @@ class ObjectDetectionTest(unittest.TestCase, DemoCompatibilityCheck):
model_id = 'damo/cv_vit_object-detection_coco'
object_detect = pipeline(Tasks.image_object_detection, model=model_id)
result = object_detect(input_location)
if result:
print(result)
else:
raise ValueError('process error')
print(result)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_object_detection_with_default_task(self):
input_location = 'data/test/images/image_detection.jpg'
object_detect = pipeline(Tasks.image_object_detection)
result = object_detect(input_location)
if result:
print(result)
else:
raise ValueError('process error')
print(result)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_human_detection(self):
@@ -40,20 +34,14 @@ class ObjectDetectionTest(unittest.TestCase, DemoCompatibilityCheck):
model_id = 'damo/cv_resnet18_human-detection'
human_detect = pipeline(Tasks.human_detection, model=model_id)
result = human_detect(input_location)
if result:
print(result)
else:
raise ValueError('process error')
print(result)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_human_detection_with_default_task(self):
input_location = 'data/test/images/image_detection.jpg'
human_detect = pipeline(Tasks.human_detection)
result = human_detect(input_location)
if result:
print(result)
else:
raise ValueError('process error')
print(result)

@unittest.skip('demo compatibility test is only enabled on a needed-basis')
def test_demo_compatibility(self):


Loading…
Cancel
Save