Browse Source

decode shape in restful api

tags/v1.2.0-rc1
yelihua 4 years ago
parent
commit
a23190bd64
3 changed files with 26 additions and 6 deletions
  1. +21
    -2
      mindinsight/backend/debugger/debugger_api.py
  2. +2
    -2
      mindinsight/ui/src/components/debugger-tensor.vue
  3. +3
    -2
      tests/st/func/debugger/test_restful_api.py

+ 21
- 2
mindinsight/backend/debugger/debugger_api.py View File

@@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
"""Debugger restful api.""" """Debugger restful api."""
import json import json
from urllib.parse import unquote


from flask import Blueprint, jsonify, request from flask import Blueprint, jsonify, request


@@ -34,6 +35,24 @@ def _initialize_debugger_server():
return server return server




def _unquote_param(param):
"""
Decode parameter value.

Args:
param (str): Encoded param value.

Returns:
str, decoded param value.
"""
if isinstance(param, str):
try:
param = unquote(param, errors='strict')
except UnicodeDecodeError:
raise ParamValueError('Unquote error with strict mode.')
return param


def _read_post_request(post_request): def _read_post_request(post_request):
""" """
Extract the body of post request. Extract the body of post request.
@@ -134,7 +153,7 @@ def tensor_comparisons():
""" """
name = request.args.get('name') name = request.args.get('name')
detail = request.args.get('detail', 'data') detail = request.args.get('detail', 'data')
shape = request.args.get('shape')
shape = _unquote_param(request.args.get('shape'))
tolerance = request.args.get('tolerance', '0') tolerance = request.args.get('tolerance', '0')
reply = _wrap_reply(BACKEND_SERVER.tensor_comparisons, name, shape, detail, tolerance) reply = _wrap_reply(BACKEND_SERVER.tensor_comparisons, name, shape, detail, tolerance)


@@ -190,7 +209,7 @@ def retrieve_tensor_value():
""" """
name = request.args.get('name') name = request.args.get('name')
detail = request.args.get('detail') detail = request.args.get('detail')
shape = request.args.get('shape')
shape = _unquote_param(request.args.get('shape'))
graph_name = request.args.get('graph_name') graph_name = request.args.get('graph_name')
prev = bool(request.args.get('prev') == 'true') prev = bool(request.args.get('prev') == 'true')




+ 2
- 2
mindinsight/ui/src/components/debugger-tensor.vue View File

@@ -992,7 +992,7 @@ export default {
const params = { const params = {
name: row.name, name: row.name,
detail: 'data', detail: 'data',
shape,
shape: encodeURIComponent(shape),
tolerance: this.tolerance / 100, tolerance: this.tolerance / 100,
graph_name: row.graph_name, graph_name: row.graph_name,
}; };
@@ -1085,7 +1085,7 @@ export default {
const params = { const params = {
name: row.name, name: row.name,
detail: 'data', detail: 'data',
shape,
shape: encodeURIComponent(shape),
graph_name: row.graph_name, graph_name: row.graph_name,
prev: this.gridType === 'preStep' ? true : false, prev: this.gridType === 'preStep' ? true : false,
}; };


+ 3
- 2
tests/st/func/debugger/test_restful_api.py View File

@@ -19,6 +19,7 @@ Usage:
pytest tests/st/func/debugger/test_restful_api.py pytest tests/st/func/debugger/test_restful_api.py
""" """
import os import os
from urllib.parse import quote


import pytest import pytest


@@ -204,7 +205,7 @@ class TestAscendDebugger:
body_data = { body_data = {
'name': node_name + ':0', 'name': node_name + ':0',
'detail': 'data', 'detail': 'data',
'shape': '[1, 1:3]'
'shape': quote('[1, 1:3]')
} }
expect_file = 'retrieve_tensor_value.json' expect_file = 'retrieve_tensor_value.json'
send_and_compare_result(app_client, url, body_data, expect_file, method='get') send_and_compare_result(app_client, url, body_data, expect_file, method='get')
@@ -237,7 +238,7 @@ class TestAscendDebugger:
body_data = { body_data = {
'name': node_name + ':0', 'name': node_name + ':0',
'detail': 'data', 'detail': 'data',
'shape': '[:, :]',
'shape': quote('[:, :]'),
'tolerance': 1 'tolerance': 1
} }
expect_file = 'compare_tensors.json' expect_file = 'compare_tensors.json'


Loading…
Cancel
Save