You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

debugger_api.py 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Debugger restful api."""
  16. import json
  17. from urllib.parse import unquote
  18. from flask import Blueprint, jsonify, request
  19. from mindinsight.conf import settings
  20. from mindinsight.debugger.session_manager import SessionManager
  21. from mindinsight.utils.exceptions import ParamMissError, ParamValueError
  22. BLUEPRINT = Blueprint("debugger", __name__,
  23. url_prefix=settings.URL_PATH_PREFIX + settings.API_PREFIX)
  24. def _unquote_param(param):
  25. """
  26. Decode parameter value.
  27. Args:
  28. param (str): Encoded param value.
  29. Returns:
  30. str, decoded param value.
  31. """
  32. if isinstance(param, str):
  33. try:
  34. param = unquote(param, errors='strict')
  35. except UnicodeDecodeError:
  36. raise ParamValueError('Unquote error with strict mode.')
  37. return param
  38. def _read_post_request(post_request):
  39. """
  40. Extract the body of post request.
  41. Args:
  42. post_request (object): The post request.
  43. Returns:
  44. dict, the deserialized body of request.
  45. """
  46. body = post_request.stream.read()
  47. try:
  48. body = json.loads(body if body else "{}")
  49. except Exception:
  50. raise ParamValueError("Json data parse failed.")
  51. return body
  52. def _wrap_reply(func, *args, **kwargs):
  53. """Serialize reply."""
  54. reply = func(*args, **kwargs)
  55. return jsonify(reply)
  56. @BLUEPRINT.route("/debugger/sessions/<session_id>/poll-data", methods=["GET"])
  57. def poll_data(session_id):
  58. """
  59. Wait for data to be updated on UI.
  60. Get data from server and display the change on UI.
  61. Returns:
  62. str, the updated data.
  63. Examples:
  64. >>> Get http://xxxx/v1/mindinsight/debugger/sessions/xxxx/poll-data?pos=xx
  65. """
  66. pos = request.args.get('pos')
  67. reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).poll_data, pos)
  68. return reply
  69. @BLUEPRINT.route("/debugger/sessions/<session_id>/search", methods=["GET"])
  70. def search(session_id):
  71. """
  72. Search nodes in specified watchpoint.
  73. Returns:
  74. str, the required data.
  75. Examples:
  76. >>> Get http://xxxx/v1/mindinsight/debugger/sessions/xxxx/search?name=mock_name&watch_point_id=1
  77. """
  78. name = request.args.get('name')
  79. graph_name = request.args.get('graph_name')
  80. watch_point_id = int(request.args.get('watch_point_id', 0))
  81. node_category = request.args.get('node_category')
  82. rank_id = int(request.args.get('rank_id', 0))
  83. reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).search,
  84. {'name': name,
  85. 'graph_name': graph_name,
  86. 'watch_point_id': watch_point_id,
  87. 'node_category': node_category,
  88. 'rand_id': rank_id})
  89. return reply
  90. @BLUEPRINT.route("/debugger/sessions/<session_id>/tensor-comparisons", methods=["GET"])
  91. def tensor_comparisons(session_id):
  92. """
  93. Get tensor comparisons.
  94. Returns:
  95. str, the required data.
  96. Examples:
  97. >>> Get http://xxxx/v1/mindinsight/debugger/sessions/xxxx/tensor-comparisons
  98. """
  99. name = request.args.get('name')
  100. detail = request.args.get('detail', 'data')
  101. shape = _unquote_param(request.args.get('shape'))
  102. tolerance = request.args.get('tolerance', '0')
  103. rank_id = int(request.args.get('rank_id', 0))
  104. reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).tensor_comparisons, name, shape, detail,
  105. tolerance, rank_id)
  106. return reply
  107. @BLUEPRINT.route("/debugger/sessions/<session_id>/retrieve", methods=["POST"])
  108. def retrieve(session_id):
  109. """
  110. Retrieve data according to mode and params.
  111. Returns:
  112. str, the required data.
  113. Examples:
  114. >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/retrieve
  115. """
  116. body = _read_post_request(request)
  117. mode = body.get('mode')
  118. params = body.get('params')
  119. reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).retrieve, mode, params)
  120. return reply
  121. @BLUEPRINT.route("/debugger/sessions/<session_id>/tensor-history", methods=["POST"])
  122. def retrieve_tensor_history(session_id):
  123. """
  124. Retrieve data according to mode and params.
  125. Returns:
  126. str, the required data.
  127. Examples:
  128. >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/tensor-history
  129. """
  130. body = _read_post_request(request)
  131. name = body.get('name')
  132. graph_name = body.get('graph_name')
  133. rank_id = body.get('rank_id')
  134. reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).retrieve_tensor_history, name, graph_name,
  135. rank_id)
  136. return reply
  137. @BLUEPRINT.route("/debugger/sessions/<session_id>/tensors", methods=["GET"])
  138. def retrieve_tensor_value(session_id):
  139. """
  140. Retrieve tensor value according to name and shape.
  141. Returns:
  142. str, the required data.
  143. Examples:
  144. >>> GET http://xxxx/v1/mindinsight/debugger/sessions/xxxx/tensors?name=tensor_name&detail=data&shape=[1,1,:,:]
  145. """
  146. name = request.args.get('name')
  147. detail = request.args.get('detail')
  148. shape = _unquote_param(request.args.get('shape'))
  149. graph_name = request.args.get('graph_name')
  150. prev = bool(request.args.get('prev') == 'true')
  151. rank_id = int(request.args.get('rank_id', 0))
  152. reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).retrieve_tensor_value, name, detail,
  153. shape, graph_name, prev, rank_id)
  154. return reply
  155. @BLUEPRINT.route("/debugger/sessions/<session_id>/create-watchpoint", methods=["POST"])
  156. def create_watchpoint(session_id):
  157. """
  158. Create watchpoint.
  159. Returns:
  160. str, watchpoint id.
  161. Raises:
  162. MindInsightException: If method fails to be called.
  163. Examples:
  164. >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/create-watchpoint
  165. """
  166. params = _read_post_request(request)
  167. params['watch_condition'] = params.pop('condition', None)
  168. reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).create_watchpoint, params)
  169. return reply
  170. @BLUEPRINT.route("/debugger/sessions/<session_id>/update-watchpoint", methods=["POST"])
  171. def update_watchpoint(session_id):
  172. """
  173. Update watchpoint.
  174. Returns:
  175. str, reply message.
  176. Raises:
  177. MindInsightException: If method fails to be called.
  178. Examples:
  179. >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/update-watchpoint
  180. """
  181. params = _read_post_request(request)
  182. reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).update_watchpoint, params)
  183. return reply
  184. @BLUEPRINT.route("/debugger/sessions/<session_id>/delete-watchpoint", methods=["POST"])
  185. def delete_watchpoint(session_id):
  186. """
  187. Delete watchpoint.
  188. Returns:
  189. str, reply message.
  190. Raises:
  191. MindInsightException: If method fails to be called.
  192. Examples:
  193. >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/delete-watchpoint
  194. """
  195. body = _read_post_request(request)
  196. watch_point_id = body.get('watch_point_id')
  197. reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).delete_watchpoint, watch_point_id)
  198. return reply
  199. @BLUEPRINT.route("/debugger/sessions/<session_id>/control", methods=["POST"])
  200. def control(session_id):
  201. """
  202. Control request.
  203. Returns:
  204. str, reply message.
  205. Raises:
  206. MindInsightException: If method fails to be called.
  207. Examples:
  208. >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/control
  209. """
  210. params = _read_post_request(request)
  211. reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).control, params)
  212. return reply
  213. @BLUEPRINT.route("/debugger/sessions/<session_id>/recheck", methods=["POST"])
  214. def recheck(session_id):
  215. """
  216. Recheck request.
  217. Returns:
  218. str, reply message.
  219. Raises:
  220. MindInsightException: If method fails to be called.
  221. Examples:
  222. >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/recheck
  223. """
  224. reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).recheck)
  225. return reply
  226. @BLUEPRINT.route("/debugger/sessions/<session_id>/tensor-graphs", methods=["GET"])
  227. def retrieve_tensor_graph(session_id):
  228. """
  229. Retrieve tensor value according to name and shape.
  230. Returns:
  231. str, the required data.
  232. Examples:
  233. >>> GET http://xxxx/v1/mindinsight/debugger/sessions/xxxx/tensor-graphs?tensor_name=xxx&graph_name=xxx
  234. """
  235. tensor_name = request.args.get('tensor_name')
  236. graph_name = request.args.get('graph_name')
  237. rank_id = int(request.args.get('rank_id', 0))
  238. reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).retrieve_tensor_graph, tensor_name,
  239. graph_name, rank_id)
  240. return reply
  241. @BLUEPRINT.route("/debugger/sessions/<session_id>/tensor-hits", methods=["GET"])
  242. def retrieve_tensor_hits(session_id):
  243. """
  244. Retrieve tensor value according to name and shape.
  245. Returns:
  246. str, the required data.
  247. Examples:
  248. >>> GET http://xxxx/v1/mindinsight/debugger/sessions/xxxx/tensor-hits?tensor_name=xxx&graph_name=xxx
  249. """
  250. tensor_name = request.args.get('tensor_name')
  251. graph_name = request.args.get('graph_name')
  252. rank_id = int(request.args.get('rank_id', 0))
  253. reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).retrieve_tensor_hits, tensor_name,
  254. graph_name, rank_id)
  255. return reply
  256. @BLUEPRINT.route("/debugger/sessions/<session_id>/search-watchpoint-hits", methods=["POST"])
  257. def search_watchpoint_hits(session_id):
  258. """
  259. Search watchpoint hits by group condition.
  260. Returns:
  261. str, the required data.
  262. Examples:
  263. >>> POST http://xxxx/v1/mindinsight/debugger/sessions/xxxx/search-watchpoint-hits
  264. """
  265. body = _read_post_request(request)
  266. group_condition = body.get('group_condition')
  267. reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).search_watchpoint_hits, group_condition)
  268. return reply
  269. @BLUEPRINT.route("/debugger/sessions/<session_id>/condition-collections", methods=["GET"])
  270. def get_condition_collections(session_id):
  271. """Get condition collections."""
  272. reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).get_condition_collections)
  273. return reply
  274. @BLUEPRINT.route("/debugger/sessions/<session_id>/set-recommended-watch-points", methods=["POST"])
  275. def set_recommended_watch_points(session_id):
  276. """Set recommended watch points."""
  277. body = _read_post_request(request)
  278. request_body = body.get('requestBody')
  279. if request_body is None:
  280. raise ParamMissError('requestBody')
  281. set_recommended = request_body.get('set_recommended')
  282. reply = _wrap_reply(SessionManager.get_instance().get_session(session_id).set_recommended_watch_points,
  283. set_recommended)
  284. return reply
  285. @BLUEPRINT.route("/debugger/sessions", methods=["POST"])
  286. def create_session():
  287. """
  288. Get session id if session exist, else create a session.
  289. Returns:
  290. str, session id.
  291. Examples:
  292. >>> POST http://xxxx/v1/mindinsight/debugger/sessions
  293. """
  294. body = _read_post_request(request)
  295. summary_dir = body.get('dump_dir')
  296. session_type = body.get('session_type')
  297. reply = _wrap_reply(SessionManager.get_instance().create_session, session_type, summary_dir)
  298. return reply
  299. @BLUEPRINT.route("/debugger/sessions", methods=["GET"])
  300. def get_train_jobs():
  301. """
  302. Check the current active sessions.
  303. Examples:
  304. >>> POST http://xxxx/v1/mindinsight/debugger/sessions
  305. """
  306. reply = _wrap_reply(SessionManager.get_instance().get_train_jobs)
  307. return reply
  308. @BLUEPRINT.route("/debugger/sessions/<session_id>/delete", methods=["POST"])
  309. def delete_session(session_id):
  310. """
  311. Delete session by session id.
  312. Examples:
  313. >>> POST http://xxxx/v1/mindinsight/debugger/xxx/delete-session
  314. """
  315. reply = _wrap_reply(SessionManager.get_instance().delete_session, session_id)
  316. return reply
  317. def init_module(app):
  318. """
  319. Init module entry.
  320. Args:
  321. app (Flask): The application obj.
  322. """
  323. app.register_blueprint(BLUEPRINT)