|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426 |
- # Copyright 2020-2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Debugger restful api."""
- import json
- from urllib.parse import unquote
-
- from flask import Blueprint, jsonify, request
-
- from mindinsight.conf import settings
- from mindinsight.debugger.session_manager import SessionManager
- from mindinsight.utils.exceptions import ParamMissError, ParamValueError
-
- BLUEPRINT = Blueprint("debugger", __name__,
- url_prefix=settings.URL_PATH_PREFIX + settings.API_PREFIX)
-
-
- def _unquote_param(param):
- """
- Decode parameter value.
-
- Args:
- param (str): Encoded param value.
-
- Returns:
- str, decoded param value.
- """
- if isinstance(param, str):
- try:
- param = unquote(param, errors='strict')
- except UnicodeDecodeError:
- raise ParamValueError('Unquote error with strict mode.')
- return param
-
-
- def _read_post_request(post_request):
- """
- Extract the body of post request.
-
- Args:
- post_request (object): The post request.
-
- Returns:
- dict, the deserialized body of request.
- """
- body = post_request.stream.read()
- try:
- body = json.loads(body if body else "{}")
- except Exception:
- raise ParamValueError("Json data parse failed.")
- return body
-
-
- def _wrap_reply(func, *args, **kwargs):
- """Serialize reply."""
- reply = func(*args, **kwargs)
- return jsonify(reply)
-
-
- @BLUEPRINT.route("/debugger/sessions/<session_id>/poll-data", methods=["GET"])
- def poll_data(session_id):
- """
- Wait for data to be updated on UI.
-
- Get data from server and display the change on UI.
-
- Returns:
- str, the updated data.
-
- Examples:
- >>> Get http://xxxx/v1/mindinsight/debugger/sessions/xxxx/poll-data?pos=xx
- """
- pos = request.args.get('pos')
-
- reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).poll_data, pos)
-
- return reply
-
-
- @BLUEPRINT.route("/debugger/sessions/<session_id>/search", methods=["GET"])
- def search(session_id):
- """
- Search nodes in specified watchpoint.
-
- Returns:
- str, the required data.
-
- Examples:
- >>> Get http://xxxx/v1/mindinsight/debugger/sessions/xxxx/search?name=mock_name&watch_point_id=1
- """
- name = request.args.get('name')
- graph_name = request.args.get('graph_name')
- watch_point_id = int(request.args.get('watch_point_id', 0))
- node_category = request.args.get('node_category')
- rank_id = int(request.args.get('rank_id', 0))
- reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).search,
- {'name': name,
- 'graph_name': graph_name,
- 'watch_point_id': watch_point_id,
- 'node_category': node_category,
- 'rank_id': rank_id})
-
- return reply
-
-
- @BLUEPRINT.route("/debugger/sessions/<session_id>/tensor-comparisons", methods=["GET"])
- def tensor_comparisons(session_id):
- """
- Get tensor comparisons.
-
- Returns:
- str, the required data.
-
- Examples:
- >>> Get http://xxxx/v1/mindinsight/debugger/sessions/xxxx/tensor-comparisons
- """
- name = request.args.get('name')
- detail = request.args.get('detail', 'data')
- shape = _unquote_param(request.args.get('shape'))
- tolerance = request.args.get('tolerance', '0')
- rank_id = int(request.args.get('rank_id', 0))
- reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).tensor_comparisons, name, shape, detail,
- tolerance, rank_id)
-
- return reply
-
-
- @BLUEPRINT.route("/debugger/sessions/<session_id>/retrieve", methods=["POST"])
- def retrieve(session_id):
- """
- Retrieve data according to mode and params.
-
- Returns:
- str, the required data.
-
- Examples:
- >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/retrieve
- """
- body = _read_post_request(request)
- mode = body.get('mode')
- params = body.get('params')
- reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).retrieve, mode, params)
- return reply
-
-
- @BLUEPRINT.route("/debugger/sessions/<session_id>/tensor-history", methods=["POST"])
- def retrieve_tensor_history(session_id):
- """
- Retrieve data according to mode and params.
-
- Returns:
- str, the required data.
-
- Examples:
- >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/tensor-history
- """
- body = _read_post_request(request)
- name = body.get('name')
- graph_name = body.get('graph_name')
- rank_id = body.get('rank_id')
- reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).retrieve_tensor_history, name, graph_name,
- rank_id)
- return reply
-
-
- @BLUEPRINT.route("/debugger/sessions/<session_id>/tensors", methods=["GET"])
- def retrieve_tensor_value(session_id):
- """
- Retrieve tensor value according to name and shape.
-
- Returns:
- str, the required data.
-
- Examples:
- >>> GET http://xxxx/v1/mindinsight/debugger/sessions/xxxx/tensors?name=tensor_name&detail=data&shape=[1,1,:,:]
- """
- name = request.args.get('name')
- detail = request.args.get('detail')
- shape = _unquote_param(request.args.get('shape'))
- graph_name = request.args.get('graph_name')
- prev = bool(request.args.get('prev') == 'true')
- rank_id = int(request.args.get('rank_id', 0))
-
- reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).retrieve_tensor_value, name, detail,
- shape, graph_name, prev, rank_id)
- return reply
-
-
- @BLUEPRINT.route("/debugger/sessions/<session_id>/create-watchpoint", methods=["POST"])
- def create_watchpoint(session_id):
- """
- Create watchpoint.
-
- Returns:
- str, watchpoint id.
-
- Raises:
- MindInsightException: If method fails to be called.
-
- Examples:
- >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/create-watchpoint
- """
- params = _read_post_request(request)
- params['watch_condition'] = params.pop('condition', None)
- reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).create_watchpoint, params)
- return reply
-
-
- @BLUEPRINT.route("/debugger/sessions/<session_id>/update-watchpoint", methods=["POST"])
- def update_watchpoint(session_id):
- """
- Update watchpoint.
-
- Returns:
- str, reply message.
-
- Raises:
- MindInsightException: If method fails to be called.
-
- Examples:
- >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/update-watchpoint
- """
- params = _read_post_request(request)
- reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).update_watchpoint, params)
- return reply
-
-
- @BLUEPRINT.route("/debugger/sessions/<session_id>/delete-watchpoint", methods=["POST"])
- def delete_watchpoint(session_id):
- """
- Delete watchpoint.
-
- Returns:
- str, reply message.
-
- Raises:
- MindInsightException: If method fails to be called.
-
- Examples:
- >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/delete-watchpoint
- """
- body = _read_post_request(request)
-
- watch_point_id = body.get('watch_point_id')
-
- reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).delete_watchpoint, watch_point_id)
-
- return reply
-
-
- @BLUEPRINT.route("/debugger/sessions/<session_id>/control", methods=["POST"])
- def control(session_id):
- """
- Control request.
-
- Returns:
- str, reply message.
-
- Raises:
- MindInsightException: If method fails to be called.
-
- Examples:
- >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/control
- """
- params = _read_post_request(request)
- reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).control, params)
-
- return reply
-
-
- @BLUEPRINT.route("/debugger/sessions/<session_id>/recheck", methods=["POST"])
- def recheck(session_id):
- """
- Recheck request.
-
- Returns:
- str, reply message.
-
- Raises:
- MindInsightException: If method fails to be called.
-
- Examples:
- >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/recheck
- """
- reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).recheck)
-
- return reply
-
-
- @BLUEPRINT.route("/debugger/sessions/<session_id>/tensor-graphs", methods=["GET"])
- def retrieve_tensor_graph(session_id):
- """
- Retrieve tensor value according to name and shape.
-
- Returns:
- str, the required data.
-
- Examples:
- >>> GET http://xxxx/v1/mindinsight/debugger/sessions/xxxx/tensor-graphs?tensor_name=xxx&graph_name=xxx
- """
- tensor_name = request.args.get('tensor_name')
- graph_name = request.args.get('graph_name')
- rank_id = int(request.args.get('rank_id', 0))
- reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).retrieve_tensor_graph, tensor_name,
- graph_name, rank_id)
- return reply
-
-
- @BLUEPRINT.route("/debugger/sessions/<session_id>/tensor-hits", methods=["GET"])
- def retrieve_tensor_hits(session_id):
- """
- Retrieve tensor value according to name and shape.
-
- Returns:
- str, the required data.
-
- Examples:
- >>> GET http://xxxx/v1/mindinsight/debugger/sessions/xxxx/tensor-hits?tensor_name=xxx&graph_name=xxx
- """
- tensor_name = request.args.get('tensor_name')
- graph_name = request.args.get('graph_name')
- rank_id = int(request.args.get('rank_id', 0))
- reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).retrieve_tensor_hits, tensor_name,
- graph_name, rank_id)
- return reply
-
-
- @BLUEPRINT.route("/debugger/sessions/<session_id>/search-watchpoint-hits", methods=["POST"])
- def search_watchpoint_hits(session_id):
- """
- Search watchpoint hits by group condition.
-
- Returns:
- str, the required data.
-
- Examples:
- >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/search-watchpoint-hits
- """
- body = _read_post_request(request)
- group_condition = body.get('group_condition')
- reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).search_watchpoint_hits, group_condition)
- return reply
-
-
- @BLUEPRINT.route("/debugger/sessions/<session_id>/condition-collections", methods=["GET"])
- def get_condition_collections(session_id):
- """Get condition collections."""
- reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).get_condition_collections)
- return reply
-
-
- @BLUEPRINT.route("/debugger/sessions/<session_id>/set-recommended-watch-points", methods=["POST"])
- def set_recommended_watch_points(session_id):
- """Set recommended watch points."""
- body = _read_post_request(request)
- request_body = body.get('requestBody')
- if request_body is None:
- raise ParamMissError('requestBody')
-
- set_recommended = request_body.get('set_recommended')
- reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).set_recommended_watch_points,
- set_recommended)
- return reply
-
-
- @BLUEPRINT.route("/debugger/sessions", methods=["POST"])
- def create_session():
- """
- Get session id if session exist, else create a session.
-
- Returns:
- str, session id.
-
- Examples:
- >>> POST http://xxxx/v1/mindinsight/debugger/sessions
- """
- body = _read_post_request(request)
- summary_dir = body.get('dump_dir')
- session_type = body.get('session_type')
- reply = _wrap_reply(SessionManager.get_instance().create_session, session_type, summary_dir)
- return reply
-
-
- @BLUEPRINT.route("/debugger/sessions", methods=["GET"])
- def get_train_jobs():
- """
- Check the current active sessions.
-
- Examples:
- >>> POST http://xxxx/v1/mindinsight/debugger/sessions
- """
- reply = _wrap_reply(SessionManager.get_instance().get_train_jobs)
- return reply
-
-
- @BLUEPRINT.route("/debugger/sessions/<session_id>/delete", methods=["POST"])
- def delete_session(session_id):
- """
- Delete session by session id.
-
- Examples:
- >>> POST http://xxxx/v1/mindinsight/debugger/xxx/delete-session
- """
- reply = _wrap_reply(SessionManager.get_instance().delete_session, session_id)
- return reply
-
-
- def init_module(app):
- """
- Init module entry.
-
- Args:
- app (Flask): The application obj.
- """
- app.register_blueprint(BLUEPRINT)
|