|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Debugger restful api."""
- import json
-
- from flask import Blueprint, jsonify, request
-
- from mindinsight.conf import settings
- from mindinsight.debugger.debugger_server import DebuggerServer
- from mindinsight.utils.exceptions import ParamValueError
-
- BLUEPRINT = Blueprint("debugger", __name__,
- url_prefix=settings.URL_PATH_PREFIX + settings.API_PREFIX)
-
-
- def _initialize_debugger_server():
- """Initialize a debugger server instance."""
- port = settings.DEBUGGER_PORT if hasattr(settings, 'DEBUGGER_PORT') else None
- enable_debugger = settings.ENABLE_DEBUGGER if hasattr(settings, 'ENABLE_DEBUGGER') else False
- server = None
- if port and enable_debugger:
- server = DebuggerServer(port)
- return server
-
-
- def _read_post_request(post_request):
- """
- Extract the body of post request.
-
- Args:
- post_request (object): The post request.
-
- Returns:
- dict, the deserialized body of request.
- """
- body = post_request.stream.read()
- try:
- body = json.loads(body if body else "{}")
- except Exception:
- raise ParamValueError("Json data parse failed.")
- return body
-
-
- def _wrap_reply(func, *args, **kwargs):
- """Serialize reply."""
- reply = func(*args, **kwargs)
- return jsonify(reply)
-
-
- @BLUEPRINT.route("/debugger/poll_data", methods=["GET"])
- def poll_data():
- """
- Wait for data to be updated on UI.
-
- Get data from server and display the change on UI.
-
- Returns:
- str, the updated data.
-
- Examples:
- >>> Get http://xxxx/v1/mindinsight/debugger/poll_data?pos=xx
- """
- pos = request.args.get('pos')
-
- reply = _wrap_reply(BACKEND_SERVER.poll_data, pos)
-
- return reply
-
-
- @BLUEPRINT.route("/debugger/search", methods=["GET"])
- def search():
- """
- Search nodes in specified watchpoint.
-
- Returns:
- str, the required data.
-
- Examples:
- >>> Get http://xxxx/v1/mindinsight/debugger/retrive?mode=all
- """
- name = request.args.get('name')
- watch_point_id = int(request.args.get('watch_point_id', 0))
- reply = _wrap_reply(BACKEND_SERVER.search, name, watch_point_id)
-
- return reply
-
-
- @BLUEPRINT.route("/debugger/retrieve_node_by_bfs", methods=["GET"])
- def retrieve_node_by_bfs():
- """
- Search node by bfs.
-
- Returns:
- str, the required data.
-
- Examples:
- >>> Get http://xxxx/v1/mindinsight/debugger/retrieve_node_by_bfs?name=node_name&ascend=true
- """
- name = request.args.get('name')
- ascend = request.args.get('ascend', 'false')
- ascend = ascend == 'true'
- reply = _wrap_reply(BACKEND_SERVER.retrieve_node_by_bfs, name, ascend)
-
- return reply
-
-
- @BLUEPRINT.route("/debugger/tensor-comparisons", methods=["GET"])
- def tensor_comparisons():
- """
- Get tensor comparisons.
-
- Returns:
- str, the required data.
-
- Examples:
- >>> Get http://xxxx/v1/mindinsight/debugger/tensor-comparisons?
- name=node_name&detail=data&shape=[0, 0, :, :]&tolerance=0.5
- """
- name = request.args.get('name')
- detail = request.args.get('detail', 'data')
- shape = request.args.get('shape')
- tolerance = request.args.get('tolerance', '0')
- reply = _wrap_reply(BACKEND_SERVER.tensor_comparisons, name, shape, detail, tolerance)
-
- return reply
-
-
- @BLUEPRINT.route("/debugger/retrieve", methods=["POST"])
- def retrieve():
- """
- Retrieve data according to mode and params.
-
- Returns:
- str, the required data.
-
- Examples:
- >>> POST http://xxxx/v1/mindinsight/debugger/retrieve
- """
- body = _read_post_request(request)
- mode = body.get('mode')
- params = body.get('params')
- reply = _wrap_reply(BACKEND_SERVER.retrieve, mode, params)
- return reply
-
-
- @BLUEPRINT.route("/debugger/retrieve_tensor_history", methods=["POST"])
- def retrieve_tensor_history():
- """
- Retrieve data according to mode and params.
-
- Returns:
- str, the required data.
-
- Examples:
- >>> POST http://xxxx/v1/mindinsight/debugger/retrieve_tensor_history
- """
- body = _read_post_request(request)
- name = body.get('name')
- reply = _wrap_reply(BACKEND_SERVER.retrieve_tensor_history, name)
- return reply
-
-
- @BLUEPRINT.route("/debugger/tensors", methods=["GET"])
- def retrieve_tensor_value():
- """
- Retrieve tensor value according to name and shape.
-
- Returns:
- str, the required data.
-
- Examples:
- >>> GET http://xxxx/v1/mindinsight/debugger/tensors?name=node_name&detail=data&shape=[1,1,:,:]
- """
- name = request.args.get('name')
- detail = request.args.get('detail')
- shape = request.args.get('shape')
- reply = _wrap_reply(BACKEND_SERVER.retrieve_tensor_value, name, detail, shape)
- return reply
-
-
- @BLUEPRINT.route("/debugger/create_watchpoint", methods=["POST"])
- def create_watchpoint():
- """
- Create watchpoint.
-
- Returns:
- str, watchpoint id.
-
- Raises:
- MindInsightException: If method fails to be called.
- ParamValueError: If parsing json data search_condition fails.
-
- Examples:
- >>> POST http://xxxx/v1/mindinsight/debugger/create_watchpoint
- """
- body = _read_post_request(request)
-
- condition = body.get('condition')
- watch_nodes = body.get('watch_nodes')
- watch_point_id = body.get('watch_point_id')
- reply = _wrap_reply(BACKEND_SERVER.create_watchpoint, condition, watch_nodes, watch_point_id)
- return reply
-
-
- @BLUEPRINT.route("/debugger/update_watchpoint", methods=["POST"])
- def update_watchpoint():
- """
- Update watchpoint.
-
- Returns:
- str, reply message.
-
- Raises:
- MindInsightException: If method fails to be called.
- ParamValueError: If parsing json data search_condition fails.
-
- Examples:
- >>> POST http://xxxx/v1/mindinsight/debugger/update_watchpoint
- """
- body = _read_post_request(request)
-
- watch_point_id = body.get('watch_point_id')
- watch_nodes = body.get('watch_nodes')
- mode = body.get('mode')
- name = body.get('name')
- reply = _wrap_reply(BACKEND_SERVER.update_watchpoint, watch_point_id, watch_nodes, mode, name)
-
- return reply
-
-
- @BLUEPRINT.route("/debugger/delete_watchpoint", methods=["POST"])
- def delete_watchpoint():
- """
- delete watchpoint.
-
- Returns:
- str, reply message.
-
- Raises:
- MindInsightException: If method fails to be called.
- ParamValueError: If parsing json data search_condition fails.
-
- Examples:
- >>> POST http://xxxx/v1/mindinsight/debugger/delete_watchpoint
- """
- body = _read_post_request(request)
-
- watch_point_id = body.get('watch_point_id')
-
- reply = _wrap_reply(BACKEND_SERVER.delete_watchpoint, watch_point_id)
-
- return reply
-
-
- @BLUEPRINT.route("/debugger/control", methods=["POST"])
- def control():
- """
- Control request.
-
- Returns:
- str, reply message.
-
- Raises:
- MindInsightException: If method fails to be called.
- ParamValueError: If parsing json data search_condition fails.
-
- Examples:
- >>> POST http://xxxx/v1/mindinsight/debugger/control
- """
- params = _read_post_request(request)
- reply = _wrap_reply(BACKEND_SERVER.control, params)
-
- return reply
-
-
- BACKEND_SERVER = _initialize_debugger_server()
-
-
- def init_module(app):
- """
- Init module entry.
-
- Args:
- app (Flask): The application obj.
- """
- app.register_blueprint(BLUEPRINT)
- if BACKEND_SERVER:
- BACKEND_SERVER.start()
|