From a23190bd646eaea4c70b4b38d40d8f9e4255d2c8 Mon Sep 17 00:00:00 2001 From: yelihua Date: Mon, 4 Jan 2021 11:43:39 +0800 Subject: [PATCH] decode shape in restful api --- mindinsight/backend/debugger/debugger_api.py | 23 +++++++++++++++++-- .../ui/src/components/debugger-tensor.vue | 4 ++-- tests/st/func/debugger/test_restful_api.py | 5 ++-- 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/mindinsight/backend/debugger/debugger_api.py b/mindinsight/backend/debugger/debugger_api.py index fe6c7e46..0e8aea63 100644 --- a/mindinsight/backend/debugger/debugger_api.py +++ b/mindinsight/backend/debugger/debugger_api.py @@ -14,6 +14,7 @@ # ============================================================================ """Debugger restful api.""" import json +from urllib.parse import unquote from flask import Blueprint, jsonify, request @@ -34,6 +35,24 @@ def _initialize_debugger_server(): return server +def _unquote_param(param): + """ + Decode parameter value. + + Args: + param (str): Encoded param value. + + Returns: + str, decoded param value. + """ + if isinstance(param, str): + try: + param = unquote(param, errors='strict') + except UnicodeDecodeError: + raise ParamValueError('Unquote error with strict mode.') + return param + + def _read_post_request(post_request): """ Extract the body of post request. @@ -134,7 +153,7 @@ def tensor_comparisons(): """ name = request.args.get('name') detail = request.args.get('detail', 'data') - shape = request.args.get('shape') + shape = _unquote_param(request.args.get('shape')) tolerance = request.args.get('tolerance', '0') reply = _wrap_reply(BACKEND_SERVER.tensor_comparisons, name, shape, detail, tolerance) @@ -190,7 +209,7 @@ def retrieve_tensor_value(): """ name = request.args.get('name') detail = request.args.get('detail') - shape = request.args.get('shape') + shape = _unquote_param(request.args.get('shape')) graph_name = request.args.get('graph_name') prev = bool(request.args.get('prev') == 'true') diff --git a/mindinsight/ui/src/components/debugger-tensor.vue b/mindinsight/ui/src/components/debugger-tensor.vue index 7f6e3138..6999d59e 100644 --- a/mindinsight/ui/src/components/debugger-tensor.vue +++ b/mindinsight/ui/src/components/debugger-tensor.vue @@ -992,7 +992,7 @@ export default { const params = { name: row.name, detail: 'data', - shape, + shape: encodeURIComponent(shape), tolerance: this.tolerance / 100, graph_name: row.graph_name, }; @@ -1085,7 +1085,7 @@ export default { const params = { name: row.name, detail: 'data', - shape, + shape: encodeURIComponent(shape), graph_name: row.graph_name, prev: this.gridType === 'preStep' ? true : false, }; diff --git a/tests/st/func/debugger/test_restful_api.py b/tests/st/func/debugger/test_restful_api.py index 00e7b4a3..16b624fd 100644 --- a/tests/st/func/debugger/test_restful_api.py +++ b/tests/st/func/debugger/test_restful_api.py @@ -19,6 +19,7 @@ Usage: pytest tests/st/func/debugger/test_restful_api.py """ import os +from urllib.parse import quote import pytest @@ -204,7 +205,7 @@ class TestAscendDebugger: body_data = { 'name': node_name + ':0', 'detail': 'data', - 'shape': '[1, 1:3]' + 'shape': quote('[1, 1:3]') } expect_file = 'retrieve_tensor_value.json' send_and_compare_result(app_client, url, body_data, expect_file, method='get') @@ -237,7 +238,7 @@ class TestAscendDebugger: body_data = { 'name': node_name + ':0', 'detail': 'data', - 'shape': '[:, :]', + 'shape': quote('[:, :]'), 'tolerance': 1 } expect_file = 'compare_tensors.json'