You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

debugger_api.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Debugger restful api."""
  16. import json
  17. from flask import Blueprint, jsonify, request
  18. from mindinsight.conf import settings
  19. from mindinsight.debugger.debugger_server import DebuggerServer
  20. from mindinsight.utils.exceptions import ParamValueError
  21. BLUEPRINT = Blueprint("debugger", __name__,
  22. url_prefix=settings.URL_PATH_PREFIX + settings.API_PREFIX)
  23. def _initialize_debugger_server():
  24. """Initialize a debugger server instance."""
  25. port = settings.DEBUGGER_PORT if hasattr(settings, 'DEBUGGER_PORT') else None
  26. enable_debugger = settings.ENABLE_DEBUGGER if hasattr(settings, 'ENABLE_DEBUGGER') else False
  27. server = None
  28. if port and enable_debugger:
  29. server = DebuggerServer(port)
  30. return server
  31. def _read_post_request(post_request):
  32. """
  33. Extract the body of post request.
  34. Args:
  35. post_request (object): The post request.
  36. Returns:
  37. dict, the deserialized body of request.
  38. """
  39. body = post_request.stream.read()
  40. try:
  41. body = json.loads(body if body else "{}")
  42. except Exception:
  43. raise ParamValueError("Json data parse failed.")
  44. return body
  45. def _wrap_reply(func, *args, **kwargs):
  46. """Serialize reply."""
  47. reply = func(*args, **kwargs)
  48. return jsonify(reply)
  49. @BLUEPRINT.route("/debugger/poll_data", methods=["GET"])
  50. def poll_data():
  51. """
  52. Wait for data to be updated on UI.
  53. Get data from server and display the change on UI.
  54. Returns:
  55. str, the updated data.
  56. Examples:
  57. >>> Get http://xxxx/v1/mindinsight/debugger/poll_data?pos=xx
  58. """
  59. pos = request.args.get('pos')
  60. reply = _wrap_reply(BACKEND_SERVER.poll_data, pos)
  61. return reply
  62. @BLUEPRINT.route("/debugger/search", methods=["GET"])
  63. def search():
  64. """
  65. Search nodes in specified watchpoint.
  66. Returns:
  67. str, the required data.
  68. Examples:
  69. >>> Get http://xxxx/v1/mindinsight/debugger/search?name=mock_name&watch_point_id=1
  70. """
  71. name = request.args.get('name')
  72. graph_name = request.args.get('graph_name')
  73. watch_point_id = int(request.args.get('watch_point_id', 0))
  74. node_category = request.args.get('node_category')
  75. reply = _wrap_reply(BACKEND_SERVER.search, {'name': name,
  76. 'graph_name': graph_name,
  77. 'watch_point_id': watch_point_id,
  78. 'node_category': node_category})
  79. return reply
  80. @BLUEPRINT.route("/debugger/retrieve_node_by_bfs", methods=["GET"])
  81. def retrieve_node_by_bfs():
  82. """
  83. Search node by bfs.
  84. Returns:
  85. str, the required data.
  86. Examples:
  87. >>> Get http://xxxx/v1/mindinsight/debugger/retrieve_node_by_bfs?name=node_name&ascend=true
  88. """
  89. name = request.args.get('name')
  90. graph_name = request.args.get('graph_name')
  91. ascend = request.args.get('ascend', 'false')
  92. ascend = ascend == 'true'
  93. reply = _wrap_reply(BACKEND_SERVER.retrieve_node_by_bfs, name, graph_name, ascend)
  94. return reply
  95. @BLUEPRINT.route("/debugger/tensor-comparisons", methods=["GET"])
  96. def tensor_comparisons():
  97. """
  98. Get tensor comparisons.
  99. Returns:
  100. str, the required data.
  101. Examples:
  102. >>> Get http://xxxx/v1/mindinsight/debugger/tensor-comparisons
  103. """
  104. name = request.args.get('name')
  105. detail = request.args.get('detail', 'data')
  106. shape = request.args.get('shape')
  107. tolerance = request.args.get('tolerance', '0')
  108. reply = _wrap_reply(BACKEND_SERVER.tensor_comparisons, name, shape, detail, tolerance)
  109. return reply
  110. @BLUEPRINT.route("/debugger/retrieve", methods=["POST"])
  111. def retrieve():
  112. """
  113. Retrieve data according to mode and params.
  114. Returns:
  115. str, the required data.
  116. Examples:
  117. >>> POST http://xxxx/v1/mindinsight/debugger/retrieve
  118. """
  119. body = _read_post_request(request)
  120. mode = body.get('mode')
  121. params = body.get('params')
  122. reply = _wrap_reply(BACKEND_SERVER.retrieve, mode, params)
  123. return reply
  124. @BLUEPRINT.route("/debugger/retrieve_tensor_history", methods=["POST"])
  125. def retrieve_tensor_history():
  126. """
  127. Retrieve data according to mode and params.
  128. Returns:
  129. str, the required data.
  130. Examples:
  131. >>> POST http://xxxx/v1/mindinsight/debugger/retrieve_tensor_history
  132. """
  133. body = _read_post_request(request)
  134. name = body.get('name')
  135. graph_name = body.get('graph_name')
  136. reply = _wrap_reply(BACKEND_SERVER.retrieve_tensor_history, name, graph_name)
  137. return reply
  138. @BLUEPRINT.route("/debugger/tensors", methods=["GET"])
  139. def retrieve_tensor_value():
  140. """
  141. Retrieve tensor value according to name and shape.
  142. Returns:
  143. str, the required data.
  144. Examples:
  145. >>> GET http://xxxx/v1/mindinsight/debugger/tensors?name=tensor_name&detail=data&shape=[1,1,:,:]
  146. """
  147. name = request.args.get('name')
  148. detail = request.args.get('detail')
  149. shape = request.args.get('shape')
  150. graph_name = request.args.get('graph_name')
  151. prev = bool(request.args.get('prev') == 'true')
  152. reply = _wrap_reply(BACKEND_SERVER.retrieve_tensor_value, name, detail, shape, graph_name, prev)
  153. return reply
  154. @BLUEPRINT.route("/debugger/create_watchpoint", methods=["POST"])
  155. def create_watchpoint():
  156. """
  157. Create watchpoint.
  158. Returns:
  159. str, watchpoint id.
  160. Raises:
  161. MindInsightException: If method fails to be called.
  162. Examples:
  163. >>> POST http://xxxx/v1/mindinsight/debugger/create_watchpoint
  164. """
  165. body = _read_post_request(request)
  166. condition = body.get('condition')
  167. graph_name = body.get('graph_name')
  168. watch_nodes = body.get('watch_nodes')
  169. watch_point_id = body.get('watch_point_id')
  170. search_pattern = body.get('search_pattern')
  171. reply = _wrap_reply(BACKEND_SERVER.create_watchpoint,
  172. condition, watch_nodes, watch_point_id, search_pattern, graph_name)
  173. return reply
  174. @BLUEPRINT.route("/debugger/update_watchpoint", methods=["POST"])
  175. def update_watchpoint():
  176. """
  177. Update watchpoint.
  178. Returns:
  179. str, reply message.
  180. Raises:
  181. MindInsightException: If method fails to be called.
  182. Examples:
  183. >>> POST http://xxxx/v1/mindinsight/debugger/update_watchpoint
  184. """
  185. body = _read_post_request(request)
  186. watch_point_id = body.get('watch_point_id')
  187. watch_nodes = body.get('watch_nodes')
  188. graph_name = body.get('graph_name')
  189. mode = body.get('mode')
  190. pattern = body.get('search_pattern')
  191. reply = _wrap_reply(BACKEND_SERVER.update_watchpoint, watch_point_id, watch_nodes, mode, pattern, graph_name)
  192. return reply
  193. @BLUEPRINT.route("/debugger/delete_watchpoint", methods=["POST"])
  194. def delete_watchpoint():
  195. """
  196. delete watchpoint.
  197. Returns:
  198. str, reply message.
  199. Raises:
  200. MindInsightException: If method fails to be called.
  201. Examples:
  202. >>> POST http://xxxx/v1/mindinsight/debugger/delete_watchpoint
  203. """
  204. body = _read_post_request(request)
  205. watch_point_id = body.get('watch_point_id')
  206. reply = _wrap_reply(BACKEND_SERVER.delete_watchpoint, watch_point_id)
  207. return reply
  208. @BLUEPRINT.route("/debugger/control", methods=["POST"])
  209. def control():
  210. """
  211. Control request.
  212. Returns:
  213. str, reply message.
  214. Raises:
  215. MindInsightException: If method fails to be called.
  216. Examples:
  217. >>> POST http://xxxx/v1/mindinsight/debugger/control
  218. """
  219. params = _read_post_request(request)
  220. reply = _wrap_reply(BACKEND_SERVER.control, params)
  221. return reply
  222. @BLUEPRINT.route("/debugger/recheck", methods=["POST"])
  223. def recheck():
  224. """
  225. Recheck request.
  226. Returns:
  227. str, reply message.
  228. Raises:
  229. MindInsightException: If method fails to be called.
  230. Examples:
  231. >>> POST http://xxxx/v1/mindinsight/debugger/recheck
  232. """
  233. reply = _wrap_reply(BACKEND_SERVER.recheck)
  234. return reply
  235. @BLUEPRINT.route("/debugger/tensor_graphs", methods=["GET"])
  236. def retrieve_tensor_graph():
  237. """
  238. Retrieve tensor value according to name and shape.
  239. Returns:
  240. str, the required data.
  241. Examples:
  242. >>> GET http://xxxx/v1/mindinsight/debugger/tensor_graphs?tensor_name=tensor_name$graph_name=graph_name
  243. """
  244. tensor_name = request.args.get('tensor_name')
  245. graph_name = request.args.get('graph_name')
  246. reply = _wrap_reply(BACKEND_SERVER.retrieve_tensor_graph, tensor_name, graph_name)
  247. return reply
  248. @BLUEPRINT.route("/debugger/tensor_hits", methods=["GET"])
  249. def retrieve_tensor_hits():
  250. """
  251. Retrieve tensor value according to name and shape.
  252. Returns:
  253. str, the required data.
  254. Examples:
  255. >>> GET http://xxxx/v1/mindinsight/debugger/tensor_hits?tensor_name=tensor_name$graph_name=graph_name
  256. """
  257. tensor_name = request.args.get('tensor_name')
  258. graph_name = request.args.get('graph_name')
  259. reply = _wrap_reply(BACKEND_SERVER.retrieve_tensor_hits, tensor_name, graph_name)
  260. return reply
  261. BACKEND_SERVER = _initialize_debugger_server()
  262. def init_module(app):
  263. """
  264. Init module entry.
  265. Args:
  266. app (Flask): The application obj.
  267. """
  268. app.register_blueprint(BLUEPRINT)
  269. if BACKEND_SERVER:
  270. BACKEND_SERVER.start()