You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_restful_api.py 32 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Function:
  17. Test query debugger restful api.
  18. Usage:
  19. pytest tests/st/func/debugger/test_restful_api.py
  20. """
  21. import os
  22. from urllib.parse import quote
  23. import pytest
  24. from mindinsight.conf import settings
  25. from mindinsight.debugger.common.utils import ServerStatus
  26. from tests.st.func.debugger.conftest import DEBUGGER_BASE_URL
  27. from tests.st.func.debugger.mock_ms_client import MockDebuggerClient
  28. from tests.st.func.debugger.utils import check_state, get_request_result, \
  29. send_and_compare_result, send_and_save_result
  30. def send_terminate_cmd(app_client):
  31. """Send terminate command to debugger client."""
  32. url = os.path.join(DEBUGGER_BASE_URL, 'control')
  33. body_data = {'mode': 'terminate'}
  34. send_and_compare_result(app_client, url, body_data)
  35. class TestAscendDebugger:
  36. """Test debugger on Ascend backend."""
  37. @classmethod
  38. def setup_class(cls):
  39. """Setup class."""
  40. cls.save_results = False
  41. cls._debugger_client = MockDebuggerClient(backend='Ascend')
  42. @pytest.mark.level0
  43. @pytest.mark.env_single
  44. @pytest.mark.platform_x86_cpu
  45. @pytest.mark.platform_arm_ascend_training
  46. @pytest.mark.platform_x86_gpu_training
  47. @pytest.mark.platform_x86_ascend_training
  48. def test_before_train_begin(self, app_client):
  49. """Test retrieve all."""
  50. url = 'retrieve'
  51. body_data = {'mode': 'all'}
  52. expect_file = 'before_train_begin.json'
  53. send_and_compare_result(app_client, url, body_data, expect_file)
  54. @pytest.mark.level0
  55. @pytest.mark.env_single
  56. @pytest.mark.platform_x86_cpu
  57. @pytest.mark.platform_arm_ascend_training
  58. @pytest.mark.platform_x86_gpu_training
  59. @pytest.mark.platform_x86_ascend_training
  60. @pytest.mark.parametrize("body_data, expect_file", [
  61. ({'mode': 'all'}, 'retrieve_all.json'),
  62. ({'mode': 'node', 'params': {'name': 'Default'}}, 'retrieve_scope_node.json'),
  63. ({'mode': 'node', 'params': {'name': 'Default/optimizer-Momentum/Parameter[18]_7'}},
  64. 'retrieve_aggregation_scope_node.json'),
  65. ({'mode': 'node', 'params': {
  66. 'name': 'Default/TransData-op99',
  67. 'single_node': True}}, 'retrieve_single_node.json')
  68. ])
  69. def test_retrieve_when_train_begin(self, app_client, body_data, expect_file):
  70. """Test retrieve when train_begin."""
  71. url = 'retrieve'
  72. with self._debugger_client.get_thread_instance():
  73. check_state(app_client)
  74. if self.save_results:
  75. send_and_save_result(app_client, url, body_data, expect_file)
  76. send_and_compare_result(app_client, url, body_data, expect_file)
  77. send_terminate_cmd(app_client)
  78. def test_get_conditions(self, app_client):
  79. """Test get conditions for ascend."""
  80. url = '/v1/mindinsight/debugger/sessions/0/condition-collections'
  81. body_data = {}
  82. expect_file = 'get_conditions_for_ascend.json'
  83. with self._debugger_client.get_thread_instance():
  84. check_state(app_client)
  85. if self.save_results:
  86. send_and_save_result(app_client, url, body_data, expect_file, method='get', full_url=True)
  87. send_and_compare_result(app_client, url, body_data, expect_file, method='get', full_url=True)
  88. send_terminate_cmd(app_client)
  89. @pytest.mark.level0
  90. @pytest.mark.env_single
  91. @pytest.mark.platform_x86_cpu
  92. @pytest.mark.platform_arm_ascend_training
  93. @pytest.mark.platform_x86_gpu_training
  94. @pytest.mark.platform_x86_ascend_training
  95. @pytest.mark.parametrize("body_data, expect_file", [
  96. ({'mode': 'all'}, 'multi_retrieve_all.json'),
  97. ({'mode': 'node', 'params': {'name': 'Default', 'graph_name': 'graph_1'}}, 'retrieve_scope_node.json'),
  98. ({'mode': 'node', 'params': {'name': 'graph_0'}}, 'multi_retrieve_scope_node.json'),
  99. ({'mode': 'node', 'params': {'name': 'graph_0/Default/optimizer-Momentum/Parameter[18]_7'}},
  100. 'multi_retrieve_aggregation_scope_node.json'),
  101. ({'mode': 'node', 'params': {
  102. 'name': 'graph_0/Default/TransData-op99',
  103. 'single_node': True}}, 'multi_retrieve_single_node.json'),
  104. ({'mode': 'node', 'params': {
  105. 'name': 'Default/TransData-op99',
  106. 'single_node': True, 'graph_name': 'graph_0'}}, 'retrieve_single_node.json')
  107. ])
  108. def test_multi_retrieve_when_train_begin(self, app_client, body_data, expect_file):
  109. """Test retrieve when train_begin."""
  110. url = 'retrieve'
  111. debugger_client = MockDebuggerClient(backend='Ascend', graph_num=2)
  112. with debugger_client.get_thread_instance():
  113. check_state(app_client)
  114. if self.save_results:
  115. send_and_save_result(app_client, url, body_data, expect_file)
  116. send_and_compare_result(app_client, url, body_data, expect_file)
  117. send_terminate_cmd(app_client)
  118. @pytest.mark.level0
  119. @pytest.mark.env_single
  120. @pytest.mark.platform_x86_cpu
  121. @pytest.mark.platform_arm_ascend_training
  122. @pytest.mark.platform_x86_gpu_training
  123. @pytest.mark.platform_x86_ascend_training
  124. def test_create_and_delete_watchpoint(self, app_client):
  125. """Test create and delete watchpoint."""
  126. with self._debugger_client.get_thread_instance():
  127. check_state(app_client)
  128. conditions = [
  129. {'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]},
  130. {'id': 'tensor_too_small', 'params': [{'name': 'max_lt', 'value': -1.0}]},
  131. {'id': 'tensor_too_large', 'params': [{'name': 'min_gt', 'value': 1e+32}]},
  132. {'id': 'tensor_too_small', 'params': [{'name': 'min_lt', 'value': -1e+32}]},
  133. {'id': 'tensor_too_large', 'params': [{'name': 'mean_gt', 'value': 0}]},
  134. {'id': 'tensor_too_small', 'params': [{'name': 'mean_lt', 'value': 0}]}
  135. ]
  136. for idx, condition in enumerate(conditions):
  137. create_watchpoint(app_client, condition, idx + 1)
  138. # delete 4-th watchpoint
  139. url = 'delete-watchpoint'
  140. body_data = {'watch_point_id': 4}
  141. get_request_result(app_client, url, body_data)
  142. # test watchpoint list
  143. url = 'retrieve'
  144. body_data = {'mode': 'watchpoint'}
  145. expect_file = 'create_and_delete_watchpoint.json'
  146. if self.save_results:
  147. send_and_save_result(app_client, url, body_data, expect_file)
  148. send_and_compare_result(app_client, url, body_data, expect_file)
  149. send_terminate_cmd(app_client)
  150. @pytest.mark.level0
  151. @pytest.mark.env_single
  152. @pytest.mark.platform_x86_cpu
  153. @pytest.mark.platform_arm_ascend_training
  154. @pytest.mark.platform_x86_gpu_training
  155. @pytest.mark.platform_x86_ascend_training
  156. def test_update_watchpoint(self, app_client):
  157. """Test retrieve when train_begin."""
  158. watch_point_id = 1
  159. leaf_node_name = 'Default/optimizer-Momentum/Parameter[18]_7/moments.fc3.bias'
  160. with self._debugger_client.get_thread_instance():
  161. check_state(app_client)
  162. condition = {'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]}
  163. create_watchpoint(app_client, condition, watch_point_id)
  164. # update watchpoint watchpoint list
  165. url = 'update-watchpoint'
  166. body_data = {'watch_point_id': watch_point_id,
  167. 'watch_nodes': [leaf_node_name],
  168. 'mode': 1}
  169. get_request_result(app_client, url, body_data)
  170. # get updated nodes
  171. url = 'search'
  172. body_data = {'name': leaf_node_name, 'watch_point_id': watch_point_id}
  173. expect_file = 'search_unwatched_leaf_node.json'
  174. if self.save_results:
  175. send_and_save_result(app_client, url, body_data, expect_file, method='get')
  176. send_and_compare_result(app_client, url, body_data, expect_file, method='get')
  177. send_terminate_cmd(app_client)
  178. @pytest.mark.level0
  179. @pytest.mark.env_single
  180. @pytest.mark.platform_x86_cpu
  181. @pytest.mark.platform_arm_ascend_training
  182. @pytest.mark.platform_x86_gpu_training
  183. @pytest.mark.platform_x86_ascend_training
  184. def test_retrieve_tensor_value(self, app_client):
  185. """Test retrieve tensor value."""
  186. node_name = 'Default/TransData-op99'
  187. with self._debugger_client.get_thread_instance():
  188. check_state(app_client)
  189. # prepare tensor value
  190. url = 'tensor-history'
  191. body_data = {'name': node_name, 'rank_id': 0}
  192. expect_file = 'retrieve_empty_tensor_history.json'
  193. if self.save_results:
  194. send_and_save_result(app_client, url, body_data, expect_file)
  195. send_and_compare_result(app_client, url, body_data, expect_file)
  196. # check full tensor history from poll data
  197. res = get_request_result(
  198. app_client=app_client, url='poll-data', body_data={'pos': 0}, method='get')
  199. assert res.get('receive_tensor', {}).get('node_name') == node_name
  200. expect_file = 'retrieve_full_tensor_history.json'
  201. if self.save_results:
  202. send_and_save_result(app_client, url, body_data, expect_file)
  203. send_and_compare_result(app_client, url, body_data, expect_file)
  204. # check tensor value
  205. url = 'tensors'
  206. body_data = {
  207. 'name': node_name + ':0',
  208. 'detail': 'data',
  209. 'shape': quote('[1, 1:3]')
  210. }
  211. expect_file = 'retrieve_tensor_value.json'
  212. if self.save_results:
  213. send_and_save_result(app_client, url, body_data, expect_file, method='get')
  214. send_and_compare_result(app_client, url, body_data, expect_file, method='get')
  215. send_terminate_cmd(app_client)
  216. @pytest.mark.level0
  217. @pytest.mark.env_single
  218. @pytest.mark.platform_x86_cpu
  219. @pytest.mark.platform_arm_ascend_training
  220. @pytest.mark.platform_x86_gpu_training
  221. @pytest.mark.platform_x86_ascend_training
  222. def test_compare_tensor_value(self, app_client):
  223. """Test compare tensor value."""
  224. node_name = 'Default/args0'
  225. with self._debugger_client.get_thread_instance():
  226. check_state(app_client)
  227. # prepare tensor values
  228. url = 'control'
  229. body_data = {'mode': 'continue',
  230. 'steps': 2}
  231. get_request_result(app_client, url, body_data)
  232. check_state(app_client)
  233. get_request_result(
  234. app_client=app_client, url='tensor-history', body_data={'name': node_name, 'rank_id': 0})
  235. res = get_request_result(
  236. app_client=app_client, url='poll-data', body_data={'pos': 0}, method='get')
  237. assert res.get('receive_tensor', {}).get('node_name') == node_name
  238. # get compare results
  239. url = 'tensor-comparisons'
  240. body_data = {
  241. 'name': node_name + ':0',
  242. 'detail': 'data',
  243. 'shape': quote('[:, :]'),
  244. 'tolerance': 1,
  245. 'rank_id': 0}
  246. expect_file = 'compare_tensors.json'
  247. if self.save_results:
  248. send_and_save_result(app_client, url, body_data, expect_file, method='get')
  249. send_and_compare_result(app_client, url, body_data, expect_file, method='get')
  250. send_terminate_cmd(app_client)
  251. @pytest.mark.level0
  252. @pytest.mark.env_single
  253. @pytest.mark.platform_x86_cpu
  254. @pytest.mark.platform_arm_ascend_training
  255. @pytest.mark.platform_x86_gpu_training
  256. @pytest.mark.platform_x86_ascend_training
  257. def test_pause(self, app_client):
  258. """Test pause the training."""
  259. with self._debugger_client.get_thread_instance():
  260. check_state(app_client)
  261. # send run command to execute to next node
  262. url = 'control'
  263. body_data = {'mode': 'continue',
  264. 'steps': -1}
  265. res = get_request_result(app_client, url, body_data)
  266. assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}}
  267. # send pause command
  268. check_state(app_client, 'running')
  269. url = 'control'
  270. body_data = {'mode': 'pause'}
  271. res = get_request_result(app_client, url, body_data)
  272. assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}}
  273. send_terminate_cmd(app_client)
  274. @pytest.mark.level0
  275. @pytest.mark.env_single
  276. @pytest.mark.platform_x86_cpu
  277. @pytest.mark.platform_arm_ascend_training
  278. @pytest.mark.platform_x86_gpu_training
  279. @pytest.mark.platform_x86_ascend_training
  280. @pytest.mark.parametrize("url, body_data, enable_recheck", [
  281. ('create-watchpoint',
  282. {'condition': {'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]},
  283. 'watch_nodes': ['Default']}, True),
  284. ('update-watchpoint',
  285. {'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum/Parameter[18]_7'],
  286. 'mode': 1}, True),
  287. ('update-watchpoint',
  288. {'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum'],
  289. 'mode': 1}, True),
  290. ('delete-watchpoint', {}, True)
  291. ])
  292. def test_recheck(self, app_client, url, body_data, enable_recheck):
  293. """Test recheck."""
  294. with self._debugger_client.get_thread_instance():
  295. create_watchpoint_and_wait(app_client)
  296. # create watchpoint
  297. res = get_request_result(app_client, url, body_data, method='post')
  298. assert res['metadata']['enable_recheck'] is enable_recheck
  299. send_terminate_cmd(app_client)
  300. @pytest.mark.level0
  301. @pytest.mark.env_single
  302. @pytest.mark.platform_x86_cpu
  303. @pytest.mark.platform_arm_ascend_training
  304. @pytest.mark.platform_x86_gpu_training
  305. @pytest.mark.platform_x86_ascend_training
  306. def test_recommend_watchpoints(self, app_client):
  307. """Test generating recommended watchpoints."""
  308. original_value = settings.ENABLE_RECOMMENDED_WATCHPOINTS
  309. settings.ENABLE_RECOMMENDED_WATCHPOINTS = True
  310. try:
  311. with self._debugger_client.get_thread_instance():
  312. check_state(app_client)
  313. url = 'retrieve'
  314. body_data = {'mode': 'watchpoint'}
  315. expect_file = 'recommended_watchpoints_at_startup.json'
  316. send_and_compare_result(app_client, url, body_data, expect_file, method='post')
  317. send_terminate_cmd(app_client)
  318. finally:
  319. settings.ENABLE_RECOMMENDED_WATCHPOINTS = original_value
  320. @pytest.mark.level0
  321. @pytest.mark.env_single
  322. @pytest.mark.platform_x86_cpu
  323. @pytest.mark.platform_arm_ascend_training
  324. @pytest.mark.platform_x86_gpu_training
  325. @pytest.mark.platform_x86_ascend_training
  326. @pytest.mark.parametrize("body_data, expect_file", [
  327. ({'tensor_name': 'Default/TransData-op99:0', 'graph_name': 'graph_0'}, 'retrieve_tensor_graph-0.json'),
  328. ({'tensor_name': 'Default/optimizer-Momentum/Parameter[18]_7/moments.fc1.bias:0', 'graph_name': 'graph_0'},
  329. 'retrieve_tensor_graph-1.json')
  330. ])
  331. def test_retrieve_tensor_graph(self, app_client, body_data, expect_file):
  332. """Test retrieve tensor graph."""
  333. url = 'tensor-graphs'
  334. with self._debugger_client.get_thread_instance():
  335. create_watchpoint_and_wait(app_client)
  336. get_request_result(app_client, url, body_data, method='GET')
  337. # check full tensor history from poll data
  338. res = get_request_result(
  339. app_client=app_client, url='poll-data', body_data={'pos': 0}, method='get')
  340. assert res.get('receive_tensor', {}).get('tensor_name') == body_data.get('tensor_name')
  341. if self.save_results:
  342. send_and_save_result(app_client, url, body_data, expect_file, method='GET')
  343. send_and_compare_result(app_client, url, body_data, expect_file, method='GET')
  344. send_terminate_cmd(app_client)
  345. class TestGPUDebugger:
  346. """Test debugger on Ascend backend."""
  347. @classmethod
  348. def setup_class(cls):
  349. """Setup class."""
  350. cls.save_results = False
  351. cls._debugger_client = MockDebuggerClient(backend='GPU')
  352. @pytest.mark.level0
  353. @pytest.mark.env_single
  354. @pytest.mark.platform_x86_cpu
  355. @pytest.mark.platform_arm_ascend_training
  356. @pytest.mark.platform_x86_gpu_training
  357. @pytest.mark.platform_x86_ascend_training
  358. def test_next_node_on_gpu(self, app_client):
  359. """Test get next node on GPU."""
  360. gpu_debugger_client = MockDebuggerClient(backend='GPU')
  361. check_state(app_client, 'pending')
  362. with gpu_debugger_client.get_thread_instance():
  363. check_state(app_client)
  364. # send run command to get watchpoint hit
  365. url = 'control'
  366. body_data = {'mode': 'continue',
  367. 'level': 'node',
  368. 'name': 'Default/TransData-op99'}
  369. res = get_request_result(app_client, url, body_data)
  370. assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}}
  371. # get metadata
  372. check_state(app_client)
  373. url = 'retrieve'
  374. body_data = {'mode': 'all'}
  375. expect_file = 'retrieve_next_node_on_gpu.json'
  376. if self.save_results:
  377. send_and_save_result(app_client, url, body_data, expect_file)
  378. send_and_compare_result(app_client, url, body_data, expect_file)
  379. send_terminate_cmd(app_client)
  380. @pytest.mark.level0
  381. @pytest.mark.env_single
  382. @pytest.mark.platform_x86_cpu
  383. @pytest.mark.platform_arm_ascend_training
  384. @pytest.mark.platform_x86_gpu_training
  385. @pytest.mark.platform_x86_ascend_training
  386. @pytest.mark.parametrize("url, body_data, enable_recheck", [
  387. ('create-watchpoint',
  388. {'condition': {'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]},
  389. 'watch_nodes': ['Default']}, True),
  390. ('create-watchpoint',
  391. {'condition': {'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]},
  392. 'watch_nodes': ['Default/TransData-op99']}, True),
  393. ('update-watchpoint',
  394. {'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum/Parameter[18]_7'],
  395. 'mode': 1}, True),
  396. ('update-watchpoint',
  397. {'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum'],
  398. 'mode': 1}, True),
  399. ('update-watchpoint',
  400. [{'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum'],
  401. 'mode': 1},
  402. {'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum'],
  403. 'mode': 0}
  404. ], True),
  405. ('update-watchpoint',
  406. [{'watch_point_id': 1, 'watch_nodes': ['Default/TransData-op99'],
  407. 'mode': 1},
  408. {'watch_point_id': 1, 'watch_nodes': ['Default/TransData-op99'],
  409. 'mode': 0}
  410. ], True),
  411. ('delete-watchpoint', {'watch_point_id': 1}, True)
  412. ])
  413. def test_recheck_state(self, app_client, url, body_data, enable_recheck):
  414. """Test update watchpoint and check the value of enable_recheck."""
  415. with self._debugger_client.get_thread_instance():
  416. create_watchpoint_and_wait(app_client)
  417. if not isinstance(body_data, list):
  418. body_data = [body_data]
  419. for sub_body_data in body_data:
  420. res = get_request_result(app_client, url, sub_body_data, method='post')
  421. assert res['metadata']['enable_recheck'] is enable_recheck
  422. send_terminate_cmd(app_client)
  423. def test_get_conditions(self, app_client):
  424. """Test get conditions for gpu."""
  425. url = '/v1/mindinsight/debugger/sessions/0/condition-collections'
  426. body_data = {}
  427. expect_file = 'get_conditions_for_gpu.json'
  428. with self._debugger_client.get_thread_instance():
  429. check_state(app_client)
  430. if self.save_results:
  431. send_and_save_result(app_client, url, body_data, expect_file, method='get', full_url=True)
  432. send_and_compare_result(app_client, url, body_data, expect_file, method='get', full_url=True)
  433. send_terminate_cmd(app_client)
  434. @pytest.mark.level0
  435. @pytest.mark.env_single
  436. @pytest.mark.platform_x86_cpu
  437. @pytest.mark.platform_arm_ascend_training
  438. @pytest.mark.platform_x86_gpu_training
  439. @pytest.mark.platform_x86_ascend_training
  440. def test_recheck(self, app_client):
  441. """Test recheck request."""
  442. with self._debugger_client.get_thread_instance():
  443. create_watchpoint_and_wait(app_client)
  444. # send recheck when disable to do recheck
  445. get_request_result(app_client, 'recheck', {}, method='post', expect_code=400)
  446. # send recheck when enable to do recheck
  447. create_watchpoint(app_client, {'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]}, 2)
  448. res = get_request_result(app_client, 'recheck', {}, method='post')
  449. assert res['metadata']['enable_recheck'] is False
  450. send_terminate_cmd(app_client)
  451. @pytest.mark.level0
  452. @pytest.mark.env_single
  453. @pytest.mark.platform_x86_cpu
  454. @pytest.mark.platform_arm_ascend_training
  455. @pytest.mark.platform_x86_gpu_training
  456. @pytest.mark.platform_x86_ascend_training
  457. @pytest.mark.parametrize("filter_condition, expect_file", [
  458. ({'name': 'fc', 'node_category': 'weight'}, 'search_weight.json'),
  459. ({'name': 'fc', 'node_category': 'gradient'}, 'search_gradient.json'),
  460. ({'node_category': 'activation'}, 'search_activation.json')
  461. ])
  462. def test_search_by_category(self, app_client, filter_condition, expect_file):
  463. """Test recheck request."""
  464. with self._debugger_client.get_thread_instance():
  465. check_state(app_client)
  466. send_and_compare_result(app_client, 'search', filter_condition, expect_file,
  467. method='get')
  468. send_terminate_cmd(app_client)
  469. class TestMultiGraphDebugger:
  470. """Test debugger on Ascend backend for multi_graph."""
  471. @classmethod
  472. def setup_class(cls):
  473. """Setup class."""
  474. cls.save_results = False
  475. cls._debugger_client = MockDebuggerClient(backend='Ascend', graph_num=2)
  476. @pytest.mark.level0
  477. @pytest.mark.env_single
  478. @pytest.mark.platform_x86_cpu
  479. @pytest.mark.platform_arm_ascend_training
  480. @pytest.mark.platform_x86_gpu_training
  481. @pytest.mark.platform_x86_ascend_training
  482. @pytest.mark.parametrize("body_data, expect_file", [
  483. ({'mode': 'all'}, 'multi_retrieve_all.json'),
  484. ({'mode': 'node', 'params': {'name': 'Default', 'graph_name': 'graph_1'}}, 'retrieve_scope_node.json'),
  485. ({'mode': 'node', 'params': {'name': 'graph_0'}}, 'multi_retrieve_scope_node.json'),
  486. ({'mode': 'node', 'params': {'name': 'graph_0/Default/optimizer-Momentum/Parameter[18]_7'}},
  487. 'multi_retrieve_aggregation_scope_node.json'),
  488. ({'mode': 'node', 'params': {
  489. 'name': 'graph_0/Default/TransData-op99',
  490. 'single_node': True}}, 'multi_retrieve_single_node.json'),
  491. ({'mode': 'node', 'params': {
  492. 'name': 'Default/TransData-op99',
  493. 'single_node': True, 'graph_name': 'graph_0'}}, 'retrieve_single_node.json')
  494. ])
  495. def test_multi_retrieve_when_train_begin(self, app_client, body_data, expect_file):
  496. """Test retrieve when train_begin."""
  497. url = 'retrieve'
  498. check_state(app_client, 'pending')
  499. with self._debugger_client.get_thread_instance():
  500. check_state(app_client)
  501. if self.save_results:
  502. send_and_save_result(app_client, url, body_data, expect_file)
  503. send_and_compare_result(app_client, url, body_data, expect_file)
  504. send_terminate_cmd(app_client)
  505. @pytest.mark.level0
  506. @pytest.mark.env_single
  507. @pytest.mark.platform_x86_cpu
  508. @pytest.mark.platform_arm_ascend_training
  509. @pytest.mark.platform_x86_gpu_training
  510. @pytest.mark.platform_x86_ascend_training
  511. @pytest.mark.parametrize("filter_condition, expect_file", [
  512. ({'name': '', 'node_category': 'weight'}, 'search_weight_multi_graph.json'),
  513. ({'node_category': 'activation'}, 'search_activation_multi_graph.json'),
  514. ({'node_category': 'gradient'}, 'search_gradient_multi_graph.json')
  515. ])
  516. def test_search_by_category_with_multi_graph(self, app_client, filter_condition, expect_file):
  517. """Test search by category request."""
  518. with self._debugger_client.get_thread_instance():
  519. check_state(app_client)
  520. if self.save_results:
  521. send_and_save_result(app_client, 'search', filter_condition, expect_file, method='get')
  522. send_and_compare_result(app_client, 'search', filter_condition, expect_file, method='get')
  523. send_terminate_cmd(app_client)
  524. @pytest.mark.level0
  525. @pytest.mark.env_single
  526. @pytest.mark.platform_x86_cpu
  527. @pytest.mark.platform_arm_ascend_training
  528. @pytest.mark.platform_x86_gpu_training
  529. @pytest.mark.platform_x86_ascend_training
  530. @pytest.mark.parametrize("filter_condition, expect_id", [
  531. ({'condition': {'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]},
  532. 'watch_nodes': ['Default/optimizer-Momentum/Parameter[18]_7'],
  533. 'graph_name': 'graph_0'}, 1),
  534. ({'condition': {'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]},
  535. 'watch_nodes': ['graph_0/Default/optimizer-Momentum/ApplyMomentum[8]_1'],
  536. 'graph_name': None}, 1)
  537. ])
  538. def test_create_watchpoint(self, app_client, filter_condition, expect_id):
  539. """Test create watchpoint with multiple graphs."""
  540. url = 'create-watchpoint'
  541. with self._debugger_client.get_thread_instance():
  542. check_state(app_client)
  543. res = get_request_result(app_client, url, filter_condition)
  544. assert res.get('id') == expect_id
  545. send_terminate_cmd(app_client)
  546. @pytest.mark.level0
  547. @pytest.mark.env_single
  548. @pytest.mark.platform_x86_cpu
  549. @pytest.mark.platform_arm_ascend_training
  550. @pytest.mark.platform_x86_gpu_training
  551. @pytest.mark.platform_x86_ascend_training
  552. @pytest.mark.parametrize("params, expect_file", [
  553. ({'level': 'node'}, 'multi_next_node.json'),
  554. ({'level': 'node', 'node_name': 'graph_0/Default/TransData-op99'}, 'multi_next_node.json'),
  555. ({'level': 'node', 'node_name': 'Default/TransData-op99', 'graph_name': 'graph_0'},
  556. 'multi_next_node.json')
  557. ])
  558. def test_continue_on_gpu(self, app_client, params, expect_file):
  559. """Test get next node on GPU."""
  560. gpu_debugger_client = MockDebuggerClient(backend='GPU', graph_num=2)
  561. original_value = settings.ENABLE_RECOMMENDED_WATCHPOINTS
  562. settings.ENABLE_RECOMMENDED_WATCHPOINTS = True
  563. try:
  564. with gpu_debugger_client.get_thread_instance():
  565. check_state(app_client)
  566. # send run command to get watchpoint hit
  567. url = 'control'
  568. body_data = {'mode': 'continue'}
  569. body_data.update(params)
  570. res = get_request_result(app_client, url, body_data)
  571. assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}}
  572. # get metadata
  573. check_state(app_client)
  574. url = 'retrieve'
  575. body_data = {'mode': 'all'}
  576. if self.save_results:
  577. send_and_save_result(app_client, url, body_data, expect_file)
  578. send_and_compare_result(app_client, url, body_data, expect_file)
  579. send_terminate_cmd(app_client)
  580. finally:
  581. settings.ENABLE_RECOMMENDED_WATCHPOINTS = original_value
  582. @pytest.mark.level0
  583. @pytest.mark.env_single
  584. @pytest.mark.platform_x86_cpu
  585. @pytest.mark.platform_arm_ascend_training
  586. @pytest.mark.platform_x86_gpu_training
  587. @pytest.mark.platform_x86_ascend_training
  588. @pytest.mark.parametrize("body_data, expect_file", [
  589. ({'tensor_name': 'Default/TransData-op99:0', 'graph_name': 'graph_0'}, 'retrieve_tensor_hits-0.json'),
  590. ({'tensor_name': 'Default/optimizer-Momentum/Parameter[18]_7/moments.fc1.bias:0', 'graph_name': 'graph_0'},
  591. 'retrieve_tensor_hits-1.json')
  592. ])
  593. def test_retrieve_tensor_hits(self, app_client, body_data, expect_file):
  594. """Test retrieve tensor graph."""
  595. url = 'tensor-hits'
  596. with self._debugger_client.get_thread_instance():
  597. check_state(app_client)
  598. if self.save_results:
  599. send_and_save_result(app_client, url, body_data, expect_file, method='GET')
  600. send_and_compare_result(app_client, url, body_data, expect_file, method='GET')
  601. send_terminate_cmd(app_client)
  602. def create_watchpoint(app_client, condition, expect_id):
  603. """Create watchpoint."""
  604. url = 'create-watchpoint'
  605. body_data = {'condition': condition,
  606. 'watch_nodes': ['Default/optimizer-Momentum/Parameter[18]_7',
  607. 'Default/optimizer-Momentum/Parameter[18]_7/moments.fc3.bias',
  608. 'Default/optimizer-Momentum/Parameter[18]_7/moments.fc1.bias',
  609. 'Default/TransData-op99']}
  610. res = get_request_result(app_client, url, body_data)
  611. assert res.get('id') == expect_id
  612. def create_watchpoint_and_wait(app_client):
  613. """Preparation for recheck."""
  614. check_state(app_client)
  615. create_watchpoint(app_client, condition={'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]},
  616. expect_id=1)
  617. # send run command to get watchpoint hit
  618. url = 'control'
  619. body_data = {'mode': 'continue',
  620. 'steps': 2}
  621. res = get_request_result(app_client, url, body_data)
  622. assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}}
  623. # wait for server has received watchpoint hit
  624. check_state(app_client)
  625. class TestMismatchDebugger:
  626. """Test debugger when Mindinsight and Mindspore is mismatched."""
  627. @classmethod
  628. def setup_class(cls):
  629. """Setup class."""
  630. cls._debugger_client = MockDebuggerClient(backend='Ascend', ms_version='1.0.0')
  631. @pytest.mark.level0
  632. @pytest.mark.env_single
  633. @pytest.mark.platform_x86_cpu
  634. @pytest.mark.platform_arm_ascend_training
  635. @pytest.mark.platform_x86_gpu_training
  636. @pytest.mark.platform_x86_ascend_training
  637. @pytest.mark.parametrize("body_data, expect_file", [
  638. ({'mode': 'all'}, 'version_mismatch.json')
  639. ])
  640. def test_retrieve_when_version_mismatch(self, app_client, body_data, expect_file):
  641. """Test retrieve when train_begin."""
  642. url = 'retrieve'
  643. with self._debugger_client.get_thread_instance():
  644. check_state(app_client, ServerStatus.MISMATCH.value)
  645. send_and_compare_result(app_client, url, body_data, expect_file)
  646. send_terminate_cmd(app_client)