From 5fcb8e2342c17b6c1266cda55402ced9831bc009 Mon Sep 17 00:00:00 2001 From: "dangwei.ldw" Date: Tue, 9 Aug 2022 17:46:24 +0800 Subject: [PATCH] [to #42322933]fix onnx thread error Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9669460 * fix onnx thread error --- .../models/cv/product_retrieval_embedding/item_detection.py | 6 +++++- tests/pipelines/test_product_retrieval_embedding.py | 4 ++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/modelscope/models/cv/product_retrieval_embedding/item_detection.py b/modelscope/models/cv/product_retrieval_embedding/item_detection.py index 4dd2914b..d5589969 100644 --- a/modelscope/models/cv/product_retrieval_embedding/item_detection.py +++ b/modelscope/models/cv/product_retrieval_embedding/item_detection.py @@ -21,7 +21,11 @@ class YOLOXONNX(object): self.num_classes = 13 self.onnx_path = onnx_path import onnxruntime as ort - self.ort_session = ort.InferenceSession(self.onnx_path) + options = ort.SessionOptions() + options.intra_op_num_threads = 1 + options.inter_op_num_threads = 1 + self.ort_session = ort.InferenceSession( + self.onnx_path, sess_options=options) self.with_p6 = False self.multi_detect = multi_detect diff --git a/tests/pipelines/test_product_retrieval_embedding.py b/tests/pipelines/test_product_retrieval_embedding.py index c0129ec5..c416943e 100644 --- a/tests/pipelines/test_product_retrieval_embedding.py +++ b/tests/pipelines/test_product_retrieval_embedding.py @@ -13,14 +13,14 @@ class ProductRetrievalEmbeddingTest(unittest.TestCase): model_id = 'damo/cv_resnet50_product-bag-embedding-models' img_input = 'data/test/images/product_embed_bag.jpg' - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_model_name(self): product_embed = pipeline(Tasks.product_retrieval_embedding, self.model_id) result = product_embed(self.img_input)[OutputKeys.IMG_EMBEDDING] print('abs sum value is: {}'.format(np.sum(np.abs(result)))) - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_model_from_modelhub(self): model = Model.from_pretrained(self.model_id) product_embed = pipeline(