You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_restful_api.py 30 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Function:
  17. Test query debugger restful api.
  18. Usage:
  19. pytest tests/st/func/debugger/test_restful_api.py
  20. """
  21. import os
  22. import pytest
  23. from mindinsight.conf import settings
  24. from tests.st.func.debugger.conftest import DEBUGGER_BASE_URL
  25. from tests.st.func.debugger.mock_ms_client import MockDebuggerClient
  26. from tests.st.func.debugger.utils import check_waiting_state, get_request_result, \
  27. send_and_compare_result
  28. def send_terminate_cmd(app_client):
  29. """Send terminate command to debugger client."""
  30. url = os.path.join(DEBUGGER_BASE_URL, 'control')
  31. body_data = {'mode': 'terminate'}
  32. send_and_compare_result(app_client, url, body_data)
  33. class TestAscendDebugger:
  34. """Test debugger on Ascend backend."""
  35. @classmethod
  36. def setup_class(cls):
  37. """Setup class."""
  38. cls._debugger_client = MockDebuggerClient(backend='Ascend')
  39. @pytest.mark.level0
  40. @pytest.mark.env_single
  41. @pytest.mark.platform_x86_cpu
  42. @pytest.mark.platform_arm_ascend_training
  43. @pytest.mark.platform_x86_gpu_training
  44. @pytest.mark.platform_x86_ascend_training
  45. def test_before_train_begin(self, app_client):
  46. """Test retrieve all."""
  47. url = 'retrieve'
  48. body_data = {'mode': 'all'}
  49. expect_file = 'before_train_begin.json'
  50. send_and_compare_result(app_client, url, body_data, expect_file)
  51. @pytest.mark.level0
  52. @pytest.mark.env_single
  53. @pytest.mark.platform_x86_cpu
  54. @pytest.mark.platform_arm_ascend_training
  55. @pytest.mark.platform_x86_gpu_training
  56. @pytest.mark.platform_x86_ascend_training
  57. @pytest.mark.parametrize("body_data, expect_file", [
  58. ({'mode': 'all'}, 'retrieve_all.json'),
  59. ({'mode': 'node', 'params': {'name': 'Default'}}, 'retrieve_scope_node.json'),
  60. ({'mode': 'node', 'params': {'name': 'Default/optimizer-Momentum/Parameter[18]_7'}},
  61. 'retrieve_aggregation_scope_node.json'),
  62. ({'mode': 'node', 'params': {
  63. 'name': 'Default/TransData-op99',
  64. 'single_node': True}}, 'retrieve_single_node.json'),
  65. ({'mode': 'watchpoint_hit'}, 'retrieve_empty_watchpoint_hit_list')
  66. ])
  67. def test_retrieve_when_train_begin(self, app_client, body_data, expect_file):
  68. """Test retrieve when train_begin."""
  69. url = 'retrieve'
  70. with self._debugger_client.get_thread_instance():
  71. check_waiting_state(app_client)
  72. send_and_compare_result(app_client, url, body_data, expect_file)
  73. send_terminate_cmd(app_client)
  74. def test_get_conditions(self, app_client):
  75. """Test get conditions for ascend."""
  76. url = '/v1/mindinsight/conditionmgr/train-jobs/train-id/conditions'
  77. body_data = {}
  78. expect_file = 'get_conditions_for_ascend.json'
  79. with self._debugger_client.get_thread_instance():
  80. check_waiting_state(app_client)
  81. send_and_compare_result(app_client, url, body_data, expect_file, method='get', full_url=True)
  82. send_terminate_cmd(app_client)
  83. @pytest.mark.level0
  84. @pytest.mark.env_single
  85. @pytest.mark.platform_x86_cpu
  86. @pytest.mark.platform_arm_ascend_training
  87. @pytest.mark.platform_x86_gpu_training
  88. @pytest.mark.platform_x86_ascend_training
  89. @pytest.mark.parametrize("body_data, expect_file", [
  90. ({'mode': 'all'}, 'multi_retrieve_all.json'),
  91. ({'mode': 'node', 'params': {'name': 'Default', 'graph_name': 'graph_1'}}, 'retrieve_scope_node.json'),
  92. ({'mode': 'node', 'params': {'name': 'graph_0'}}, 'multi_retrieve_scope_node.json'),
  93. ({'mode': 'node', 'params': {'name': 'graph_0/Default/optimizer-Momentum/Parameter[18]_7'}},
  94. 'multi_retrieve_aggregation_scope_node.json'),
  95. ({'mode': 'node', 'params': {
  96. 'name': 'graph_0/Default/TransData-op99',
  97. 'single_node': True}}, 'multi_retrieve_single_node.json'),
  98. ({'mode': 'node', 'params': {
  99. 'name': 'Default/TransData-op99',
  100. 'single_node': True, 'graph_name': 'graph_0'}}, 'retrieve_single_node.json')
  101. ])
  102. def test_multi_retrieve_when_train_begin(self, app_client, body_data, expect_file):
  103. """Test retrieve when train_begin."""
  104. url = 'retrieve'
  105. debugger_client = MockDebuggerClient(backend='Ascend', graph_num=2)
  106. with debugger_client.get_thread_instance():
  107. check_waiting_state(app_client)
  108. send_and_compare_result(app_client, url, body_data, expect_file)
  109. send_terminate_cmd(app_client)
  110. @pytest.mark.level0
  111. @pytest.mark.env_single
  112. @pytest.mark.platform_x86_cpu
  113. @pytest.mark.platform_arm_ascend_training
  114. @pytest.mark.platform_x86_gpu_training
  115. @pytest.mark.platform_x86_ascend_training
  116. def test_create_and_delete_watchpoint(self, app_client):
  117. """Test create and delete watchpoint."""
  118. with self._debugger_client.get_thread_instance():
  119. check_waiting_state(app_client)
  120. conditions = [
  121. {'id': 'max_gt', 'params': [{'name': 'param', 'value': 1.0, 'disable': False}]},
  122. {'id': 'max_lt', 'params': [{'name': 'param', 'value': -1.0, 'disable': False}]},
  123. {'id': 'min_gt', 'params': [{'name': 'param', 'value': 1e+32, 'disable': False}]},
  124. {'id': 'min_lt', 'params': [{'name': 'param', 'value': -1e+32, 'disable': False}]},
  125. {'id': 'max_min_gt', 'params': [{'name': 'param', 'value': 0, 'disable': False}]},
  126. {'id': 'max_min_lt', 'params': [{'name': 'param', 'value': 0, 'disable': False}]},
  127. {'id': 'mean_gt', 'params': [{'name': 'param', 'value': 0, 'disable': False}]},
  128. {'id': 'mean_lt', 'params': [{'name': 'param', 'value': 0, 'disable': False}]},
  129. {'id': 'inf', 'params': []},
  130. {'id': 'overflow', 'params': []},
  131. ]
  132. for idx, condition in enumerate(conditions):
  133. create_watchpoint(app_client, condition, idx + 1)
  134. # delete 4-th watchpoint
  135. url = 'delete_watchpoint'
  136. body_data = {'watch_point_id': 4}
  137. get_request_result(app_client, url, body_data)
  138. # test watchpoint list
  139. url = 'retrieve'
  140. body_data = {'mode': 'watchpoint'}
  141. expect_file = 'create_and_delete_watchpoint.json'
  142. send_and_compare_result(app_client, url, body_data, expect_file)
  143. send_terminate_cmd(app_client)
  144. @pytest.mark.level0
  145. @pytest.mark.env_single
  146. @pytest.mark.platform_x86_cpu
  147. @pytest.mark.platform_arm_ascend_training
  148. @pytest.mark.platform_x86_gpu_training
  149. @pytest.mark.platform_x86_ascend_training
  150. def test_update_watchpoint(self, app_client):
  151. """Test retrieve when train_begin."""
  152. watch_point_id = 1
  153. leaf_node_name = 'Default/optimizer-Momentum/Parameter[18]_7/moments.fc3.bias'
  154. with self._debugger_client.get_thread_instance():
  155. check_waiting_state(app_client)
  156. condition = {'id': 'inf', 'params': []}
  157. create_watchpoint(app_client, condition, watch_point_id)
  158. # update watchpoint watchpoint list
  159. url = 'update_watchpoint'
  160. body_data = {'watch_point_id': watch_point_id,
  161. 'watch_nodes': [leaf_node_name],
  162. 'mode': 0}
  163. get_request_result(app_client, url, body_data)
  164. # get updated nodes
  165. url = 'search'
  166. body_data = {'name': leaf_node_name, 'watch_point_id': watch_point_id}
  167. expect_file = 'search_unwatched_leaf_node.json'
  168. send_and_compare_result(app_client, url, body_data, expect_file, method='get')
  169. send_terminate_cmd(app_client)
  170. @pytest.mark.level0
  171. @pytest.mark.env_single
  172. @pytest.mark.platform_x86_cpu
  173. @pytest.mark.platform_arm_ascend_training
  174. @pytest.mark.platform_x86_gpu_training
  175. @pytest.mark.platform_x86_ascend_training
  176. def test_watchpoint_hit(self, app_client):
  177. """Test retrieve watchpoint hit."""
  178. with self._debugger_client.get_thread_instance():
  179. create_watchpoint_and_wait(app_client)
  180. # check watchpoint hit list
  181. url = 'retrieve'
  182. body_data = {'mode': 'watchpoint_hit'}
  183. expect_file = 'retrieve_watchpoint_hit.json'
  184. send_and_compare_result(app_client, url, body_data, expect_file)
  185. # check single watchpoint hit
  186. body_data = {
  187. 'mode': 'watchpoint_hit',
  188. 'params': {
  189. 'name': 'Default/TransData-op99',
  190. 'single_node': True,
  191. 'watch_point_id': 1
  192. }
  193. }
  194. expect_file = 'retrieve_single_watchpoint_hit.json'
  195. send_and_compare_result(app_client, url, body_data, expect_file)
  196. send_terminate_cmd(app_client)
  197. @pytest.mark.level0
  198. @pytest.mark.env_single
  199. @pytest.mark.platform_x86_cpu
  200. @pytest.mark.platform_arm_ascend_training
  201. @pytest.mark.platform_x86_gpu_training
  202. @pytest.mark.platform_x86_ascend_training
  203. def test_retrieve_tensor_value(self, app_client):
  204. """Test retrieve tensor value."""
  205. node_name = 'Default/TransData-op99'
  206. with self._debugger_client.get_thread_instance():
  207. check_waiting_state(app_client)
  208. # prepare tensor value
  209. url = 'retrieve_tensor_history'
  210. body_data = {'name': node_name}
  211. expect_file = 'retrieve_empty_tensor_history.json'
  212. send_and_compare_result(app_client, url, body_data, expect_file)
  213. # check full tensor history from poll data
  214. res = get_request_result(
  215. app_client=app_client, url='poll_data', body_data={'pos': 0}, method='get')
  216. assert res.get('receive_tensor', {}).get('node_name') == node_name
  217. expect_file = 'retrieve_full_tensor_history.json'
  218. send_and_compare_result(app_client, url, body_data, expect_file)
  219. # check tensor value
  220. url = 'tensors'
  221. body_data = {
  222. 'name': node_name + ':0',
  223. 'detail': 'data',
  224. 'shape': '[1, 1:3]'
  225. }
  226. expect_file = 'retrieve_tensor_value.json'
  227. send_and_compare_result(app_client, url, body_data, expect_file, method='get')
  228. send_terminate_cmd(app_client)
  229. @pytest.mark.level0
  230. @pytest.mark.env_single
  231. @pytest.mark.platform_x86_cpu
  232. @pytest.mark.platform_arm_ascend_training
  233. @pytest.mark.platform_x86_gpu_training
  234. @pytest.mark.platform_x86_ascend_training
  235. def test_compare_tensor_value(self, app_client):
  236. """Test compare tensor value."""
  237. node_name = 'Default/args0'
  238. with self._debugger_client.get_thread_instance():
  239. check_waiting_state(app_client)
  240. # prepare tensor values
  241. url = 'control'
  242. body_data = {'mode': 'continue',
  243. 'steps': 2}
  244. get_request_result(app_client, url, body_data)
  245. check_waiting_state(app_client)
  246. get_request_result(
  247. app_client=app_client, url='retrieve_tensor_history', body_data={'name': node_name})
  248. res = get_request_result(
  249. app_client=app_client, url='poll_data', body_data={'pos': 0}, method='get')
  250. assert res.get('receive_tensor', {}).get('node_name') == node_name
  251. # get compare results
  252. url = 'tensor-comparisons'
  253. body_data = {
  254. 'name': node_name + ':0',
  255. 'detail': 'data',
  256. 'shape': '[:, :]',
  257. 'tolerance': 1
  258. }
  259. expect_file = 'compare_tensors.json'
  260. send_and_compare_result(app_client, url, body_data, expect_file, method='get')
  261. send_terminate_cmd(app_client)
  262. @pytest.mark.level0
  263. @pytest.mark.env_single
  264. @pytest.mark.platform_x86_cpu
  265. @pytest.mark.platform_arm_ascend_training
  266. @pytest.mark.platform_x86_gpu_training
  267. @pytest.mark.platform_x86_ascend_training
  268. @pytest.mark.parametrize("body_data, expect_file", [
  269. ({'ascend': True}, 'retrieve_node_by_bfs_ascend.json'),
  270. ({'name': 'Default/args0', 'ascend': False}, 'retrieve_node_by_bfs.json')
  271. ])
  272. def test_retrieve_bfs_node(self, app_client, body_data, expect_file):
  273. """Test retrieve bfs node."""
  274. with self._debugger_client.get_thread_instance():
  275. check_waiting_state(app_client)
  276. # prepare tensor values
  277. url = 'retrieve_node_by_bfs'
  278. send_and_compare_result(app_client, url, body_data, expect_file, method='get')
  279. send_terminate_cmd(app_client)
  280. @pytest.mark.level0
  281. @pytest.mark.env_single
  282. @pytest.mark.platform_x86_cpu
  283. @pytest.mark.platform_arm_ascend_training
  284. @pytest.mark.platform_x86_gpu_training
  285. @pytest.mark.platform_x86_ascend_training
  286. def test_pause(self, app_client):
  287. """Test pause the training."""
  288. with self._debugger_client.get_thread_instance():
  289. check_waiting_state(app_client)
  290. # send run command to execute to next node
  291. url = 'control'
  292. body_data = {'mode': 'continue',
  293. 'steps': -1}
  294. res = get_request_result(app_client, url, body_data)
  295. assert res == {'metadata': {'state': 'running', 'enable_recheck': False}}
  296. # send pause command
  297. url = 'control'
  298. body_data = {'mode': 'pause'}
  299. res = get_request_result(app_client, url, body_data)
  300. assert res == {'metadata': {'state': 'waiting', 'enable_recheck': False}}
  301. send_terminate_cmd(app_client)
  302. @pytest.mark.level0
  303. @pytest.mark.env_single
  304. @pytest.mark.platform_x86_cpu
  305. @pytest.mark.platform_arm_ascend_training
  306. @pytest.mark.platform_x86_gpu_training
  307. @pytest.mark.platform_x86_ascend_training
  308. @pytest.mark.parametrize("url, body_data, enable_recheck", [
  309. ('create_watchpoint',
  310. {'condition': {'id': 'inf', 'params': []},
  311. 'watch_nodes': ['Default']}, True),
  312. ('update_watchpoint',
  313. {'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum/Parameter[18]_7'],
  314. 'mode': 0}, True),
  315. ('update_watchpoint',
  316. {'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum'],
  317. 'mode': 1}, True),
  318. ('delete_watchpoint', {}, True)
  319. ])
  320. def test_recheck(self, app_client, url, body_data, enable_recheck):
  321. """Test recheck."""
  322. with self._debugger_client.get_thread_instance():
  323. create_watchpoint_and_wait(app_client)
  324. # create watchpoint
  325. res = get_request_result(app_client, url, body_data, method='post')
  326. assert res['metadata']['enable_recheck'] is enable_recheck
  327. send_terminate_cmd(app_client)
  328. @pytest.mark.level0
  329. @pytest.mark.env_single
  330. @pytest.mark.platform_x86_cpu
  331. @pytest.mark.platform_arm_ascend_training
  332. @pytest.mark.platform_x86_gpu_training
  333. @pytest.mark.platform_x86_ascend_training
  334. def test_recommend_watchpoints(self, app_client):
  335. """Test generating recommended watchpoints."""
  336. original_value = settings.ENABLE_RECOMMENDED_WATCHPOINTS
  337. settings.ENABLE_RECOMMENDED_WATCHPOINTS = True
  338. try:
  339. with self._debugger_client.get_thread_instance():
  340. check_waiting_state(app_client)
  341. url = 'retrieve'
  342. body_data = {'mode': 'watchpoint'}
  343. expect_file = 'recommended_watchpoints_at_startup.json'
  344. send_and_compare_result(app_client, url, body_data, expect_file, method='post')
  345. send_terminate_cmd(app_client)
  346. finally:
  347. settings.ENABLE_RECOMMENDED_WATCHPOINTS = original_value
  348. @pytest.mark.level0
  349. @pytest.mark.env_single
  350. @pytest.mark.platform_x86_cpu
  351. @pytest.mark.platform_arm_ascend_training
  352. @pytest.mark.platform_x86_gpu_training
  353. @pytest.mark.platform_x86_ascend_training
  354. @pytest.mark.parametrize("body_data, expect_file", [
  355. ({'tensor_name': 'Default/TransData-op99:0', 'graph_name': 'graph_0'}, 'retrieve_tensor_graph-0.json'),
  356. ({'tensor_name': 'Default/optimizer-Momentum/Parameter[18]_7/moments.fc1.bias:0', 'graph_name': 'graph_0'},
  357. 'retrieve_tensor_graph-1.json')
  358. ])
  359. def test_retrieve_tensor_graph(self, app_client, body_data, expect_file):
  360. """Test retrieve tensor graph."""
  361. url = 'tensor_graphs'
  362. with self._debugger_client.get_thread_instance():
  363. create_watchpoint_and_wait(app_client)
  364. send_and_compare_result(app_client, url, body_data, expect_file, method='GET')
  365. send_terminate_cmd(app_client)
  366. class TestGPUDebugger:
  367. """Test debugger on Ascend backend."""
  368. @classmethod
  369. def setup_class(cls):
  370. """Setup class."""
  371. cls._debugger_client = MockDebuggerClient(backend='GPU')
  372. @pytest.mark.level0
  373. @pytest.mark.env_single
  374. @pytest.mark.platform_x86_cpu
  375. @pytest.mark.platform_arm_ascend_training
  376. @pytest.mark.platform_x86_gpu_training
  377. @pytest.mark.platform_x86_ascend_training
  378. def test_next_node_on_gpu(self, app_client):
  379. """Test get next node on GPU."""
  380. gpu_debugger_client = MockDebuggerClient(backend='GPU')
  381. with gpu_debugger_client.get_thread_instance():
  382. check_waiting_state(app_client)
  383. # send run command to get watchpoint hit
  384. url = 'control'
  385. body_data = {'mode': 'continue',
  386. 'level': 'node',
  387. 'name': 'Default/TransData-op99'}
  388. res = get_request_result(app_client, url, body_data)
  389. assert res == {'metadata': {'state': 'running', 'enable_recheck': False}}
  390. # get metadata
  391. check_waiting_state(app_client)
  392. url = 'retrieve'
  393. body_data = {'mode': 'all'}
  394. expect_file = 'retrieve_next_node_on_gpu.json'
  395. send_and_compare_result(app_client, url, body_data, expect_file)
  396. send_terminate_cmd(app_client)
  397. @pytest.mark.level0
  398. @pytest.mark.env_single
  399. @pytest.mark.platform_x86_cpu
  400. @pytest.mark.platform_arm_ascend_training
  401. @pytest.mark.platform_x86_gpu_training
  402. @pytest.mark.platform_x86_ascend_training
  403. @pytest.mark.parametrize("url, body_data, enable_recheck", [
  404. ('create_watchpoint',
  405. {'condition': {'id': 'inf', 'params': []},
  406. 'watch_nodes': ['Default']}, False),
  407. ('create_watchpoint',
  408. {'condition': {'id': 'inf', 'params': []},
  409. 'watch_nodes': ['Default/TransData-op99']}, True),
  410. ('update_watchpoint',
  411. {'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum/Parameter[18]_7'],
  412. 'mode': 0}, True),
  413. ('update_watchpoint',
  414. {'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum'],
  415. 'mode': 1}, False),
  416. ('update_watchpoint',
  417. [{'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum'],
  418. 'mode': 1},
  419. {'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum'],
  420. 'mode': 0}
  421. ], True),
  422. ('update_watchpoint',
  423. [{'watch_point_id': 1, 'watch_nodes': ['Default/TransData-op99'],
  424. 'mode': 0},
  425. {'watch_point_id': 1, 'watch_nodes': ['Default/TransData-op99'],
  426. 'mode': 1}
  427. ], True),
  428. ('delete_watchpoint', {'watch_point_id': 1}, True)
  429. ])
  430. def test_recheck_state(self, app_client, url, body_data, enable_recheck):
  431. """Test update watchpoint and check the value of enable_recheck."""
  432. with self._debugger_client.get_thread_instance():
  433. create_watchpoint_and_wait(app_client)
  434. if not isinstance(body_data, list):
  435. body_data = [body_data]
  436. for sub_body_data in body_data:
  437. res = get_request_result(app_client, url, sub_body_data, method='post')
  438. assert res['metadata']['enable_recheck'] is enable_recheck
  439. send_terminate_cmd(app_client)
  440. def test_get_conditions(self, app_client):
  441. """Test get conditions for gpu."""
  442. url = '/v1/mindinsight/conditionmgr/train-jobs/train-id/conditions'
  443. body_data = {}
  444. expect_file = 'get_conditions_for_gpu.json'
  445. with self._debugger_client.get_thread_instance():
  446. check_waiting_state(app_client)
  447. send_and_compare_result(app_client, url, body_data, expect_file, method='get', full_url=True)
  448. send_terminate_cmd(app_client)
  449. @pytest.mark.level0
  450. @pytest.mark.env_single
  451. @pytest.mark.platform_x86_cpu
  452. @pytest.mark.platform_arm_ascend_training
  453. @pytest.mark.platform_x86_gpu_training
  454. @pytest.mark.platform_x86_ascend_training
  455. def test_recheck(self, app_client):
  456. """Test recheck request."""
  457. with self._debugger_client.get_thread_instance():
  458. create_watchpoint_and_wait(app_client)
  459. # send recheck when disable to do recheck
  460. get_request_result(app_client, 'recheck', {}, method='post', expect_code=400)
  461. # send recheck when enable to do recheck
  462. create_watchpoint(app_client, {'id': 'inf', 'params': []}, 2)
  463. res = get_request_result(app_client, 'recheck', {}, method='post')
  464. assert res['metadata']['enable_recheck'] is False
  465. send_terminate_cmd(app_client)
  466. @pytest.mark.level0
  467. @pytest.mark.env_single
  468. @pytest.mark.platform_x86_cpu
  469. @pytest.mark.platform_arm_ascend_training
  470. @pytest.mark.platform_x86_gpu_training
  471. @pytest.mark.platform_x86_ascend_training
  472. @pytest.mark.parametrize("filter_condition, expect_file", [
  473. ({'name': 'fc', 'node_category': 'weight'}, 'search_weight.json'),
  474. ({'name': 'fc', 'node_category': 'gradient'}, 'search_gradient.json'),
  475. ({'node_category': 'activation'}, 'search_activation.json')
  476. ])
  477. def test_search_by_category(self, app_client, filter_condition, expect_file):
  478. """Test recheck request."""
  479. with self._debugger_client.get_thread_instance():
  480. check_waiting_state(app_client)
  481. send_and_compare_result(app_client, 'search', filter_condition, expect_file,
  482. method='get')
  483. send_terminate_cmd(app_client)
  484. class TestMultiGraphDebugger:
  485. """Test debugger on Ascend backend."""
  486. @classmethod
  487. def setup_class(cls):
  488. """Setup class."""
  489. cls._debugger_client = MockDebuggerClient(backend='Ascend', graph_num=2)
  490. @pytest.mark.level0
  491. @pytest.mark.env_single
  492. @pytest.mark.platform_x86_cpu
  493. @pytest.mark.platform_arm_ascend_training
  494. @pytest.mark.platform_x86_gpu_training
  495. @pytest.mark.platform_x86_ascend_training
  496. @pytest.mark.parametrize("body_data, expect_file", [
  497. ({'mode': 'all'}, 'multi_retrieve_all.json'),
  498. ({'mode': 'node', 'params': {'name': 'Default', 'graph_name': 'graph_1'}}, 'retrieve_scope_node.json'),
  499. ({'mode': 'node', 'params': {'name': 'graph_0'}}, 'multi_retrieve_scope_node.json'),
  500. ({'mode': 'node', 'params': {'name': 'graph_0/Default/optimizer-Momentum/Parameter[18]_7'}},
  501. 'multi_retrieve_aggregation_scope_node.json'),
  502. ({'mode': 'node', 'params': {
  503. 'name': 'graph_0/Default/TransData-op99',
  504. 'single_node': True}}, 'multi_retrieve_single_node.json'),
  505. ({'mode': 'node', 'params': {
  506. 'name': 'Default/TransData-op99',
  507. 'single_node': True, 'graph_name': 'graph_0'}}, 'retrieve_single_node.json')
  508. ])
  509. def test_multi_retrieve_when_train_begin(self, app_client, body_data, expect_file):
  510. """Test retrieve when train_begin."""
  511. url = 'retrieve'
  512. with self._debugger_client.get_thread_instance():
  513. check_waiting_state(app_client)
  514. send_and_compare_result(app_client, url, body_data, expect_file)
  515. send_terminate_cmd(app_client)
  516. @pytest.mark.level0
  517. @pytest.mark.env_single
  518. @pytest.mark.platform_x86_cpu
  519. @pytest.mark.platform_arm_ascend_training
  520. @pytest.mark.platform_x86_gpu_training
  521. @pytest.mark.platform_x86_ascend_training
  522. @pytest.mark.parametrize("filter_condition, expect_file", [
  523. ({'name': '', 'node_category': 'weight'}, 'search_weight_multi_graph.json'),
  524. ({'node_category': 'activation'}, 'search_activation_multi_graph.json')
  525. ])
  526. def test_search_by_category_with_multi_graph(self, app_client, filter_condition, expect_file):
  527. """Test search by category request."""
  528. with self._debugger_client.get_thread_instance():
  529. check_waiting_state(app_client)
  530. send_and_compare_result(app_client, 'search', filter_condition, expect_file, method='get')
  531. send_terminate_cmd(app_client)
  532. @pytest.mark.level0
  533. @pytest.mark.env_single
  534. @pytest.mark.platform_x86_cpu
  535. @pytest.mark.platform_arm_ascend_training
  536. @pytest.mark.platform_x86_gpu_training
  537. @pytest.mark.platform_x86_ascend_training
  538. @pytest.mark.parametrize("filter_condition, expect_id", [
  539. ({'condition': {'id': 'inf'},
  540. 'watch_nodes': ['Default/optimizer-Momentum/Parameter[18]_7'],
  541. 'graph_name': 'graph_0'}, 1),
  542. ({'condition': {'id': 'inf'},
  543. 'watch_nodes': ['graph_0/Default/optimizer-Momentum/ApplyMomentum[8]_1'],
  544. 'graph_name': None}, 1)
  545. ])
  546. def test_create_watchpoint(self, app_client, filter_condition, expect_id):
  547. """Test create watchpoint with multiple graphs."""
  548. url = 'create_watchpoint'
  549. with self._debugger_client.get_thread_instance():
  550. check_waiting_state(app_client)
  551. res = get_request_result(app_client, url, filter_condition)
  552. assert res.get('id') == expect_id
  553. send_terminate_cmd(app_client)
  554. @pytest.mark.level0
  555. @pytest.mark.env_single
  556. @pytest.mark.platform_x86_cpu
  557. @pytest.mark.platform_arm_ascend_training
  558. @pytest.mark.platform_x86_gpu_training
  559. @pytest.mark.platform_x86_ascend_training
  560. @pytest.mark.parametrize("params, expect_file", [
  561. ({'level': 'node'}, 'multi_next_node.json'),
  562. ({'level': 'node', 'node_name': 'graph_0/Default/TransData-op99'}, 'multi_next_node.json'),
  563. ({'level': 'node', 'node_name': 'Default/TransData-op99', 'graph_name': 'graph_0'},
  564. 'multi_next_node.json')
  565. ])
  566. def test_continue_on_gpu(self, app_client, params, expect_file):
  567. """Test get next node on GPU."""
  568. gpu_debugger_client = MockDebuggerClient(backend='GPU', graph_num=2)
  569. original_value = settings.ENABLE_RECOMMENDED_WATCHPOINTS
  570. settings.ENABLE_RECOMMENDED_WATCHPOINTS = True
  571. try:
  572. with gpu_debugger_client.get_thread_instance():
  573. check_waiting_state(app_client)
  574. # send run command to get watchpoint hit
  575. url = 'control'
  576. body_data = {'mode': 'continue'}
  577. body_data.update(params)
  578. res = get_request_result(app_client, url, body_data)
  579. assert res == {'metadata': {'state': 'running', 'enable_recheck': False}}
  580. # get metadata
  581. check_waiting_state(app_client)
  582. url = 'retrieve'
  583. body_data = {'mode': 'all'}
  584. send_and_compare_result(app_client, url, body_data, expect_file)
  585. send_terminate_cmd(app_client)
  586. finally:
  587. settings.ENABLE_RECOMMENDED_WATCHPOINTS = original_value
  588. @pytest.mark.level0
  589. @pytest.mark.env_single
  590. @pytest.mark.platform_x86_cpu
  591. @pytest.mark.platform_arm_ascend_training
  592. @pytest.mark.platform_x86_gpu_training
  593. @pytest.mark.platform_x86_ascend_training
  594. @pytest.mark.parametrize("body_data, expect_file", [
  595. ({'tensor_name': 'Default/TransData-op99:0', 'graph_name': 'graph_0'}, 'retrieve_tensor_hits-0.json'),
  596. ({'tensor_name': 'Default/optimizer-Momentum/Parameter[18]_7/moments.fc1.bias:0', 'graph_name': 'graph_0'},
  597. 'retrieve_tensor_hits-1.json')
  598. ])
  599. def test_retrieve_tensor_hits(self, app_client, body_data, expect_file):
  600. """Test retrieve tensor graph."""
  601. url = 'tensor_hits'
  602. with self._debugger_client.get_thread_instance():
  603. check_waiting_state(app_client)
  604. send_and_compare_result(app_client, url, body_data, expect_file, method='GET')
  605. send_terminate_cmd(app_client)
  606. def create_watchpoint(app_client, condition, expect_id):
  607. """Create watchpoint."""
  608. url = 'create_watchpoint'
  609. body_data = {'condition': condition,
  610. 'watch_nodes': ['Default/optimizer-Momentum/Parameter[18]_7',
  611. 'Default/optimizer-Momentum/Parameter[18]_7/moments.fc3.bias',
  612. 'Default/optimizer-Momentum/Parameter[18]_7/moments.fc1.bias',
  613. 'Default/TransData-op99']}
  614. res = get_request_result(app_client, url, body_data)
  615. assert res.get('id') == expect_id
  616. def create_watchpoint_and_wait(app_client):
  617. """Preparation for recheck."""
  618. check_waiting_state(app_client)
  619. create_watchpoint(app_client, condition={'id': 'inf', 'params': []}, expect_id=1)
  620. # send run command to get watchpoint hit
  621. url = 'control'
  622. body_data = {'mode': 'continue',
  623. 'steps': 2}
  624. res = get_request_result(app_client, url, body_data)
  625. assert res == {'metadata': {'state': 'running', 'enable_recheck': False}}
  626. # wait for server has received watchpoint hit
  627. check_waiting_state(app_client)