From cc27e3a25e70b2015e5702d6ab24af3197ab1691 Mon Sep 17 00:00:00 2001 From: "qianmu.ywh" Date: Wed, 30 Nov 2022 11:53:40 +0800 Subject: [PATCH] update pipeline according to online demo requirements MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 按在线demo前端的要求,将输出改成单独一个numpy格式的图片 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10912907 --- modelscope/pipelines/cv/image_depth_estimation_pipeline.py | 5 ++++- tests/pipelines/test_image_depth_estimation.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/modelscope/pipelines/cv/image_depth_estimation_pipeline.py b/modelscope/pipelines/cv/image_depth_estimation_pipeline.py index d318ebd2..1f580733 100644 --- a/modelscope/pipelines/cv/image_depth_estimation_pipeline.py +++ b/modelscope/pipelines/cv/image_depth_estimation_pipeline.py @@ -47,6 +47,9 @@ class ImageDepthEstimationPipeline(Pipeline): def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: results = self.model.postprocess(inputs) - outputs = {OutputKeys.DEPTHS: results[OutputKeys.DEPTHS]} + depths = results[OutputKeys.DEPTHS] + if isinstance(depths, torch.Tensor): + depths = depths.detach().cpu().squeeze().numpy() + outputs = {OutputKeys.DEPTHS: depths} return outputs diff --git a/tests/pipelines/test_image_depth_estimation.py b/tests/pipelines/test_image_depth_estimation.py index 856734f8..933ce7a0 100644 --- a/tests/pipelines/test_image_depth_estimation.py +++ b/tests/pipelines/test_image_depth_estimation.py @@ -25,7 +25,7 @@ class ImageDepthEstimationTest(unittest.TestCase, DemoCompatibilityCheck): estimator = pipeline(Tasks.image_depth_estimation, model=self.model_id) result = estimator(input_location) depths = result[OutputKeys.DEPTHS] - depth_viz = depth_to_color(depths[0].squeeze().cpu().numpy()) + depth_viz = depth_to_color(depths) cv2.imwrite('result.jpg', depth_viz) print('test_image_depth_estimation DONE')