Browse Source

decode shape in restful api

tags/v1.2.0-rc1
yelihua 4 years ago
parent
commit
a23190bd64
3 changed files with 26 additions and 6 deletions
  1. +21
    -2
      mindinsight/backend/debugger/debugger_api.py
  2. +2
    -2
      mindinsight/ui/src/components/debugger-tensor.vue
  3. +3
    -2
      tests/st/func/debugger/test_restful_api.py

+ 21
- 2
mindinsight/backend/debugger/debugger_api.py View File

@@ -14,6 +14,7 @@
# ============================================================================
"""Debugger restful api."""
import json
from urllib.parse import unquote

from flask import Blueprint, jsonify, request

@@ -34,6 +35,24 @@ def _initialize_debugger_server():
return server


def _unquote_param(param):
"""
Decode parameter value.

Args:
param (str): Encoded param value.

Returns:
str, decoded param value.
"""
if isinstance(param, str):
try:
param = unquote(param, errors='strict')
except UnicodeDecodeError:
raise ParamValueError('Unquote error with strict mode.')
return param


def _read_post_request(post_request):
"""
Extract the body of post request.
@@ -134,7 +153,7 @@ def tensor_comparisons():
"""
name = request.args.get('name')
detail = request.args.get('detail', 'data')
shape = request.args.get('shape')
shape = _unquote_param(request.args.get('shape'))
tolerance = request.args.get('tolerance', '0')
reply = _wrap_reply(BACKEND_SERVER.tensor_comparisons, name, shape, detail, tolerance)

@@ -190,7 +209,7 @@ def retrieve_tensor_value():
"""
name = request.args.get('name')
detail = request.args.get('detail')
shape = request.args.get('shape')
shape = _unquote_param(request.args.get('shape'))
graph_name = request.args.get('graph_name')
prev = bool(request.args.get('prev') == 'true')



+ 2
- 2
mindinsight/ui/src/components/debugger-tensor.vue View File

@@ -992,7 +992,7 @@ export default {
const params = {
name: row.name,
detail: 'data',
shape,
shape: encodeURIComponent(shape),
tolerance: this.tolerance / 100,
graph_name: row.graph_name,
};
@@ -1085,7 +1085,7 @@ export default {
const params = {
name: row.name,
detail: 'data',
shape,
shape: encodeURIComponent(shape),
graph_name: row.graph_name,
prev: this.gridType === 'preStep' ? true : false,
};


+ 3
- 2
tests/st/func/debugger/test_restful_api.py View File

@@ -19,6 +19,7 @@ Usage:
pytest tests/st/func/debugger/test_restful_api.py
"""
import os
from urllib.parse import quote

import pytest

@@ -204,7 +205,7 @@ class TestAscendDebugger:
body_data = {
'name': node_name + ':0',
'detail': 'data',
'shape': '[1, 1:3]'
'shape': quote('[1, 1:3]')
}
expect_file = 'retrieve_tensor_value.json'
send_and_compare_result(app_client, url, body_data, expect_file, method='get')
@@ -237,7 +238,7 @@ class TestAscendDebugger:
body_data = {
'name': node_name + ':0',
'detail': 'data',
'shape': '[:, :]',
'shape': quote('[:, :]'),
'tolerance': 1
}
expect_file = 'compare_tensors.json'


Loading…
Cancel
Save