You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_restful_api.py 32 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Function:
  17. Test query debugger restful api.
  18. Usage:
  19. pytest tests/st/func/debugger/test_restful_api.py
  20. """
  21. import os
  22. import pytest
  23. from mindinsight.conf import settings
  24. from mindinsight.debugger.common.utils import ServerStatus
  25. from tests.st.func.debugger.conftest import DEBUGGER_BASE_URL
  26. from tests.st.func.debugger.mock_ms_client import MockDebuggerClient
  27. from tests.st.func.debugger.utils import check_state, get_request_result, \
  28. send_and_compare_result
  29. def send_terminate_cmd(app_client):
  30. """Send terminate command to debugger client."""
  31. url = os.path.join(DEBUGGER_BASE_URL, 'control')
  32. body_data = {'mode': 'terminate'}
  33. send_and_compare_result(app_client, url, body_data)
  34. class TestAscendDebugger:
  35. """Test debugger on Ascend backend."""
  36. @classmethod
  37. def setup_class(cls):
  38. """Setup class."""
  39. cls._debugger_client = MockDebuggerClient(backend='Ascend')
  40. @pytest.mark.level0
  41. @pytest.mark.env_single
  42. @pytest.mark.platform_x86_cpu
  43. @pytest.mark.platform_arm_ascend_training
  44. @pytest.mark.platform_x86_gpu_training
  45. @pytest.mark.platform_x86_ascend_training
  46. def test_before_train_begin(self, app_client):
  47. """Test retrieve all."""
  48. url = 'retrieve'
  49. body_data = {'mode': 'all'}
  50. expect_file = 'before_train_begin.json'
  51. send_and_compare_result(app_client, url, body_data, expect_file)
  52. @pytest.mark.level0
  53. @pytest.mark.env_single
  54. @pytest.mark.platform_x86_cpu
  55. @pytest.mark.platform_arm_ascend_training
  56. @pytest.mark.platform_x86_gpu_training
  57. @pytest.mark.platform_x86_ascend_training
  58. @pytest.mark.parametrize("body_data, expect_file", [
  59. ({'mode': 'all'}, 'retrieve_all.json'),
  60. ({'mode': 'node', 'params': {'name': 'Default'}}, 'retrieve_scope_node.json'),
  61. ({'mode': 'node', 'params': {'name': 'Default/optimizer-Momentum/Parameter[18]_7'}},
  62. 'retrieve_aggregation_scope_node.json'),
  63. ({'mode': 'node', 'params': {
  64. 'name': 'Default/TransData-op99',
  65. 'single_node': True}}, 'retrieve_single_node.json'),
  66. ({'mode': 'watchpoint_hit'}, 'retrieve_empty_watchpoint_hit_list')
  67. ])
  68. def test_retrieve_when_train_begin(self, app_client, body_data, expect_file):
  69. """Test retrieve when train_begin."""
  70. url = 'retrieve'
  71. with self._debugger_client.get_thread_instance():
  72. check_state(app_client)
  73. send_and_compare_result(app_client, url, body_data, expect_file)
  74. send_terminate_cmd(app_client)
  75. def test_get_conditions(self, app_client):
  76. """Test get conditions for ascend."""
  77. url = '/v1/mindinsight/conditionmgr/train-jobs/train-id/condition-collections'
  78. body_data = {}
  79. expect_file = 'get_conditions_for_ascend.json'
  80. with self._debugger_client.get_thread_instance():
  81. check_state(app_client)
  82. send_and_compare_result(app_client, url, body_data, expect_file, method='get', full_url=True)
  83. send_terminate_cmd(app_client)
  84. @pytest.mark.level0
  85. @pytest.mark.env_single
  86. @pytest.mark.platform_x86_cpu
  87. @pytest.mark.platform_arm_ascend_training
  88. @pytest.mark.platform_x86_gpu_training
  89. @pytest.mark.platform_x86_ascend_training
  90. @pytest.mark.parametrize("body_data, expect_file", [
  91. ({'mode': 'all'}, 'multi_retrieve_all.json'),
  92. ({'mode': 'node', 'params': {'name': 'Default', 'graph_name': 'graph_1'}}, 'retrieve_scope_node.json'),
  93. ({'mode': 'node', 'params': {'name': 'graph_0'}}, 'multi_retrieve_scope_node.json'),
  94. ({'mode': 'node', 'params': {'name': 'graph_0/Default/optimizer-Momentum/Parameter[18]_7'}},
  95. 'multi_retrieve_aggregation_scope_node.json'),
  96. ({'mode': 'node', 'params': {
  97. 'name': 'graph_0/Default/TransData-op99',
  98. 'single_node': True}}, 'multi_retrieve_single_node.json'),
  99. ({'mode': 'node', 'params': {
  100. 'name': 'Default/TransData-op99',
  101. 'single_node': True, 'graph_name': 'graph_0'}}, 'retrieve_single_node.json')
  102. ])
  103. def test_multi_retrieve_when_train_begin(self, app_client, body_data, expect_file):
  104. """Test retrieve when train_begin."""
  105. url = 'retrieve'
  106. debugger_client = MockDebuggerClient(backend='Ascend', graph_num=2)
  107. with debugger_client.get_thread_instance():
  108. check_state(app_client)
  109. send_and_compare_result(app_client, url, body_data, expect_file)
  110. send_terminate_cmd(app_client)
  111. @pytest.mark.level0
  112. @pytest.mark.env_single
  113. @pytest.mark.platform_x86_cpu
  114. @pytest.mark.platform_arm_ascend_training
  115. @pytest.mark.platform_x86_gpu_training
  116. @pytest.mark.platform_x86_ascend_training
  117. def test_create_and_delete_watchpoint(self, app_client):
  118. """Test create and delete watchpoint."""
  119. with self._debugger_client.get_thread_instance():
  120. check_state(app_client)
  121. conditions = [
  122. {'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]},
  123. {'id': 'tensor_too_small', 'params': [{'name': 'max_lt', 'value': -1.0}]},
  124. {'id': 'tensor_too_large', 'params': [{'name': 'min_gt', 'value': 1e+32}]},
  125. {'id': 'tensor_too_small', 'params': [{'name': 'min_lt', 'value': -1e+32}]},
  126. {'id': 'tensor_too_large', 'params': [{'name': 'mean_gt', 'value': 0}]},
  127. {'id': 'tensor_too_small', 'params': [{'name': 'mean_lt', 'value': 0}]}
  128. ]
  129. for idx, condition in enumerate(conditions):
  130. create_watchpoint(app_client, condition, idx + 1)
  131. # delete 4-th watchpoint
  132. url = 'delete_watchpoint'
  133. body_data = {'watch_point_id': 4}
  134. get_request_result(app_client, url, body_data)
  135. # test watchpoint list
  136. url = 'retrieve'
  137. body_data = {'mode': 'watchpoint'}
  138. expect_file = 'create_and_delete_watchpoint.json'
  139. send_and_compare_result(app_client, url, body_data, expect_file)
  140. send_terminate_cmd(app_client)
  141. @pytest.mark.level0
  142. @pytest.mark.env_single
  143. @pytest.mark.platform_x86_cpu
  144. @pytest.mark.platform_arm_ascend_training
  145. @pytest.mark.platform_x86_gpu_training
  146. @pytest.mark.platform_x86_ascend_training
  147. def test_update_watchpoint(self, app_client):
  148. """Test retrieve when train_begin."""
  149. watch_point_id = 1
  150. leaf_node_name = 'Default/optimizer-Momentum/Parameter[18]_7/moments.fc3.bias'
  151. with self._debugger_client.get_thread_instance():
  152. check_state(app_client)
  153. condition = {'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]}
  154. create_watchpoint(app_client, condition, watch_point_id)
  155. # update watchpoint watchpoint list
  156. url = 'update_watchpoint'
  157. body_data = {'watch_point_id': watch_point_id,
  158. 'watch_nodes': [leaf_node_name],
  159. 'mode': 0}
  160. get_request_result(app_client, url, body_data)
  161. # get updated nodes
  162. url = 'search'
  163. body_data = {'name': leaf_node_name, 'watch_point_id': watch_point_id}
  164. expect_file = 'search_unwatched_leaf_node.json'
  165. send_and_compare_result(app_client, url, body_data, expect_file, method='get')
  166. send_terminate_cmd(app_client)
  167. @pytest.mark.level0
  168. @pytest.mark.env_single
  169. @pytest.mark.platform_x86_cpu
  170. @pytest.mark.platform_arm_ascend_training
  171. @pytest.mark.platform_x86_gpu_training
  172. @pytest.mark.platform_x86_ascend_training
  173. def test_watchpoint_hit(self, app_client):
  174. """Test retrieve watchpoint hit."""
  175. with self._debugger_client.get_thread_instance():
  176. create_watchpoint_and_wait(app_client)
  177. # check watchpoint hit list
  178. url = 'retrieve'
  179. body_data = {'mode': 'watchpoint_hit'}
  180. expect_file = 'retrieve_watchpoint_hit.json'
  181. send_and_compare_result(app_client, url, body_data, expect_file)
  182. # check single watchpoint hit
  183. body_data = {
  184. 'mode': 'watchpoint_hit',
  185. 'params': {
  186. 'name': 'Default/TransData-op99',
  187. 'single_node': True,
  188. 'watch_point_id': 1
  189. }
  190. }
  191. expect_file = 'retrieve_single_watchpoint_hit.json'
  192. send_and_compare_result(app_client, url, body_data, expect_file)
  193. send_terminate_cmd(app_client)
  194. @pytest.mark.level0
  195. @pytest.mark.env_single
  196. @pytest.mark.platform_x86_cpu
  197. @pytest.mark.platform_arm_ascend_training
  198. @pytest.mark.platform_x86_gpu_training
  199. @pytest.mark.platform_x86_ascend_training
  200. def test_retrieve_tensor_value(self, app_client):
  201. """Test retrieve tensor value."""
  202. node_name = 'Default/TransData-op99'
  203. with self._debugger_client.get_thread_instance():
  204. check_state(app_client)
  205. # prepare tensor value
  206. url = 'retrieve_tensor_history'
  207. body_data = {'name': node_name}
  208. expect_file = 'retrieve_empty_tensor_history.json'
  209. send_and_compare_result(app_client, url, body_data, expect_file)
  210. # check full tensor history from poll data
  211. res = get_request_result(
  212. app_client=app_client, url='poll_data', body_data={'pos': 0}, method='get')
  213. assert res.get('receive_tensor', {}).get('node_name') == node_name
  214. expect_file = 'retrieve_full_tensor_history.json'
  215. send_and_compare_result(app_client, url, body_data, expect_file)
  216. # check tensor value
  217. url = 'tensors'
  218. body_data = {
  219. 'name': node_name + ':0',
  220. 'detail': 'data',
  221. 'shape': '[1, 1:3]'
  222. }
  223. expect_file = 'retrieve_tensor_value.json'
  224. send_and_compare_result(app_client, url, body_data, expect_file, method='get')
  225. send_terminate_cmd(app_client)
  226. @pytest.mark.level0
  227. @pytest.mark.env_single
  228. @pytest.mark.platform_x86_cpu
  229. @pytest.mark.platform_arm_ascend_training
  230. @pytest.mark.platform_x86_gpu_training
  231. @pytest.mark.platform_x86_ascend_training
  232. def test_compare_tensor_value(self, app_client):
  233. """Test compare tensor value."""
  234. node_name = 'Default/args0'
  235. with self._debugger_client.get_thread_instance():
  236. check_state(app_client)
  237. # prepare tensor values
  238. url = 'control'
  239. body_data = {'mode': 'continue',
  240. 'steps': 2}
  241. get_request_result(app_client, url, body_data)
  242. check_state(app_client)
  243. get_request_result(
  244. app_client=app_client, url='retrieve_tensor_history', body_data={'name': node_name})
  245. res = get_request_result(
  246. app_client=app_client, url='poll_data', body_data={'pos': 0}, method='get')
  247. assert res.get('receive_tensor', {}).get('node_name') == node_name
  248. # get compare results
  249. url = 'tensor-comparisons'
  250. body_data = {
  251. 'name': node_name + ':0',
  252. 'detail': 'data',
  253. 'shape': '[:, :]',
  254. 'tolerance': 1
  255. }
  256. expect_file = 'compare_tensors.json'
  257. send_and_compare_result(app_client, url, body_data, expect_file, method='get')
  258. send_terminate_cmd(app_client)
  259. @pytest.mark.level0
  260. @pytest.mark.env_single
  261. @pytest.mark.platform_x86_cpu
  262. @pytest.mark.platform_arm_ascend_training
  263. @pytest.mark.platform_x86_gpu_training
  264. @pytest.mark.platform_x86_ascend_training
  265. @pytest.mark.parametrize("body_data, expect_file", [
  266. ({'ascend': True}, 'retrieve_node_by_bfs_ascend.json'),
  267. ({'name': 'Default/args0', 'ascend': False}, 'retrieve_node_by_bfs.json')
  268. ])
  269. def test_retrieve_bfs_node(self, app_client, body_data, expect_file):
  270. """Test retrieve bfs node."""
  271. with self._debugger_client.get_thread_instance():
  272. check_state(app_client)
  273. # prepare tensor values
  274. url = 'retrieve_node_by_bfs'
  275. send_and_compare_result(app_client, url, body_data, expect_file, method='get')
  276. send_terminate_cmd(app_client)
  277. @pytest.mark.level0
  278. @pytest.mark.env_single
  279. @pytest.mark.platform_x86_cpu
  280. @pytest.mark.platform_arm_ascend_training
  281. @pytest.mark.platform_x86_gpu_training
  282. @pytest.mark.platform_x86_ascend_training
  283. def test_pause(self, app_client):
  284. """Test pause the training."""
  285. with self._debugger_client.get_thread_instance():
  286. check_state(app_client)
  287. # send run command to execute to next node
  288. url = 'control'
  289. body_data = {'mode': 'continue',
  290. 'steps': -1}
  291. res = get_request_result(app_client, url, body_data)
  292. assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}}
  293. # send pause command
  294. check_state(app_client, 'running')
  295. url = 'control'
  296. body_data = {'mode': 'pause'}
  297. res = get_request_result(app_client, url, body_data)
  298. assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}}
  299. send_terminate_cmd(app_client)
  300. @pytest.mark.level0
  301. @pytest.mark.env_single
  302. @pytest.mark.platform_x86_cpu
  303. @pytest.mark.platform_arm_ascend_training
  304. @pytest.mark.platform_x86_gpu_training
  305. @pytest.mark.platform_x86_ascend_training
  306. @pytest.mark.parametrize("url, body_data, enable_recheck", [
  307. ('create_watchpoint',
  308. {'condition': {'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]},
  309. 'watch_nodes': ['Default']}, True),
  310. ('update_watchpoint',
  311. {'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum/Parameter[18]_7'],
  312. 'mode': 0}, True),
  313. ('update_watchpoint',
  314. {'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum'],
  315. 'mode': 1}, True),
  316. ('delete_watchpoint', {}, True)
  317. ])
  318. def test_recheck(self, app_client, url, body_data, enable_recheck):
  319. """Test recheck."""
  320. with self._debugger_client.get_thread_instance():
  321. create_watchpoint_and_wait(app_client)
  322. # create watchpoint
  323. res = get_request_result(app_client, url, body_data, method='post')
  324. assert res['metadata']['enable_recheck'] is enable_recheck
  325. send_terminate_cmd(app_client)
  326. @pytest.mark.level0
  327. @pytest.mark.env_single
  328. @pytest.mark.platform_x86_cpu
  329. @pytest.mark.platform_arm_ascend_training
  330. @pytest.mark.platform_x86_gpu_training
  331. @pytest.mark.platform_x86_ascend_training
  332. def test_recommend_watchpoints(self, app_client):
  333. """Test generating recommended watchpoints."""
  334. original_value = settings.ENABLE_RECOMMENDED_WATCHPOINTS
  335. settings.ENABLE_RECOMMENDED_WATCHPOINTS = True
  336. try:
  337. with self._debugger_client.get_thread_instance():
  338. check_state(app_client)
  339. url = 'retrieve'
  340. body_data = {'mode': 'watchpoint'}
  341. expect_file = 'recommended_watchpoints_at_startup.json'
  342. send_and_compare_result(app_client, url, body_data, expect_file, method='post')
  343. send_terminate_cmd(app_client)
  344. finally:
  345. settings.ENABLE_RECOMMENDED_WATCHPOINTS = original_value
  346. @pytest.mark.level0
  347. @pytest.mark.env_single
  348. @pytest.mark.platform_x86_cpu
  349. @pytest.mark.platform_arm_ascend_training
  350. @pytest.mark.platform_x86_gpu_training
  351. @pytest.mark.platform_x86_ascend_training
  352. @pytest.mark.parametrize("body_data, expect_file", [
  353. ({'tensor_name': 'Default/TransData-op99:0', 'graph_name': 'graph_0'}, 'retrieve_tensor_graph-0.json'),
  354. ({'tensor_name': 'Default/optimizer-Momentum/Parameter[18]_7/moments.fc1.bias:0', 'graph_name': 'graph_0'},
  355. 'retrieve_tensor_graph-1.json')
  356. ])
  357. def test_retrieve_tensor_graph(self, app_client, body_data, expect_file):
  358. """Test retrieve tensor graph."""
  359. url = 'tensor-graphs'
  360. with self._debugger_client.get_thread_instance():
  361. create_watchpoint_and_wait(app_client)
  362. get_request_result(app_client, url, body_data, method='GET')
  363. # check full tensor history from poll data
  364. res = get_request_result(
  365. app_client=app_client, url='poll_data', body_data={'pos': 0}, method='get')
  366. assert res.get('receive_tensor', {}).get('tensor_name') == body_data.get('tensor_name')
  367. send_and_compare_result(app_client, url, body_data, expect_file, method='GET')
  368. send_terminate_cmd(app_client)
  369. class TestGPUDebugger:
  370. """Test debugger on Ascend backend."""
  371. @classmethod
  372. def setup_class(cls):
  373. """Setup class."""
  374. cls._debugger_client = MockDebuggerClient(backend='GPU')
  375. @pytest.mark.level0
  376. @pytest.mark.env_single
  377. @pytest.mark.platform_x86_cpu
  378. @pytest.mark.platform_arm_ascend_training
  379. @pytest.mark.platform_x86_gpu_training
  380. @pytest.mark.platform_x86_ascend_training
  381. def test_next_node_on_gpu(self, app_client):
  382. """Test get next node on GPU."""
  383. gpu_debugger_client = MockDebuggerClient(backend='GPU')
  384. with gpu_debugger_client.get_thread_instance():
  385. check_state(app_client)
  386. # send run command to get watchpoint hit
  387. url = 'control'
  388. body_data = {'mode': 'continue',
  389. 'level': 'node',
  390. 'name': 'Default/TransData-op99'}
  391. res = get_request_result(app_client, url, body_data)
  392. assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}}
  393. # get metadata
  394. check_state(app_client)
  395. url = 'retrieve'
  396. body_data = {'mode': 'all'}
  397. expect_file = 'retrieve_next_node_on_gpu.json'
  398. send_and_compare_result(app_client, url, body_data, expect_file)
  399. send_terminate_cmd(app_client)
  400. @pytest.mark.level0
  401. @pytest.mark.env_single
  402. @pytest.mark.platform_x86_cpu
  403. @pytest.mark.platform_arm_ascend_training
  404. @pytest.mark.platform_x86_gpu_training
  405. @pytest.mark.platform_x86_ascend_training
  406. @pytest.mark.parametrize("url, body_data, enable_recheck", [
  407. ('create_watchpoint',
  408. {'condition': {'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]},
  409. 'watch_nodes': ['Default']}, True),
  410. ('create_watchpoint',
  411. {'condition': {'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]},
  412. 'watch_nodes': ['Default/TransData-op99']}, True),
  413. ('update_watchpoint',
  414. {'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum/Parameter[18]_7'],
  415. 'mode': 0}, True),
  416. ('update_watchpoint',
  417. {'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum'],
  418. 'mode': 1}, True),
  419. ('update_watchpoint',
  420. [{'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum'],
  421. 'mode': 1},
  422. {'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum'],
  423. 'mode': 0}
  424. ], True),
  425. ('update_watchpoint',
  426. [{'watch_point_id': 1, 'watch_nodes': ['Default/TransData-op99'],
  427. 'mode': 0},
  428. {'watch_point_id': 1, 'watch_nodes': ['Default/TransData-op99'],
  429. 'mode': 1}
  430. ], True),
  431. ('delete_watchpoint', {'watch_point_id': 1}, True)
  432. ])
  433. def test_recheck_state(self, app_client, url, body_data, enable_recheck):
  434. """Test update watchpoint and check the value of enable_recheck."""
  435. with self._debugger_client.get_thread_instance():
  436. create_watchpoint_and_wait(app_client)
  437. if not isinstance(body_data, list):
  438. body_data = [body_data]
  439. for sub_body_data in body_data:
  440. res = get_request_result(app_client, url, sub_body_data, method='post')
  441. assert res['metadata']['enable_recheck'] is enable_recheck
  442. send_terminate_cmd(app_client)
  443. def test_get_conditions(self, app_client):
  444. """Test get conditions for gpu."""
  445. url = '/v1/mindinsight/conditionmgr/train-jobs/train-id/condition-collections'
  446. body_data = {}
  447. expect_file = 'get_conditions_for_gpu.json'
  448. with self._debugger_client.get_thread_instance():
  449. check_state(app_client)
  450. send_and_compare_result(app_client, url, body_data, expect_file, method='get', full_url=True)
  451. send_terminate_cmd(app_client)
  452. @pytest.mark.level0
  453. @pytest.mark.env_single
  454. @pytest.mark.platform_x86_cpu
  455. @pytest.mark.platform_arm_ascend_training
  456. @pytest.mark.platform_x86_gpu_training
  457. @pytest.mark.platform_x86_ascend_training
  458. def test_recheck(self, app_client):
  459. """Test recheck request."""
  460. with self._debugger_client.get_thread_instance():
  461. create_watchpoint_and_wait(app_client)
  462. # send recheck when disable to do recheck
  463. get_request_result(app_client, 'recheck', {}, method='post', expect_code=400)
  464. # send recheck when enable to do recheck
  465. create_watchpoint(app_client, {'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]}, 2)
  466. res = get_request_result(app_client, 'recheck', {}, method='post')
  467. assert res['metadata']['enable_recheck'] is False
  468. send_terminate_cmd(app_client)
  469. @pytest.mark.level0
  470. @pytest.mark.env_single
  471. @pytest.mark.platform_x86_cpu
  472. @pytest.mark.platform_arm_ascend_training
  473. @pytest.mark.platform_x86_gpu_training
  474. @pytest.mark.platform_x86_ascend_training
  475. @pytest.mark.parametrize("filter_condition, expect_file", [
  476. ({'name': 'fc', 'node_category': 'weight'}, 'search_weight.json'),
  477. ({'name': 'fc', 'node_category': 'gradient'}, 'search_gradient.json'),
  478. ({'node_category': 'activation'}, 'search_activation.json')
  479. ])
  480. def test_search_by_category(self, app_client, filter_condition, expect_file):
  481. """Test recheck request."""
  482. with self._debugger_client.get_thread_instance():
  483. check_state(app_client)
  484. send_and_compare_result(app_client, 'search', filter_condition, expect_file,
  485. method='get')
  486. send_terminate_cmd(app_client)
  487. class TestMultiGraphDebugger:
  488. """Test debugger on Ascend backend for multi_graph."""
  489. @classmethod
  490. def setup_class(cls):
  491. """Setup class."""
  492. cls._debugger_client = MockDebuggerClient(backend='Ascend', graph_num=2)
  493. @pytest.mark.level0
  494. @pytest.mark.env_single
  495. @pytest.mark.platform_x86_cpu
  496. @pytest.mark.platform_arm_ascend_training
  497. @pytest.mark.platform_x86_gpu_training
  498. @pytest.mark.platform_x86_ascend_training
  499. @pytest.mark.parametrize("body_data, expect_file", [
  500. ({'mode': 'all'}, 'multi_retrieve_all.json'),
  501. ({'mode': 'node', 'params': {'name': 'Default', 'graph_name': 'graph_1'}}, 'retrieve_scope_node.json'),
  502. ({'mode': 'node', 'params': {'name': 'graph_0'}}, 'multi_retrieve_scope_node.json'),
  503. ({'mode': 'node', 'params': {'name': 'graph_0/Default/optimizer-Momentum/Parameter[18]_7'}},
  504. 'multi_retrieve_aggregation_scope_node.json'),
  505. ({'mode': 'node', 'params': {
  506. 'name': 'graph_0/Default/TransData-op99',
  507. 'single_node': True}}, 'multi_retrieve_single_node.json'),
  508. ({'mode': 'node', 'params': {
  509. 'name': 'Default/TransData-op99',
  510. 'single_node': True, 'graph_name': 'graph_0'}}, 'retrieve_single_node.json')
  511. ])
  512. def test_multi_retrieve_when_train_begin(self, app_client, body_data, expect_file):
  513. """Test retrieve when train_begin."""
  514. url = 'retrieve'
  515. with self._debugger_client.get_thread_instance():
  516. check_state(app_client)
  517. send_and_compare_result(app_client, url, body_data, expect_file)
  518. send_terminate_cmd(app_client)
  519. @pytest.mark.level0
  520. @pytest.mark.env_single
  521. @pytest.mark.platform_x86_cpu
  522. @pytest.mark.platform_arm_ascend_training
  523. @pytest.mark.platform_x86_gpu_training
  524. @pytest.mark.platform_x86_ascend_training
  525. @pytest.mark.parametrize("filter_condition, expect_file", [
  526. ({'name': '', 'node_category': 'weight'}, 'search_weight_multi_graph.json'),
  527. ({'node_category': 'activation'}, 'search_activation_multi_graph.json'),
  528. ({'node_category': 'gradient'}, 'search_gradient_multi_graph.json')
  529. ])
  530. def test_search_by_category_with_multi_graph(self, app_client, filter_condition, expect_file):
  531. """Test search by category request."""
  532. with self._debugger_client.get_thread_instance():
  533. check_state(app_client)
  534. send_and_compare_result(app_client, 'search', filter_condition, expect_file, method='get')
  535. send_terminate_cmd(app_client)
  536. @pytest.mark.level0
  537. @pytest.mark.env_single
  538. @pytest.mark.platform_x86_cpu
  539. @pytest.mark.platform_arm_ascend_training
  540. @pytest.mark.platform_x86_gpu_training
  541. @pytest.mark.platform_x86_ascend_training
  542. @pytest.mark.parametrize("filter_condition, expect_id", [
  543. ({'condition': {'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]},
  544. 'watch_nodes': ['Default/optimizer-Momentum/Parameter[18]_7'],
  545. 'graph_name': 'graph_0'}, 1),
  546. ({'condition': {'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]},
  547. 'watch_nodes': ['graph_0/Default/optimizer-Momentum/ApplyMomentum[8]_1'],
  548. 'graph_name': None}, 1)
  549. ])
  550. def test_create_watchpoint(self, app_client, filter_condition, expect_id):
  551. """Test create watchpoint with multiple graphs."""
  552. url = 'create_watchpoint'
  553. with self._debugger_client.get_thread_instance():
  554. check_state(app_client)
  555. res = get_request_result(app_client, url, filter_condition)
  556. assert res.get('id') == expect_id
  557. send_terminate_cmd(app_client)
  558. @pytest.mark.level0
  559. @pytest.mark.env_single
  560. @pytest.mark.platform_x86_cpu
  561. @pytest.mark.platform_arm_ascend_training
  562. @pytest.mark.platform_x86_gpu_training
  563. @pytest.mark.platform_x86_ascend_training
  564. @pytest.mark.parametrize("params, expect_file", [
  565. ({'level': 'node'}, 'multi_next_node.json'),
  566. ({'level': 'node', 'node_name': 'graph_0/Default/TransData-op99'}, 'multi_next_node.json'),
  567. ({'level': 'node', 'node_name': 'Default/TransData-op99', 'graph_name': 'graph_0'},
  568. 'multi_next_node.json')
  569. ])
  570. def test_continue_on_gpu(self, app_client, params, expect_file):
  571. """Test get next node on GPU."""
  572. gpu_debugger_client = MockDebuggerClient(backend='GPU', graph_num=2)
  573. original_value = settings.ENABLE_RECOMMENDED_WATCHPOINTS
  574. settings.ENABLE_RECOMMENDED_WATCHPOINTS = True
  575. try:
  576. with gpu_debugger_client.get_thread_instance():
  577. check_state(app_client)
  578. # send run command to get watchpoint hit
  579. url = 'control'
  580. body_data = {'mode': 'continue'}
  581. body_data.update(params)
  582. res = get_request_result(app_client, url, body_data)
  583. assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}}
  584. # get metadata
  585. check_state(app_client)
  586. url = 'retrieve'
  587. body_data = {'mode': 'all'}
  588. send_and_compare_result(app_client, url, body_data, expect_file)
  589. send_terminate_cmd(app_client)
  590. finally:
  591. settings.ENABLE_RECOMMENDED_WATCHPOINTS = original_value
  592. @pytest.mark.level0
  593. @pytest.mark.env_single
  594. @pytest.mark.platform_x86_cpu
  595. @pytest.mark.platform_arm_ascend_training
  596. @pytest.mark.platform_x86_gpu_training
  597. @pytest.mark.platform_x86_ascend_training
  598. @pytest.mark.parametrize("body_data, expect_file", [
  599. ({'tensor_name': 'Default/TransData-op99:0', 'graph_name': 'graph_0'}, 'retrieve_tensor_hits-0.json'),
  600. ({'tensor_name': 'Default/optimizer-Momentum/Parameter[18]_7/moments.fc1.bias:0', 'graph_name': 'graph_0'},
  601. 'retrieve_tensor_hits-1.json')
  602. ])
  603. def test_retrieve_tensor_hits(self, app_client, body_data, expect_file):
  604. """Test retrieve tensor graph."""
  605. url = 'tensor-hits'
  606. with self._debugger_client.get_thread_instance():
  607. check_state(app_client)
  608. send_and_compare_result(app_client, url, body_data, expect_file, method='GET')
  609. send_terminate_cmd(app_client)
  610. def create_watchpoint(app_client, condition, expect_id):
  611. """Create watchpoint."""
  612. url = 'create_watchpoint'
  613. body_data = {'condition': condition,
  614. 'watch_nodes': ['Default/optimizer-Momentum/Parameter[18]_7',
  615. 'Default/optimizer-Momentum/Parameter[18]_7/moments.fc3.bias',
  616. 'Default/optimizer-Momentum/Parameter[18]_7/moments.fc1.bias',
  617. 'Default/TransData-op99']}
  618. res = get_request_result(app_client, url, body_data)
  619. assert res.get('id') == expect_id
  620. def create_watchpoint_and_wait(app_client):
  621. """Preparation for recheck."""
  622. check_state(app_client)
  623. create_watchpoint(app_client, condition={'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]},
  624. expect_id=1)
  625. # send run command to get watchpoint hit
  626. url = 'control'
  627. body_data = {'mode': 'continue',
  628. 'steps': 2}
  629. res = get_request_result(app_client, url, body_data)
  630. assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}}
  631. # wait for server has received watchpoint hit
  632. check_state(app_client)
  633. class TestMismatchDebugger:
  634. """Test debugger when Mindinsight and Mindspore is mismatched."""
  635. @classmethod
  636. def setup_class(cls):
  637. """Setup class."""
  638. cls._debugger_client = MockDebuggerClient(backend='Ascend', ms_version='1.0.0')
  639. @pytest.mark.level0
  640. @pytest.mark.env_single
  641. @pytest.mark.platform_x86_cpu
  642. @pytest.mark.platform_arm_ascend_training
  643. @pytest.mark.platform_x86_gpu_training
  644. @pytest.mark.platform_x86_ascend_training
  645. @pytest.mark.parametrize("body_data, expect_file", [
  646. ({'mode': 'all'}, 'version_mismatch.json')
  647. ])
  648. def test_retrieve_when_version_mismatch(self, app_client, body_data, expect_file):
  649. """Test retrieve when train_begin."""
  650. url = 'retrieve'
  651. with self._debugger_client.get_thread_instance():
  652. check_state(app_client, ServerStatus.MISMATCH.value)
  653. send_and_compare_result(app_client, url, body_data, expect_file)
  654. send_terminate_cmd(app_client)