# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Debugger restful api.""" import json from flask import Blueprint, jsonify, request from mindinsight.conf import settings from mindinsight.debugger.debugger_server import DebuggerServer from mindinsight.utils.exceptions import ParamValueError BLUEPRINT = Blueprint("debugger", __name__, url_prefix=settings.URL_PATH_PREFIX + settings.API_PREFIX) def _initialize_debugger_server(): """Initialize a debugger server instance.""" enable_debugger = settings.ENABLE_DEBUGGER if hasattr(settings, 'ENABLE_DEBUGGER') else False server = None if enable_debugger: server = DebuggerServer() return server def _read_post_request(post_request): """ Extract the body of post request. Args: post_request (object): The post request. Returns: dict, the deserialized body of request. """ body = post_request.stream.read() try: body = json.loads(body if body else "{}") except Exception: raise ParamValueError("Json data parse failed.") return body def _wrap_reply(func, *args, **kwargs): """Serialize reply.""" reply = func(*args, **kwargs) return jsonify(reply) @BLUEPRINT.route("/debugger/poll-data", methods=["GET"]) def poll_data(): """ Wait for data to be updated on UI. Get data from server and display the change on UI. Returns: str, the updated data. Examples: >>> Get http://xxxx/v1/mindinsight/debugger/poll-data?pos=xx """ pos = request.args.get('pos') reply = _wrap_reply(BACKEND_SERVER.poll_data, pos) return reply @BLUEPRINT.route("/debugger/search", methods=["GET"]) def search(): """ Search nodes in specified watchpoint. Returns: str, the required data. Examples: >>> Get http://xxxx/v1/mindinsight/debugger/search?name=mock_name&watch_point_id=1 """ name = request.args.get('name') graph_name = request.args.get('graph_name') watch_point_id = int(request.args.get('watch_point_id', 0)) node_category = request.args.get('node_category') reply = _wrap_reply(BACKEND_SERVER.search, {'name': name, 'graph_name': graph_name, 'watch_point_id': watch_point_id, 'node_category': node_category}) return reply @BLUEPRINT.route("/debugger/retrieve_node_by_bfs", methods=["GET"]) def retrieve_node_by_bfs(): """ Search node by bfs. Returns: str, the required data. Examples: >>> Get http://xxxx/v1/mindinsight/debugger/retrieve_node_by_bfs?name=node_name&ascend=true """ name = request.args.get('name') graph_name = request.args.get('graph_name') ascend = request.args.get('ascend', 'false') ascend = ascend == 'true' reply = _wrap_reply(BACKEND_SERVER.retrieve_node_by_bfs, name, graph_name, ascend) return reply @BLUEPRINT.route("/debugger/tensor-comparisons", methods=["GET"]) def tensor_comparisons(): """ Get tensor comparisons. Returns: str, the required data. Examples: >>> Get http://xxxx/v1/mindinsight/debugger/tensor-comparisons """ name = request.args.get('name') detail = request.args.get('detail', 'data') shape = request.args.get('shape') tolerance = request.args.get('tolerance', '0') reply = _wrap_reply(BACKEND_SERVER.tensor_comparisons, name, shape, detail, tolerance) return reply @BLUEPRINT.route("/debugger/retrieve", methods=["POST"]) def retrieve(): """ Retrieve data according to mode and params. Returns: str, the required data. Examples: >>> POST http://xxxx/v1/mindinsight/debugger/retrieve """ body = _read_post_request(request) mode = body.get('mode') params = body.get('params') reply = _wrap_reply(BACKEND_SERVER.retrieve, mode, params) return reply @BLUEPRINT.route("/debugger/tensor-history", methods=["POST"]) def retrieve_tensor_history(): """ Retrieve data according to mode and params. Returns: str, the required data. Examples: >>> POST http://xxxx/v1/mindinsight/debugger/tensor-history """ body = _read_post_request(request) name = body.get('name') graph_name = body.get('graph_name') reply = _wrap_reply(BACKEND_SERVER.retrieve_tensor_history, name, graph_name) return reply @BLUEPRINT.route("/debugger/tensors", methods=["GET"]) def retrieve_tensor_value(): """ Retrieve tensor value according to name and shape. Returns: str, the required data. Examples: >>> GET http://xxxx/v1/mindinsight/debugger/tensors?name=tensor_name&detail=data&shape=[1,1,:,:] """ name = request.args.get('name') detail = request.args.get('detail') shape = request.args.get('shape') graph_name = request.args.get('graph_name') prev = bool(request.args.get('prev') == 'true') reply = _wrap_reply(BACKEND_SERVER.retrieve_tensor_value, name, detail, shape, graph_name, prev) return reply @BLUEPRINT.route("/debugger/create-watchpoint", methods=["POST"]) def create_watchpoint(): """ Create watchpoint. Returns: str, watchpoint id. Raises: MindInsightException: If method fails to be called. Examples: >>> POST http://xxxx/v1/mindinsight/debugger/create-watchpoint """ params = _read_post_request(request) params['watch_condition'] = params.pop('condition', None) reply = _wrap_reply(BACKEND_SERVER.create_watchpoint, params) return reply @BLUEPRINT.route("/debugger/update-watchpoint", methods=["POST"]) def update_watchpoint(): """ Update watchpoint. Returns: str, reply message. Raises: MindInsightException: If method fails to be called. Examples: >>> POST http://xxxx/v1/mindinsight/debugger/update-watchpoint """ params = _read_post_request(request) reply = _wrap_reply(BACKEND_SERVER.update_watchpoint, params) return reply @BLUEPRINT.route("/debugger/delete-watchpoint", methods=["POST"]) def delete_watchpoint(): """ delete watchpoint. Returns: str, reply message. Raises: MindInsightException: If method fails to be called. Examples: >>> POST http://xxxx/v1/mindinsight/debugger/delete-watchpoint """ body = _read_post_request(request) watch_point_id = body.get('watch_point_id') reply = _wrap_reply(BACKEND_SERVER.delete_watchpoint, watch_point_id) return reply @BLUEPRINT.route("/debugger/control", methods=["POST"]) def control(): """ Control request. Returns: str, reply message. Raises: MindInsightException: If method fails to be called. Examples: >>> POST http://xxxx/v1/mindinsight/debugger/control """ params = _read_post_request(request) reply = _wrap_reply(BACKEND_SERVER.control, params) return reply @BLUEPRINT.route("/debugger/recheck", methods=["POST"]) def recheck(): """ Recheck request. Returns: str, reply message. Raises: MindInsightException: If method fails to be called. Examples: >>> POST http://xxxx/v1/mindinsight/debugger/recheck """ reply = _wrap_reply(BACKEND_SERVER.recheck) return reply @BLUEPRINT.route("/debugger/tensor-graphs", methods=["GET"]) def retrieve_tensor_graph(): """ Retrieve tensor value according to name and shape. Returns: str, the required data. Examples: >>> GET http://xxxx/v1/mindinsight/debugger/tensor-graphs?tensor_name=tensor_name&graph_name=graph_name """ tensor_name = request.args.get('tensor_name') graph_name = request.args.get('graph_name') reply = _wrap_reply(BACKEND_SERVER.retrieve_tensor_graph, tensor_name, graph_name) return reply @BLUEPRINT.route("/debugger/tensor-hits", methods=["GET"]) def retrieve_tensor_hits(): """ Retrieve tensor value according to name and shape. Returns: str, the required data. Examples: >>> GET http://xxxx/v1/mindinsight/debugger/tensor-hits?tensor_name=tensor_name&graph_name=graph_name """ tensor_name = request.args.get('tensor_name') graph_name = request.args.get('graph_name') reply = _wrap_reply(BACKEND_SERVER.retrieve_tensor_hits, tensor_name, graph_name) return reply BACKEND_SERVER = _initialize_debugger_server() def init_module(app): """ Init module entry. Args: app (Flask): The application obj. """ app.register_blueprint(BLUEPRINT) if BACKEND_SERVER: BACKEND_SERVER.start()