You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_restful_api.py 31 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Function:
  17. Test query debugger restful api.
  18. Usage:
  19. pytest tests/st/func/debugger/test_restful_api.py
  20. """
  21. import os
  22. import pytest
  23. from mindinsight.conf import settings
  24. from mindinsight.debugger.common.utils import ServerStatus
  25. from tests.st.func.debugger.conftest import DEBUGGER_BASE_URL
  26. from tests.st.func.debugger.mock_ms_client import MockDebuggerClient
  27. from tests.st.func.debugger.utils import check_state, get_request_result, \
  28. send_and_compare_result
  29. def send_terminate_cmd(app_client):
  30. """Send terminate command to debugger client."""
  31. url = os.path.join(DEBUGGER_BASE_URL, 'control')
  32. body_data = {'mode': 'terminate'}
  33. send_and_compare_result(app_client, url, body_data)
  34. class TestAscendDebugger:
  35. """Test debugger on Ascend backend."""
  36. @classmethod
  37. def setup_class(cls):
  38. """Setup class."""
  39. cls._debugger_client = MockDebuggerClient(backend='Ascend')
  40. @pytest.mark.level0
  41. @pytest.mark.env_single
  42. @pytest.mark.platform_x86_cpu
  43. @pytest.mark.platform_arm_ascend_training
  44. @pytest.mark.platform_x86_gpu_training
  45. @pytest.mark.platform_x86_ascend_training
  46. def test_before_train_begin(self, app_client):
  47. """Test retrieve all."""
  48. url = 'retrieve'
  49. body_data = {'mode': 'all'}
  50. expect_file = 'before_train_begin.json'
  51. send_and_compare_result(app_client, url, body_data, expect_file)
  52. @pytest.mark.level0
  53. @pytest.mark.env_single
  54. @pytest.mark.platform_x86_cpu
  55. @pytest.mark.platform_arm_ascend_training
  56. @pytest.mark.platform_x86_gpu_training
  57. @pytest.mark.platform_x86_ascend_training
  58. @pytest.mark.parametrize("body_data, expect_file", [
  59. ({'mode': 'all'}, 'retrieve_all.json'),
  60. ({'mode': 'node', 'params': {'name': 'Default'}}, 'retrieve_scope_node.json'),
  61. ({'mode': 'node', 'params': {'name': 'Default/optimizer-Momentum/Parameter[18]_7'}},
  62. 'retrieve_aggregation_scope_node.json'),
  63. ({'mode': 'node', 'params': {
  64. 'name': 'Default/TransData-op99',
  65. 'single_node': True}}, 'retrieve_single_node.json')
  66. ])
  67. def test_retrieve_when_train_begin(self, app_client, body_data, expect_file):
  68. """Test retrieve when train_begin."""
  69. url = 'retrieve'
  70. with self._debugger_client.get_thread_instance():
  71. check_state(app_client)
  72. send_and_compare_result(app_client, url, body_data, expect_file)
  73. send_terminate_cmd(app_client)
  74. def test_get_conditions(self, app_client):
  75. """Test get conditions for ascend."""
  76. url = '/v1/mindinsight/conditionmgr/train-jobs/train-id/condition-collections'
  77. body_data = {}
  78. expect_file = 'get_conditions_for_ascend.json'
  79. with self._debugger_client.get_thread_instance():
  80. check_state(app_client)
  81. send_and_compare_result(app_client, url, body_data, expect_file, method='get', full_url=True)
  82. send_terminate_cmd(app_client)
  83. @pytest.mark.level0
  84. @pytest.mark.env_single
  85. @pytest.mark.platform_x86_cpu
  86. @pytest.mark.platform_arm_ascend_training
  87. @pytest.mark.platform_x86_gpu_training
  88. @pytest.mark.platform_x86_ascend_training
  89. @pytest.mark.parametrize("body_data, expect_file", [
  90. ({'mode': 'all'}, 'multi_retrieve_all.json'),
  91. ({'mode': 'node', 'params': {'name': 'Default', 'graph_name': 'graph_1'}}, 'retrieve_scope_node.json'),
  92. ({'mode': 'node', 'params': {'name': 'graph_0'}}, 'multi_retrieve_scope_node.json'),
  93. ({'mode': 'node', 'params': {'name': 'graph_0/Default/optimizer-Momentum/Parameter[18]_7'}},
  94. 'multi_retrieve_aggregation_scope_node.json'),
  95. ({'mode': 'node', 'params': {
  96. 'name': 'graph_0/Default/TransData-op99',
  97. 'single_node': True}}, 'multi_retrieve_single_node.json'),
  98. ({'mode': 'node', 'params': {
  99. 'name': 'Default/TransData-op99',
  100. 'single_node': True, 'graph_name': 'graph_0'}}, 'retrieve_single_node.json')
  101. ])
  102. def test_multi_retrieve_when_train_begin(self, app_client, body_data, expect_file):
  103. """Test retrieve when train_begin."""
  104. url = 'retrieve'
  105. debugger_client = MockDebuggerClient(backend='Ascend', graph_num=2)
  106. with debugger_client.get_thread_instance():
  107. check_state(app_client)
  108. send_and_compare_result(app_client, url, body_data, expect_file)
  109. send_terminate_cmd(app_client)
  110. @pytest.mark.level0
  111. @pytest.mark.env_single
  112. @pytest.mark.platform_x86_cpu
  113. @pytest.mark.platform_arm_ascend_training
  114. @pytest.mark.platform_x86_gpu_training
  115. @pytest.mark.platform_x86_ascend_training
  116. def test_create_and_delete_watchpoint(self, app_client):
  117. """Test create and delete watchpoint."""
  118. with self._debugger_client.get_thread_instance():
  119. check_state(app_client)
  120. conditions = [
  121. {'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]},
  122. {'id': 'tensor_too_small', 'params': [{'name': 'max_lt', 'value': -1.0}]},
  123. {'id': 'tensor_too_large', 'params': [{'name': 'min_gt', 'value': 1e+32}]},
  124. {'id': 'tensor_too_small', 'params': [{'name': 'min_lt', 'value': -1e+32}]},
  125. {'id': 'tensor_too_large', 'params': [{'name': 'mean_gt', 'value': 0}]},
  126. {'id': 'tensor_too_small', 'params': [{'name': 'mean_lt', 'value': 0}]}
  127. ]
  128. for idx, condition in enumerate(conditions):
  129. create_watchpoint(app_client, condition, idx + 1)
  130. # delete 4-th watchpoint
  131. url = 'delete-watchpoint'
  132. body_data = {'watch_point_id': 4}
  133. get_request_result(app_client, url, body_data)
  134. # test watchpoint list
  135. url = 'retrieve'
  136. body_data = {'mode': 'watchpoint'}
  137. expect_file = 'create_and_delete_watchpoint.json'
  138. send_and_compare_result(app_client, url, body_data, expect_file)
  139. send_terminate_cmd(app_client)
  140. @pytest.mark.level0
  141. @pytest.mark.env_single
  142. @pytest.mark.platform_x86_cpu
  143. @pytest.mark.platform_arm_ascend_training
  144. @pytest.mark.platform_x86_gpu_training
  145. @pytest.mark.platform_x86_ascend_training
  146. def test_update_watchpoint(self, app_client):
  147. """Test retrieve when train_begin."""
  148. watch_point_id = 1
  149. leaf_node_name = 'Default/optimizer-Momentum/Parameter[18]_7/moments.fc3.bias'
  150. with self._debugger_client.get_thread_instance():
  151. check_state(app_client)
  152. condition = {'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]}
  153. create_watchpoint(app_client, condition, watch_point_id)
  154. # update watchpoint watchpoint list
  155. url = 'update-watchpoint'
  156. body_data = {'watch_point_id': watch_point_id,
  157. 'watch_nodes': [leaf_node_name],
  158. 'mode': 1}
  159. get_request_result(app_client, url, body_data)
  160. # get updated nodes
  161. url = 'search'
  162. body_data = {'name': leaf_node_name, 'watch_point_id': watch_point_id}
  163. expect_file = 'search_unwatched_leaf_node.json'
  164. send_and_compare_result(app_client, url, body_data, expect_file, method='get')
  165. send_terminate_cmd(app_client)
  166. @pytest.mark.level0
  167. @pytest.mark.env_single
  168. @pytest.mark.platform_x86_cpu
  169. @pytest.mark.platform_arm_ascend_training
  170. @pytest.mark.platform_x86_gpu_training
  171. @pytest.mark.platform_x86_ascend_training
  172. def test_retrieve_tensor_value(self, app_client):
  173. """Test retrieve tensor value."""
  174. node_name = 'Default/TransData-op99'
  175. with self._debugger_client.get_thread_instance():
  176. check_state(app_client)
  177. # prepare tensor value
  178. url = 'tensor-history'
  179. body_data = {'name': node_name}
  180. expect_file = 'retrieve_empty_tensor_history.json'
  181. send_and_compare_result(app_client, url, body_data, expect_file)
  182. # check full tensor history from poll data
  183. res = get_request_result(
  184. app_client=app_client, url='poll-data', body_data={'pos': 0}, method='get')
  185. assert res.get('receive_tensor', {}).get('node_name') == node_name
  186. expect_file = 'retrieve_full_tensor_history.json'
  187. send_and_compare_result(app_client, url, body_data, expect_file)
  188. # check tensor value
  189. url = 'tensors'
  190. body_data = {
  191. 'name': node_name + ':0',
  192. 'detail': 'data',
  193. 'shape': '[1, 1:3]'
  194. }
  195. expect_file = 'retrieve_tensor_value.json'
  196. send_and_compare_result(app_client, url, body_data, expect_file, method='get')
  197. send_terminate_cmd(app_client)
  198. @pytest.mark.level0
  199. @pytest.mark.env_single
  200. @pytest.mark.platform_x86_cpu
  201. @pytest.mark.platform_arm_ascend_training
  202. @pytest.mark.platform_x86_gpu_training
  203. @pytest.mark.platform_x86_ascend_training
  204. def test_compare_tensor_value(self, app_client):
  205. """Test compare tensor value."""
  206. node_name = 'Default/args0'
  207. with self._debugger_client.get_thread_instance():
  208. check_state(app_client)
  209. # prepare tensor values
  210. url = 'control'
  211. body_data = {'mode': 'continue',
  212. 'steps': 2}
  213. get_request_result(app_client, url, body_data)
  214. check_state(app_client)
  215. get_request_result(
  216. app_client=app_client, url='tensor-history', body_data={'name': node_name})
  217. res = get_request_result(
  218. app_client=app_client, url='poll-data', body_data={'pos': 0}, method='get')
  219. assert res.get('receive_tensor', {}).get('node_name') == node_name
  220. # get compare results
  221. url = 'tensor-comparisons'
  222. body_data = {
  223. 'name': node_name + ':0',
  224. 'detail': 'data',
  225. 'shape': '[:, :]',
  226. 'tolerance': 1
  227. }
  228. expect_file = 'compare_tensors.json'
  229. send_and_compare_result(app_client, url, body_data, expect_file, method='get')
  230. send_terminate_cmd(app_client)
  231. @pytest.mark.level0
  232. @pytest.mark.env_single
  233. @pytest.mark.platform_x86_cpu
  234. @pytest.mark.platform_arm_ascend_training
  235. @pytest.mark.platform_x86_gpu_training
  236. @pytest.mark.platform_x86_ascend_training
  237. @pytest.mark.parametrize("body_data, expect_file", [
  238. ({'ascend': True}, 'retrieve_node_by_bfs_ascend.json'),
  239. ({'name': 'Default/args0', 'ascend': False}, 'retrieve_node_by_bfs.json')
  240. ])
  241. def test_retrieve_bfs_node(self, app_client, body_data, expect_file):
  242. """Test retrieve bfs node."""
  243. with self._debugger_client.get_thread_instance():
  244. check_state(app_client)
  245. # prepare tensor values
  246. url = 'retrieve_node_by_bfs'
  247. send_and_compare_result(app_client, url, body_data, expect_file, method='get')
  248. send_terminate_cmd(app_client)
  249. @pytest.mark.level0
  250. @pytest.mark.env_single
  251. @pytest.mark.platform_x86_cpu
  252. @pytest.mark.platform_arm_ascend_training
  253. @pytest.mark.platform_x86_gpu_training
  254. @pytest.mark.platform_x86_ascend_training
  255. def test_pause(self, app_client):
  256. """Test pause the training."""
  257. with self._debugger_client.get_thread_instance():
  258. check_state(app_client)
  259. # send run command to execute to next node
  260. url = 'control'
  261. body_data = {'mode': 'continue',
  262. 'steps': -1}
  263. res = get_request_result(app_client, url, body_data)
  264. assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}}
  265. # send pause command
  266. check_state(app_client, 'running')
  267. url = 'control'
  268. body_data = {'mode': 'pause'}
  269. res = get_request_result(app_client, url, body_data)
  270. assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}}
  271. send_terminate_cmd(app_client)
  272. @pytest.mark.level0
  273. @pytest.mark.env_single
  274. @pytest.mark.platform_x86_cpu
  275. @pytest.mark.platform_arm_ascend_training
  276. @pytest.mark.platform_x86_gpu_training
  277. @pytest.mark.platform_x86_ascend_training
  278. @pytest.mark.parametrize("url, body_data, enable_recheck", [
  279. ('create-watchpoint',
  280. {'condition': {'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]},
  281. 'watch_nodes': ['Default']}, True),
  282. ('update-watchpoint',
  283. {'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum/Parameter[18]_7'],
  284. 'mode': 1}, True),
  285. ('update-watchpoint',
  286. {'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum'],
  287. 'mode': 1}, True),
  288. ('delete-watchpoint', {}, True)
  289. ])
  290. def test_recheck(self, app_client, url, body_data, enable_recheck):
  291. """Test recheck."""
  292. with self._debugger_client.get_thread_instance():
  293. create_watchpoint_and_wait(app_client)
  294. # create watchpoint
  295. res = get_request_result(app_client, url, body_data, method='post')
  296. assert res['metadata']['enable_recheck'] is enable_recheck
  297. send_terminate_cmd(app_client)
  298. @pytest.mark.level0
  299. @pytest.mark.env_single
  300. @pytest.mark.platform_x86_cpu
  301. @pytest.mark.platform_arm_ascend_training
  302. @pytest.mark.platform_x86_gpu_training
  303. @pytest.mark.platform_x86_ascend_training
  304. def test_recommend_watchpoints(self, app_client):
  305. """Test generating recommended watchpoints."""
  306. original_value = settings.ENABLE_RECOMMENDED_WATCHPOINTS
  307. settings.ENABLE_RECOMMENDED_WATCHPOINTS = True
  308. try:
  309. with self._debugger_client.get_thread_instance():
  310. check_state(app_client)
  311. url = 'retrieve'
  312. body_data = {'mode': 'watchpoint'}
  313. expect_file = 'recommended_watchpoints_at_startup.json'
  314. send_and_compare_result(app_client, url, body_data, expect_file, method='post')
  315. send_terminate_cmd(app_client)
  316. finally:
  317. settings.ENABLE_RECOMMENDED_WATCHPOINTS = original_value
  318. @pytest.mark.level0
  319. @pytest.mark.env_single
  320. @pytest.mark.platform_x86_cpu
  321. @pytest.mark.platform_arm_ascend_training
  322. @pytest.mark.platform_x86_gpu_training
  323. @pytest.mark.platform_x86_ascend_training
  324. @pytest.mark.parametrize("body_data, expect_file", [
  325. ({'tensor_name': 'Default/TransData-op99:0', 'graph_name': 'graph_0'}, 'retrieve_tensor_graph-0.json'),
  326. ({'tensor_name': 'Default/optimizer-Momentum/Parameter[18]_7/moments.fc1.bias:0', 'graph_name': 'graph_0'},
  327. 'retrieve_tensor_graph-1.json')
  328. ])
  329. def test_retrieve_tensor_graph(self, app_client, body_data, expect_file):
  330. """Test retrieve tensor graph."""
  331. url = 'tensor-graphs'
  332. with self._debugger_client.get_thread_instance():
  333. create_watchpoint_and_wait(app_client)
  334. get_request_result(app_client, url, body_data, method='GET')
  335. # check full tensor history from poll data
  336. res = get_request_result(
  337. app_client=app_client, url='poll-data', body_data={'pos': 0}, method='get')
  338. assert res.get('receive_tensor', {}).get('tensor_name') == body_data.get('tensor_name')
  339. send_and_compare_result(app_client, url, body_data, expect_file, method='GET')
  340. send_terminate_cmd(app_client)
  341. class TestGPUDebugger:
  342. """Test debugger on Ascend backend."""
  343. @classmethod
  344. def setup_class(cls):
  345. """Setup class."""
  346. cls._debugger_client = MockDebuggerClient(backend='GPU')
  347. @pytest.mark.level0
  348. @pytest.mark.env_single
  349. @pytest.mark.platform_x86_cpu
  350. @pytest.mark.platform_arm_ascend_training
  351. @pytest.mark.platform_x86_gpu_training
  352. @pytest.mark.platform_x86_ascend_training
  353. def test_next_node_on_gpu(self, app_client):
  354. """Test get next node on GPU."""
  355. gpu_debugger_client = MockDebuggerClient(backend='GPU')
  356. with gpu_debugger_client.get_thread_instance():
  357. check_state(app_client)
  358. # send run command to get watchpoint hit
  359. url = 'control'
  360. body_data = {'mode': 'continue',
  361. 'level': 'node',
  362. 'name': 'Default/TransData-op99'}
  363. res = get_request_result(app_client, url, body_data)
  364. assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}}
  365. # get metadata
  366. check_state(app_client)
  367. url = 'retrieve'
  368. body_data = {'mode': 'all'}
  369. expect_file = 'retrieve_next_node_on_gpu.json'
  370. send_and_compare_result(app_client, url, body_data, expect_file)
  371. send_terminate_cmd(app_client)
  372. @pytest.mark.level0
  373. @pytest.mark.env_single
  374. @pytest.mark.platform_x86_cpu
  375. @pytest.mark.platform_arm_ascend_training
  376. @pytest.mark.platform_x86_gpu_training
  377. @pytest.mark.platform_x86_ascend_training
  378. @pytest.mark.parametrize("url, body_data, enable_recheck", [
  379. ('create-watchpoint',
  380. {'condition': {'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]},
  381. 'watch_nodes': ['Default']}, True),
  382. ('create-watchpoint',
  383. {'condition': {'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]},
  384. 'watch_nodes': ['Default/TransData-op99']}, True),
  385. ('update-watchpoint',
  386. {'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum/Parameter[18]_7'],
  387. 'mode': 1}, True),
  388. ('update-watchpoint',
  389. {'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum'],
  390. 'mode': 1}, True),
  391. ('update-watchpoint',
  392. [{'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum'],
  393. 'mode': 1},
  394. {'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum'],
  395. 'mode': 0}
  396. ], True),
  397. ('update-watchpoint',
  398. [{'watch_point_id': 1, 'watch_nodes': ['Default/TransData-op99'],
  399. 'mode': 1},
  400. {'watch_point_id': 1, 'watch_nodes': ['Default/TransData-op99'],
  401. 'mode': 0}
  402. ], True),
  403. ('delete-watchpoint', {'watch_point_id': 1}, True)
  404. ])
  405. def test_recheck_state(self, app_client, url, body_data, enable_recheck):
  406. """Test update watchpoint and check the value of enable_recheck."""
  407. with self._debugger_client.get_thread_instance():
  408. create_watchpoint_and_wait(app_client)
  409. if not isinstance(body_data, list):
  410. body_data = [body_data]
  411. for sub_body_data in body_data:
  412. res = get_request_result(app_client, url, sub_body_data, method='post')
  413. assert res['metadata']['enable_recheck'] is enable_recheck
  414. send_terminate_cmd(app_client)
  415. def test_get_conditions(self, app_client):
  416. """Test get conditions for gpu."""
  417. url = '/v1/mindinsight/conditionmgr/train-jobs/train-id/condition-collections'
  418. body_data = {}
  419. expect_file = 'get_conditions_for_gpu.json'
  420. with self._debugger_client.get_thread_instance():
  421. check_state(app_client)
  422. send_and_compare_result(app_client, url, body_data, expect_file, method='get', full_url=True)
  423. send_terminate_cmd(app_client)
  424. @pytest.mark.level0
  425. @pytest.mark.env_single
  426. @pytest.mark.platform_x86_cpu
  427. @pytest.mark.platform_arm_ascend_training
  428. @pytest.mark.platform_x86_gpu_training
  429. @pytest.mark.platform_x86_ascend_training
  430. def test_recheck(self, app_client):
  431. """Test recheck request."""
  432. with self._debugger_client.get_thread_instance():
  433. create_watchpoint_and_wait(app_client)
  434. # send recheck when disable to do recheck
  435. get_request_result(app_client, 'recheck', {}, method='post', expect_code=400)
  436. # send recheck when enable to do recheck
  437. create_watchpoint(app_client, {'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]}, 2)
  438. res = get_request_result(app_client, 'recheck', {}, method='post')
  439. assert res['metadata']['enable_recheck'] is False
  440. send_terminate_cmd(app_client)
  441. @pytest.mark.level0
  442. @pytest.mark.env_single
  443. @pytest.mark.platform_x86_cpu
  444. @pytest.mark.platform_arm_ascend_training
  445. @pytest.mark.platform_x86_gpu_training
  446. @pytest.mark.platform_x86_ascend_training
  447. @pytest.mark.parametrize("filter_condition, expect_file", [
  448. ({'name': 'fc', 'node_category': 'weight'}, 'search_weight.json'),
  449. ({'name': 'fc', 'node_category': 'gradient'}, 'search_gradient.json'),
  450. ({'node_category': 'activation'}, 'search_activation.json')
  451. ])
  452. def test_search_by_category(self, app_client, filter_condition, expect_file):
  453. """Test recheck request."""
  454. with self._debugger_client.get_thread_instance():
  455. check_state(app_client)
  456. send_and_compare_result(app_client, 'search', filter_condition, expect_file,
  457. method='get')
  458. send_terminate_cmd(app_client)
  459. class TestMultiGraphDebugger:
  460. """Test debugger on Ascend backend for multi_graph."""
  461. @classmethod
  462. def setup_class(cls):
  463. """Setup class."""
  464. cls._debugger_client = MockDebuggerClient(backend='Ascend', graph_num=2)
  465. @pytest.mark.level0
  466. @pytest.mark.env_single
  467. @pytest.mark.platform_x86_cpu
  468. @pytest.mark.platform_arm_ascend_training
  469. @pytest.mark.platform_x86_gpu_training
  470. @pytest.mark.platform_x86_ascend_training
  471. @pytest.mark.parametrize("body_data, expect_file", [
  472. ({'mode': 'all'}, 'multi_retrieve_all.json'),
  473. ({'mode': 'node', 'params': {'name': 'Default', 'graph_name': 'graph_1'}}, 'retrieve_scope_node.json'),
  474. ({'mode': 'node', 'params': {'name': 'graph_0'}}, 'multi_retrieve_scope_node.json'),
  475. ({'mode': 'node', 'params': {'name': 'graph_0/Default/optimizer-Momentum/Parameter[18]_7'}},
  476. 'multi_retrieve_aggregation_scope_node.json'),
  477. ({'mode': 'node', 'params': {
  478. 'name': 'graph_0/Default/TransData-op99',
  479. 'single_node': True}}, 'multi_retrieve_single_node.json'),
  480. ({'mode': 'node', 'params': {
  481. 'name': 'Default/TransData-op99',
  482. 'single_node': True, 'graph_name': 'graph_0'}}, 'retrieve_single_node.json')
  483. ])
  484. def test_multi_retrieve_when_train_begin(self, app_client, body_data, expect_file):
  485. """Test retrieve when train_begin."""
  486. url = 'retrieve'
  487. with self._debugger_client.get_thread_instance():
  488. check_state(app_client)
  489. send_and_compare_result(app_client, url, body_data, expect_file)
  490. send_terminate_cmd(app_client)
  491. @pytest.mark.level0
  492. @pytest.mark.env_single
  493. @pytest.mark.platform_x86_cpu
  494. @pytest.mark.platform_arm_ascend_training
  495. @pytest.mark.platform_x86_gpu_training
  496. @pytest.mark.platform_x86_ascend_training
  497. @pytest.mark.parametrize("filter_condition, expect_file", [
  498. ({'name': '', 'node_category': 'weight'}, 'search_weight_multi_graph.json'),
  499. ({'node_category': 'activation'}, 'search_activation_multi_graph.json'),
  500. ({'node_category': 'gradient'}, 'search_gradient_multi_graph.json')
  501. ])
  502. def test_search_by_category_with_multi_graph(self, app_client, filter_condition, expect_file):
  503. """Test search by category request."""
  504. with self._debugger_client.get_thread_instance():
  505. check_state(app_client)
  506. send_and_compare_result(app_client, 'search', filter_condition, expect_file, method='get')
  507. send_terminate_cmd(app_client)
  508. @pytest.mark.level0
  509. @pytest.mark.env_single
  510. @pytest.mark.platform_x86_cpu
  511. @pytest.mark.platform_arm_ascend_training
  512. @pytest.mark.platform_x86_gpu_training
  513. @pytest.mark.platform_x86_ascend_training
  514. @pytest.mark.parametrize("filter_condition, expect_id", [
  515. ({'condition': {'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]},
  516. 'watch_nodes': ['Default/optimizer-Momentum/Parameter[18]_7'],
  517. 'graph_name': 'graph_0'}, 1),
  518. ({'condition': {'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]},
  519. 'watch_nodes': ['graph_0/Default/optimizer-Momentum/ApplyMomentum[8]_1'],
  520. 'graph_name': None}, 1)
  521. ])
  522. def test_create_watchpoint(self, app_client, filter_condition, expect_id):
  523. """Test create watchpoint with multiple graphs."""
  524. url = 'create-watchpoint'
  525. with self._debugger_client.get_thread_instance():
  526. check_state(app_client)
  527. res = get_request_result(app_client, url, filter_condition)
  528. assert res.get('id') == expect_id
  529. send_terminate_cmd(app_client)
  530. @pytest.mark.level0
  531. @pytest.mark.env_single
  532. @pytest.mark.platform_x86_cpu
  533. @pytest.mark.platform_arm_ascend_training
  534. @pytest.mark.platform_x86_gpu_training
  535. @pytest.mark.platform_x86_ascend_training
  536. @pytest.mark.parametrize("params, expect_file", [
  537. ({'level': 'node'}, 'multi_next_node.json'),
  538. ({'level': 'node', 'node_name': 'graph_0/Default/TransData-op99'}, 'multi_next_node.json'),
  539. ({'level': 'node', 'node_name': 'Default/TransData-op99', 'graph_name': 'graph_0'},
  540. 'multi_next_node.json')
  541. ])
  542. def test_continue_on_gpu(self, app_client, params, expect_file):
  543. """Test get next node on GPU."""
  544. gpu_debugger_client = MockDebuggerClient(backend='GPU', graph_num=2)
  545. original_value = settings.ENABLE_RECOMMENDED_WATCHPOINTS
  546. settings.ENABLE_RECOMMENDED_WATCHPOINTS = True
  547. try:
  548. with gpu_debugger_client.get_thread_instance():
  549. check_state(app_client)
  550. # send run command to get watchpoint hit
  551. url = 'control'
  552. body_data = {'mode': 'continue'}
  553. body_data.update(params)
  554. res = get_request_result(app_client, url, body_data)
  555. assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}}
  556. # get metadata
  557. check_state(app_client)
  558. url = 'retrieve'
  559. body_data = {'mode': 'all'}
  560. send_and_compare_result(app_client, url, body_data, expect_file)
  561. send_terminate_cmd(app_client)
  562. finally:
  563. settings.ENABLE_RECOMMENDED_WATCHPOINTS = original_value
  564. @pytest.mark.level0
  565. @pytest.mark.env_single
  566. @pytest.mark.platform_x86_cpu
  567. @pytest.mark.platform_arm_ascend_training
  568. @pytest.mark.platform_x86_gpu_training
  569. @pytest.mark.platform_x86_ascend_training
  570. @pytest.mark.parametrize("body_data, expect_file", [
  571. ({'tensor_name': 'Default/TransData-op99:0', 'graph_name': 'graph_0'}, 'retrieve_tensor_hits-0.json'),
  572. ({'tensor_name': 'Default/optimizer-Momentum/Parameter[18]_7/moments.fc1.bias:0', 'graph_name': 'graph_0'},
  573. 'retrieve_tensor_hits-1.json')
  574. ])
  575. def test_retrieve_tensor_hits(self, app_client, body_data, expect_file):
  576. """Test retrieve tensor graph."""
  577. url = 'tensor-hits'
  578. with self._debugger_client.get_thread_instance():
  579. check_state(app_client)
  580. send_and_compare_result(app_client, url, body_data, expect_file, method='GET')
  581. send_terminate_cmd(app_client)
  582. def create_watchpoint(app_client, condition, expect_id):
  583. """Create watchpoint."""
  584. url = 'create-watchpoint'
  585. body_data = {'condition': condition,
  586. 'watch_nodes': ['Default/optimizer-Momentum/Parameter[18]_7',
  587. 'Default/optimizer-Momentum/Parameter[18]_7/moments.fc3.bias',
  588. 'Default/optimizer-Momentum/Parameter[18]_7/moments.fc1.bias',
  589. 'Default/TransData-op99']}
  590. res = get_request_result(app_client, url, body_data)
  591. assert res.get('id') == expect_id
  592. def create_watchpoint_and_wait(app_client):
  593. """Preparation for recheck."""
  594. check_state(app_client)
  595. create_watchpoint(app_client, condition={'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]},
  596. expect_id=1)
  597. # send run command to get watchpoint hit
  598. url = 'control'
  599. body_data = {'mode': 'continue',
  600. 'steps': 2}
  601. res = get_request_result(app_client, url, body_data)
  602. assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}}
  603. # wait for server has received watchpoint hit
  604. check_state(app_client)
  605. class TestMismatchDebugger:
  606. """Test debugger when Mindinsight and Mindspore is mismatched."""
  607. @classmethod
  608. def setup_class(cls):
  609. """Setup class."""
  610. cls._debugger_client = MockDebuggerClient(backend='Ascend', ms_version='1.0.0')
  611. @pytest.mark.level0
  612. @pytest.mark.env_single
  613. @pytest.mark.platform_x86_cpu
  614. @pytest.mark.platform_arm_ascend_training
  615. @pytest.mark.platform_x86_gpu_training
  616. @pytest.mark.platform_x86_ascend_training
  617. @pytest.mark.parametrize("body_data, expect_file", [
  618. ({'mode': 'all'}, 'version_mismatch.json')
  619. ])
  620. def test_retrieve_when_version_mismatch(self, app_client, body_data, expect_file):
  621. """Test retrieve when train_begin."""
  622. url = 'retrieve'
  623. with self._debugger_client.get_thread_instance():
  624. check_state(app_client, ServerStatus.MISMATCH.value)
  625. send_and_compare_result(app_client, url, body_data, expect_file)
  626. send_terminate_cmd(app_client)