You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

debugger_api.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Debugger restful api."""
  16. import json
  17. from urllib.parse import unquote
  18. from flask import Blueprint, jsonify, request
  19. from mindinsight.conf import settings
  20. from mindinsight.debugger.debugger_server import DebuggerServer
  21. from mindinsight.utils.exceptions import ParamValueError
  22. BLUEPRINT = Blueprint("debugger", __name__,
  23. url_prefix=settings.URL_PATH_PREFIX + settings.API_PREFIX)
  24. def _initialize_debugger_server():
  25. """Initialize a debugger server instance."""
  26. enable_debugger = settings.ENABLE_DEBUGGER if hasattr(settings, 'ENABLE_DEBUGGER') else False
  27. server = None
  28. if enable_debugger:
  29. server = DebuggerServer()
  30. return server
  31. def _unquote_param(param):
  32. """
  33. Decode parameter value.
  34. Args:
  35. param (str): Encoded param value.
  36. Returns:
  37. str, decoded param value.
  38. """
  39. if isinstance(param, str):
  40. try:
  41. param = unquote(param, errors='strict')
  42. except UnicodeDecodeError:
  43. raise ParamValueError('Unquote error with strict mode.')
  44. return param
  45. def _read_post_request(post_request):
  46. """
  47. Extract the body of post request.
  48. Args:
  49. post_request (object): The post request.
  50. Returns:
  51. dict, the deserialized body of request.
  52. """
  53. body = post_request.stream.read()
  54. try:
  55. body = json.loads(body if body else "{}")
  56. except Exception:
  57. raise ParamValueError("Json data parse failed.")
  58. return body
  59. def _wrap_reply(func, *args, **kwargs):
  60. """Serialize reply."""
  61. reply = func(*args, **kwargs)
  62. return jsonify(reply)
  63. @BLUEPRINT.route("/debugger/poll-data", methods=["GET"])
  64. def poll_data():
  65. """
  66. Wait for data to be updated on UI.
  67. Get data from server and display the change on UI.
  68. Returns:
  69. str, the updated data.
  70. Examples:
  71. >>> Get http://xxxx/v1/mindinsight/debugger/poll-data?pos=xx
  72. """
  73. pos = request.args.get('pos')
  74. reply = _wrap_reply(BACKEND_SERVER.poll_data, pos)
  75. return reply
  76. @BLUEPRINT.route("/debugger/search", methods=["GET"])
  77. def search():
  78. """
  79. Search nodes in specified watchpoint.
  80. Returns:
  81. str, the required data.
  82. Examples:
  83. >>> Get http://xxxx/v1/mindinsight/debugger/search?name=mock_name&watch_point_id=1
  84. """
  85. name = request.args.get('name')
  86. graph_name = request.args.get('graph_name')
  87. watch_point_id = int(request.args.get('watch_point_id', 0))
  88. node_category = request.args.get('node_category')
  89. reply = _wrap_reply(BACKEND_SERVER.search, {'name': name,
  90. 'graph_name': graph_name,
  91. 'watch_point_id': watch_point_id,
  92. 'node_category': node_category})
  93. return reply
  94. @BLUEPRINT.route("/debugger/retrieve_node_by_bfs", methods=["GET"])
  95. def retrieve_node_by_bfs():
  96. """
  97. Search node by bfs.
  98. Returns:
  99. str, the required data.
  100. Examples:
  101. >>> Get http://xxxx/v1/mindinsight/debugger/retrieve_node_by_bfs?name=node_name&ascend=true
  102. """
  103. name = request.args.get('name')
  104. graph_name = request.args.get('graph_name')
  105. ascend = request.args.get('ascend', 'false')
  106. ascend = ascend == 'true'
  107. reply = _wrap_reply(BACKEND_SERVER.retrieve_node_by_bfs, name, graph_name, ascend)
  108. return reply
  109. @BLUEPRINT.route("/debugger/tensor-comparisons", methods=["GET"])
  110. def tensor_comparisons():
  111. """
  112. Get tensor comparisons.
  113. Returns:
  114. str, the required data.
  115. Examples:
  116. >>> Get http://xxxx/v1/mindinsight/debugger/tensor-comparisons
  117. """
  118. name = request.args.get('name')
  119. detail = request.args.get('detail', 'data')
  120. shape = _unquote_param(request.args.get('shape'))
  121. tolerance = request.args.get('tolerance', '0')
  122. reply = _wrap_reply(BACKEND_SERVER.tensor_comparisons, name, shape, detail, tolerance)
  123. return reply
  124. @BLUEPRINT.route("/debugger/retrieve", methods=["POST"])
  125. def retrieve():
  126. """
  127. Retrieve data according to mode and params.
  128. Returns:
  129. str, the required data.
  130. Examples:
  131. >>> POST http://xxxx/v1/mindinsight/debugger/retrieve
  132. """
  133. body = _read_post_request(request)
  134. mode = body.get('mode')
  135. params = body.get('params')
  136. reply = _wrap_reply(BACKEND_SERVER.retrieve, mode, params)
  137. return reply
  138. @BLUEPRINT.route("/debugger/tensor-history", methods=["POST"])
  139. def retrieve_tensor_history():
  140. """
  141. Retrieve data according to mode and params.
  142. Returns:
  143. str, the required data.
  144. Examples:
  145. >>> POST http://xxxx/v1/mindinsight/debugger/tensor-history
  146. """
  147. body = _read_post_request(request)
  148. name = body.get('name')
  149. graph_name = body.get('graph_name')
  150. reply = _wrap_reply(BACKEND_SERVER.retrieve_tensor_history, name, graph_name)
  151. return reply
  152. @BLUEPRINT.route("/debugger/tensors", methods=["GET"])
  153. def retrieve_tensor_value():
  154. """
  155. Retrieve tensor value according to name and shape.
  156. Returns:
  157. str, the required data.
  158. Examples:
  159. >>> GET http://xxxx/v1/mindinsight/debugger/tensors?name=tensor_name&detail=data&shape=[1,1,:,:]
  160. """
  161. name = request.args.get('name')
  162. detail = request.args.get('detail')
  163. shape = _unquote_param(request.args.get('shape'))
  164. graph_name = request.args.get('graph_name')
  165. prev = bool(request.args.get('prev') == 'true')
  166. reply = _wrap_reply(BACKEND_SERVER.retrieve_tensor_value, name, detail, shape, graph_name, prev)
  167. return reply
  168. @BLUEPRINT.route("/debugger/create-watchpoint", methods=["POST"])
  169. def create_watchpoint():
  170. """
  171. Create watchpoint.
  172. Returns:
  173. str, watchpoint id.
  174. Raises:
  175. MindInsightException: If method fails to be called.
  176. Examples:
  177. >>> POST http://xxxx/v1/mindinsight/debugger/create-watchpoint
  178. """
  179. params = _read_post_request(request)
  180. params['watch_condition'] = params.pop('condition', None)
  181. reply = _wrap_reply(BACKEND_SERVER.create_watchpoint, params)
  182. return reply
  183. @BLUEPRINT.route("/debugger/update-watchpoint", methods=["POST"])
  184. def update_watchpoint():
  185. """
  186. Update watchpoint.
  187. Returns:
  188. str, reply message.
  189. Raises:
  190. MindInsightException: If method fails to be called.
  191. Examples:
  192. >>> POST http://xxxx/v1/mindinsight/debugger/update-watchpoint
  193. """
  194. params = _read_post_request(request)
  195. reply = _wrap_reply(BACKEND_SERVER.update_watchpoint, params)
  196. return reply
  197. @BLUEPRINT.route("/debugger/delete-watchpoint", methods=["POST"])
  198. def delete_watchpoint():
  199. """
  200. delete watchpoint.
  201. Returns:
  202. str, reply message.
  203. Raises:
  204. MindInsightException: If method fails to be called.
  205. Examples:
  206. >>> POST http://xxxx/v1/mindinsight/debugger/delete-watchpoint
  207. """
  208. body = _read_post_request(request)
  209. watch_point_id = body.get('watch_point_id')
  210. reply = _wrap_reply(BACKEND_SERVER.delete_watchpoint, watch_point_id)
  211. return reply
  212. @BLUEPRINT.route("/debugger/control", methods=["POST"])
  213. def control():
  214. """
  215. Control request.
  216. Returns:
  217. str, reply message.
  218. Raises:
  219. MindInsightException: If method fails to be called.
  220. Examples:
  221. >>> POST http://xxxx/v1/mindinsight/debugger/control
  222. """
  223. params = _read_post_request(request)
  224. reply = _wrap_reply(BACKEND_SERVER.control, params)
  225. return reply
  226. @BLUEPRINT.route("/debugger/recheck", methods=["POST"])
  227. def recheck():
  228. """
  229. Recheck request.
  230. Returns:
  231. str, reply message.
  232. Raises:
  233. MindInsightException: If method fails to be called.
  234. Examples:
  235. >>> POST http://xxxx/v1/mindinsight/debugger/recheck
  236. """
  237. reply = _wrap_reply(BACKEND_SERVER.recheck)
  238. return reply
  239. @BLUEPRINT.route("/debugger/tensor-graphs", methods=["GET"])
  240. def retrieve_tensor_graph():
  241. """
  242. Retrieve tensor value according to name and shape.
  243. Returns:
  244. str, the required data.
  245. Examples:
  246. >>> GET http://xxxx/v1/mindinsight/debugger/tensor-graphs?tensor_name=tensor_name&graph_name=graph_name
  247. """
  248. tensor_name = request.args.get('tensor_name')
  249. graph_name = request.args.get('graph_name')
  250. reply = _wrap_reply(BACKEND_SERVER.retrieve_tensor_graph, tensor_name, graph_name)
  251. return reply
  252. @BLUEPRINT.route("/debugger/tensor-hits", methods=["GET"])
  253. def retrieve_tensor_hits():
  254. """
  255. Retrieve tensor value according to name and shape.
  256. Returns:
  257. str, the required data.
  258. Examples:
  259. >>> GET http://xxxx/v1/mindinsight/debugger/tensor-hits?tensor_name=tensor_name&graph_name=graph_name
  260. """
  261. tensor_name = request.args.get('tensor_name')
  262. graph_name = request.args.get('graph_name')
  263. reply = _wrap_reply(BACKEND_SERVER.retrieve_tensor_hits, tensor_name, graph_name)
  264. return reply
  265. @BLUEPRINT.route("/debugger/search-watchpoint-hits", methods=["POST"])
  266. def search_watchpoint_hits():
  267. """
  268. Search watchpoint hits by group condition.
  269. Returns:
  270. str, the required data.
  271. Examples:
  272. >>> POST http://xxxx/v1/mindinsight/debugger/search-watchpoint-hits
  273. """
  274. body = _read_post_request(request)
  275. group_condition = body.get('group_condition')
  276. reply = _wrap_reply(BACKEND_SERVER.search_watchpoint_hits, group_condition)
  277. return reply
  278. BACKEND_SERVER = _initialize_debugger_server()
  279. def init_module(app):
  280. """
  281. Init module entry.
  282. Args:
  283. app (Flask): The application obj.
  284. """
  285. app.register_blueprint(BLUEPRINT)
  286. if BACKEND_SERVER:
  287. BACKEND_SERVER.start()