You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_restful_api.py 30 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Function:
  17. Test query debugger restful api.
  18. Usage:
  19. pytest tests/st/func/debugger/test_restful_api.py
  20. """
  21. import os
  22. from urllib.parse import quote
  23. import pytest
  24. from mindinsight.conf import settings
  25. from mindinsight.debugger.common.utils import ServerStatus
  26. from tests.st.func.debugger.conftest import DEBUGGER_BASE_URL
  27. from tests.st.func.debugger.mock_ms_client import MockDebuggerClient
  28. from tests.st.func.debugger.utils import check_state, get_request_result, \
  29. send_and_compare_result
  30. def send_terminate_cmd(app_client):
  31. """Send terminate command to debugger client."""
  32. url = os.path.join(DEBUGGER_BASE_URL, 'control')
  33. body_data = {'mode': 'terminate'}
  34. send_and_compare_result(app_client, url, body_data)
  35. class TestAscendDebugger:
  36. """Test debugger on Ascend backend."""
  37. @classmethod
  38. def setup_class(cls):
  39. """Setup class."""
  40. cls._debugger_client = MockDebuggerClient(backend='Ascend')
  41. @pytest.mark.level0
  42. @pytest.mark.env_single
  43. @pytest.mark.platform_x86_cpu
  44. @pytest.mark.platform_arm_ascend_training
  45. @pytest.mark.platform_x86_gpu_training
  46. @pytest.mark.platform_x86_ascend_training
  47. def test_before_train_begin(self, app_client):
  48. """Test retrieve all."""
  49. url = 'retrieve'
  50. body_data = {'mode': 'all'}
  51. expect_file = 'before_train_begin.json'
  52. send_and_compare_result(app_client, url, body_data, expect_file)
  53. @pytest.mark.level0
  54. @pytest.mark.env_single
  55. @pytest.mark.platform_x86_cpu
  56. @pytest.mark.platform_arm_ascend_training
  57. @pytest.mark.platform_x86_gpu_training
  58. @pytest.mark.platform_x86_ascend_training
  59. @pytest.mark.parametrize("body_data, expect_file", [
  60. ({'mode': 'all'}, 'retrieve_all.json'),
  61. ({'mode': 'node', 'params': {'name': 'Default'}}, 'retrieve_scope_node.json'),
  62. ({'mode': 'node', 'params': {'name': 'Default/optimizer-Momentum/Parameter[18]_7'}},
  63. 'retrieve_aggregation_scope_node.json'),
  64. ({'mode': 'node', 'params': {
  65. 'name': 'Default/TransData-op99',
  66. 'single_node': True}}, 'retrieve_single_node.json')
  67. ])
  68. def test_retrieve_when_train_begin(self, app_client, body_data, expect_file):
  69. """Test retrieve when train_begin."""
  70. url = 'retrieve'
  71. with self._debugger_client.get_thread_instance():
  72. check_state(app_client)
  73. send_and_compare_result(app_client, url, body_data, expect_file)
  74. send_terminate_cmd(app_client)
  75. def test_get_conditions(self, app_client):
  76. """Test get conditions for ascend."""
  77. url = '/v1/mindinsight/debugger/sessions/0/condition-collections'
  78. body_data = {}
  79. expect_file = 'get_conditions_for_ascend.json'
  80. with self._debugger_client.get_thread_instance():
  81. check_state(app_client)
  82. send_and_compare_result(app_client, url, body_data, expect_file, method='get', full_url=True)
  83. send_terminate_cmd(app_client)
  84. @pytest.mark.level0
  85. @pytest.mark.env_single
  86. @pytest.mark.platform_x86_cpu
  87. @pytest.mark.platform_arm_ascend_training
  88. @pytest.mark.platform_x86_gpu_training
  89. @pytest.mark.platform_x86_ascend_training
  90. @pytest.mark.parametrize("body_data, expect_file", [
  91. ({'mode': 'all'}, 'multi_retrieve_all.json'),
  92. ({'mode': 'node', 'params': {'name': 'Default', 'graph_name': 'graph_1'}}, 'retrieve_scope_node.json'),
  93. ({'mode': 'node', 'params': {'name': 'graph_0'}}, 'multi_retrieve_scope_node.json'),
  94. ({'mode': 'node', 'params': {'name': 'graph_0/Default/optimizer-Momentum/Parameter[18]_7'}},
  95. 'multi_retrieve_aggregation_scope_node.json'),
  96. ({'mode': 'node', 'params': {
  97. 'name': 'graph_0/Default/TransData-op99',
  98. 'single_node': True}}, 'multi_retrieve_single_node.json'),
  99. ({'mode': 'node', 'params': {
  100. 'name': 'Default/TransData-op99',
  101. 'single_node': True, 'graph_name': 'graph_0'}}, 'retrieve_single_node.json')
  102. ])
  103. def test_multi_retrieve_when_train_begin(self, app_client, body_data, expect_file):
  104. """Test retrieve when train_begin."""
  105. url = 'retrieve'
  106. debugger_client = MockDebuggerClient(backend='Ascend', graph_num=2)
  107. with debugger_client.get_thread_instance():
  108. check_state(app_client)
  109. send_and_compare_result(app_client, url, body_data, expect_file)
  110. send_terminate_cmd(app_client)
  111. @pytest.mark.level0
  112. @pytest.mark.env_single
  113. @pytest.mark.platform_x86_cpu
  114. @pytest.mark.platform_arm_ascend_training
  115. @pytest.mark.platform_x86_gpu_training
  116. @pytest.mark.platform_x86_ascend_training
  117. def test_create_and_delete_watchpoint(self, app_client):
  118. """Test create and delete watchpoint."""
  119. with self._debugger_client.get_thread_instance():
  120. check_state(app_client)
  121. conditions = [
  122. {'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]},
  123. {'id': 'tensor_too_small', 'params': [{'name': 'max_lt', 'value': -1.0}]},
  124. {'id': 'tensor_too_large', 'params': [{'name': 'min_gt', 'value': 1e+32}]},
  125. {'id': 'tensor_too_small', 'params': [{'name': 'min_lt', 'value': -1e+32}]},
  126. {'id': 'tensor_too_large', 'params': [{'name': 'mean_gt', 'value': 0}]},
  127. {'id': 'tensor_too_small', 'params': [{'name': 'mean_lt', 'value': 0}]}
  128. ]
  129. for idx, condition in enumerate(conditions):
  130. create_watchpoint(app_client, condition, idx + 1)
  131. # delete 4-th watchpoint
  132. url = 'delete-watchpoint'
  133. body_data = {'watch_point_id': 4}
  134. get_request_result(app_client, url, body_data)
  135. # test watchpoint list
  136. url = 'retrieve'
  137. body_data = {'mode': 'watchpoint'}
  138. expect_file = 'create_and_delete_watchpoint.json'
  139. send_and_compare_result(app_client, url, body_data, expect_file)
  140. send_terminate_cmd(app_client)
  141. @pytest.mark.level0
  142. @pytest.mark.env_single
  143. @pytest.mark.platform_x86_cpu
  144. @pytest.mark.platform_arm_ascend_training
  145. @pytest.mark.platform_x86_gpu_training
  146. @pytest.mark.platform_x86_ascend_training
  147. def test_update_watchpoint(self, app_client):
  148. """Test retrieve when train_begin."""
  149. watch_point_id = 1
  150. leaf_node_name = 'Default/optimizer-Momentum/Parameter[18]_7/moments.fc3.bias'
  151. with self._debugger_client.get_thread_instance():
  152. check_state(app_client)
  153. condition = {'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]}
  154. create_watchpoint(app_client, condition, watch_point_id)
  155. # update watchpoint watchpoint list
  156. url = 'update-watchpoint'
  157. body_data = {'watch_point_id': watch_point_id,
  158. 'watch_nodes': [leaf_node_name],
  159. 'mode': 1}
  160. get_request_result(app_client, url, body_data)
  161. # get updated nodes
  162. url = 'search'
  163. body_data = {'name': leaf_node_name, 'watch_point_id': watch_point_id}
  164. expect_file = 'search_unwatched_leaf_node.json'
  165. send_and_compare_result(app_client, url, body_data, expect_file, method='get')
  166. send_terminate_cmd(app_client)
  167. @pytest.mark.level0
  168. @pytest.mark.env_single
  169. @pytest.mark.platform_x86_cpu
  170. @pytest.mark.platform_arm_ascend_training
  171. @pytest.mark.platform_x86_gpu_training
  172. @pytest.mark.platform_x86_ascend_training
  173. def test_retrieve_tensor_value(self, app_client):
  174. """Test retrieve tensor value."""
  175. node_name = 'Default/TransData-op99'
  176. with self._debugger_client.get_thread_instance():
  177. check_state(app_client)
  178. # prepare tensor value
  179. url = 'tensor-history'
  180. body_data = {'name': node_name, 'rank_id': 0}
  181. expect_file = 'retrieve_empty_tensor_history.json'
  182. send_and_compare_result(app_client, url, body_data, expect_file)
  183. # check full tensor history from poll data
  184. res = get_request_result(
  185. app_client=app_client, url='poll-data', body_data={'pos': 0}, method='get')
  186. assert res.get('receive_tensor', {}).get('node_name') == node_name
  187. expect_file = 'retrieve_full_tensor_history.json'
  188. send_and_compare_result(app_client, url, body_data, expect_file)
  189. # check tensor value
  190. url = 'tensors'
  191. body_data = {
  192. 'name': node_name + ':0',
  193. 'detail': 'data',
  194. 'shape': quote('[1, 1:3]')
  195. }
  196. expect_file = 'retrieve_tensor_value.json'
  197. send_and_compare_result(app_client, url, body_data, expect_file, method='get')
  198. send_terminate_cmd(app_client)
  199. @pytest.mark.level0
  200. @pytest.mark.env_single
  201. @pytest.mark.platform_x86_cpu
  202. @pytest.mark.platform_arm_ascend_training
  203. @pytest.mark.platform_x86_gpu_training
  204. @pytest.mark.platform_x86_ascend_training
  205. def test_compare_tensor_value(self, app_client):
  206. """Test compare tensor value."""
  207. node_name = 'Default/args0'
  208. with self._debugger_client.get_thread_instance():
  209. check_state(app_client)
  210. # prepare tensor values
  211. url = 'control'
  212. body_data = {'mode': 'continue',
  213. 'steps': 2}
  214. get_request_result(app_client, url, body_data)
  215. check_state(app_client)
  216. get_request_result(
  217. app_client=app_client, url='tensor-history', body_data={'name': node_name, 'rank_id': 0})
  218. res = get_request_result(
  219. app_client=app_client, url='poll-data', body_data={'pos': 0}, method='get')
  220. assert res.get('receive_tensor', {}).get('node_name') == node_name
  221. # get compare results
  222. url = 'tensor-comparisons'
  223. body_data = {
  224. 'name': node_name + ':0',
  225. 'detail': 'data',
  226. 'shape': quote('[:, :]'),
  227. 'tolerance': 1,
  228. 'rank_id': 0}
  229. expect_file = 'compare_tensors.json'
  230. send_and_compare_result(app_client, url, body_data, expect_file, method='get')
  231. send_terminate_cmd(app_client)
  232. @pytest.mark.level0
  233. @pytest.mark.env_single
  234. @pytest.mark.platform_x86_cpu
  235. @pytest.mark.platform_arm_ascend_training
  236. @pytest.mark.platform_x86_gpu_training
  237. @pytest.mark.platform_x86_ascend_training
  238. def test_pause(self, app_client):
  239. """Test pause the training."""
  240. with self._debugger_client.get_thread_instance():
  241. check_state(app_client)
  242. # send run command to execute to next node
  243. url = 'control'
  244. body_data = {'mode': 'continue',
  245. 'steps': -1}
  246. res = get_request_result(app_client, url, body_data)
  247. assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}}
  248. # send pause command
  249. check_state(app_client, 'running')
  250. url = 'control'
  251. body_data = {'mode': 'pause'}
  252. res = get_request_result(app_client, url, body_data)
  253. assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}}
  254. send_terminate_cmd(app_client)
  255. @pytest.mark.level0
  256. @pytest.mark.env_single
  257. @pytest.mark.platform_x86_cpu
  258. @pytest.mark.platform_arm_ascend_training
  259. @pytest.mark.platform_x86_gpu_training
  260. @pytest.mark.platform_x86_ascend_training
  261. @pytest.mark.parametrize("url, body_data, enable_recheck", [
  262. ('create-watchpoint',
  263. {'condition': {'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]},
  264. 'watch_nodes': ['Default']}, True),
  265. ('update-watchpoint',
  266. {'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum/Parameter[18]_7'],
  267. 'mode': 1}, True),
  268. ('update-watchpoint',
  269. {'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum'],
  270. 'mode': 1}, True),
  271. ('delete-watchpoint', {}, True)
  272. ])
  273. def test_recheck(self, app_client, url, body_data, enable_recheck):
  274. """Test recheck."""
  275. with self._debugger_client.get_thread_instance():
  276. create_watchpoint_and_wait(app_client)
  277. # create watchpoint
  278. res = get_request_result(app_client, url, body_data, method='post')
  279. assert res['metadata']['enable_recheck'] is enable_recheck
  280. send_terminate_cmd(app_client)
  281. @pytest.mark.level0
  282. @pytest.mark.env_single
  283. @pytest.mark.platform_x86_cpu
  284. @pytest.mark.platform_arm_ascend_training
  285. @pytest.mark.platform_x86_gpu_training
  286. @pytest.mark.platform_x86_ascend_training
  287. def test_recommend_watchpoints(self, app_client):
  288. """Test generating recommended watchpoints."""
  289. original_value = settings.ENABLE_RECOMMENDED_WATCHPOINTS
  290. settings.ENABLE_RECOMMENDED_WATCHPOINTS = True
  291. try:
  292. with self._debugger_client.get_thread_instance():
  293. check_state(app_client)
  294. url = 'retrieve'
  295. body_data = {'mode': 'watchpoint'}
  296. expect_file = 'recommended_watchpoints_at_startup.json'
  297. send_and_compare_result(app_client, url, body_data, expect_file, method='post')
  298. send_terminate_cmd(app_client)
  299. finally:
  300. settings.ENABLE_RECOMMENDED_WATCHPOINTS = original_value
  301. @pytest.mark.level0
  302. @pytest.mark.env_single
  303. @pytest.mark.platform_x86_cpu
  304. @pytest.mark.platform_arm_ascend_training
  305. @pytest.mark.platform_x86_gpu_training
  306. @pytest.mark.platform_x86_ascend_training
  307. @pytest.mark.parametrize("body_data, expect_file", [
  308. ({'tensor_name': 'Default/TransData-op99:0', 'graph_name': 'graph_0'}, 'retrieve_tensor_graph-0.json'),
  309. ({'tensor_name': 'Default/optimizer-Momentum/Parameter[18]_7/moments.fc1.bias:0', 'graph_name': 'graph_0'},
  310. 'retrieve_tensor_graph-1.json')
  311. ])
  312. def test_retrieve_tensor_graph(self, app_client, body_data, expect_file):
  313. """Test retrieve tensor graph."""
  314. url = 'tensor-graphs'
  315. with self._debugger_client.get_thread_instance():
  316. create_watchpoint_and_wait(app_client)
  317. get_request_result(app_client, url, body_data, method='GET')
  318. # check full tensor history from poll data
  319. res = get_request_result(
  320. app_client=app_client, url='poll-data', body_data={'pos': 0}, method='get')
  321. assert res.get('receive_tensor', {}).get('tensor_name') == body_data.get('tensor_name')
  322. send_and_compare_result(app_client, url, body_data, expect_file, method='GET')
  323. send_terminate_cmd(app_client)
  324. class TestGPUDebugger:
  325. """Test debugger on Ascend backend."""
  326. @classmethod
  327. def setup_class(cls):
  328. """Setup class."""
  329. cls._debugger_client = MockDebuggerClient(backend='GPU')
  330. @pytest.mark.level0
  331. @pytest.mark.env_single
  332. @pytest.mark.platform_x86_cpu
  333. @pytest.mark.platform_arm_ascend_training
  334. @pytest.mark.platform_x86_gpu_training
  335. @pytest.mark.platform_x86_ascend_training
  336. def test_next_node_on_gpu(self, app_client):
  337. """Test get next node on GPU."""
  338. gpu_debugger_client = MockDebuggerClient(backend='GPU')
  339. with gpu_debugger_client.get_thread_instance():
  340. check_state(app_client)
  341. # send run command to get watchpoint hit
  342. url = 'control'
  343. body_data = {'mode': 'continue',
  344. 'level': 'node',
  345. 'name': 'Default/TransData-op99'}
  346. res = get_request_result(app_client, url, body_data)
  347. assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}}
  348. # get metadata
  349. check_state(app_client)
  350. url = 'retrieve'
  351. body_data = {'mode': 'all'}
  352. expect_file = 'retrieve_next_node_on_gpu.json'
  353. send_and_compare_result(app_client, url, body_data, expect_file)
  354. send_terminate_cmd(app_client)
  355. @pytest.mark.level0
  356. @pytest.mark.env_single
  357. @pytest.mark.platform_x86_cpu
  358. @pytest.mark.platform_arm_ascend_training
  359. @pytest.mark.platform_x86_gpu_training
  360. @pytest.mark.platform_x86_ascend_training
  361. @pytest.mark.parametrize("url, body_data, enable_recheck", [
  362. ('create-watchpoint',
  363. {'condition': {'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]},
  364. 'watch_nodes': ['Default']}, True),
  365. ('create-watchpoint',
  366. {'condition': {'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]},
  367. 'watch_nodes': ['Default/TransData-op99']}, True),
  368. ('update-watchpoint',
  369. {'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum/Parameter[18]_7'],
  370. 'mode': 1}, True),
  371. ('update-watchpoint',
  372. {'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum'],
  373. 'mode': 1}, True),
  374. ('update-watchpoint',
  375. [{'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum'],
  376. 'mode': 1},
  377. {'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum'],
  378. 'mode': 0}
  379. ], True),
  380. ('update-watchpoint',
  381. [{'watch_point_id': 1, 'watch_nodes': ['Default/TransData-op99'],
  382. 'mode': 1},
  383. {'watch_point_id': 1, 'watch_nodes': ['Default/TransData-op99'],
  384. 'mode': 0}
  385. ], True),
  386. ('delete-watchpoint', {'watch_point_id': 1}, True)
  387. ])
  388. def test_recheck_state(self, app_client, url, body_data, enable_recheck):
  389. """Test update watchpoint and check the value of enable_recheck."""
  390. with self._debugger_client.get_thread_instance():
  391. create_watchpoint_and_wait(app_client)
  392. if not isinstance(body_data, list):
  393. body_data = [body_data]
  394. for sub_body_data in body_data:
  395. res = get_request_result(app_client, url, sub_body_data, method='post')
  396. assert res['metadata']['enable_recheck'] is enable_recheck
  397. send_terminate_cmd(app_client)
  398. def test_get_conditions(self, app_client):
  399. """Test get conditions for gpu."""
  400. url = '/v1/mindinsight/debugger/sessions/0/condition-collections'
  401. body_data = {}
  402. expect_file = 'get_conditions_for_gpu.json'
  403. with self._debugger_client.get_thread_instance():
  404. check_state(app_client)
  405. send_and_compare_result(app_client, url, body_data, expect_file, method='get', full_url=True)
  406. send_terminate_cmd(app_client)
  407. @pytest.mark.level0
  408. @pytest.mark.env_single
  409. @pytest.mark.platform_x86_cpu
  410. @pytest.mark.platform_arm_ascend_training
  411. @pytest.mark.platform_x86_gpu_training
  412. @pytest.mark.platform_x86_ascend_training
  413. def test_recheck(self, app_client):
  414. """Test recheck request."""
  415. with self._debugger_client.get_thread_instance():
  416. create_watchpoint_and_wait(app_client)
  417. # send recheck when disable to do recheck
  418. get_request_result(app_client, 'recheck', {}, method='post', expect_code=400)
  419. # send recheck when enable to do recheck
  420. create_watchpoint(app_client, {'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]}, 2)
  421. res = get_request_result(app_client, 'recheck', {}, method='post')
  422. assert res['metadata']['enable_recheck'] is False
  423. send_terminate_cmd(app_client)
  424. @pytest.mark.level0
  425. @pytest.mark.env_single
  426. @pytest.mark.platform_x86_cpu
  427. @pytest.mark.platform_arm_ascend_training
  428. @pytest.mark.platform_x86_gpu_training
  429. @pytest.mark.platform_x86_ascend_training
  430. @pytest.mark.parametrize("filter_condition, expect_file", [
  431. ({'name': 'fc', 'node_category': 'weight'}, 'search_weight.json'),
  432. ({'name': 'fc', 'node_category': 'gradient'}, 'search_gradient.json'),
  433. ({'node_category': 'activation'}, 'search_activation.json')
  434. ])
  435. def test_search_by_category(self, app_client, filter_condition, expect_file):
  436. """Test recheck request."""
  437. with self._debugger_client.get_thread_instance():
  438. check_state(app_client)
  439. send_and_compare_result(app_client, 'search', filter_condition, expect_file,
  440. method='get')
  441. send_terminate_cmd(app_client)
  442. class TestMultiGraphDebugger:
  443. """Test debugger on Ascend backend for multi_graph."""
  444. @classmethod
  445. def setup_class(cls):
  446. """Setup class."""
  447. cls._debugger_client = MockDebuggerClient(backend='Ascend', graph_num=2)
  448. @pytest.mark.level0
  449. @pytest.mark.env_single
  450. @pytest.mark.platform_x86_cpu
  451. @pytest.mark.platform_arm_ascend_training
  452. @pytest.mark.platform_x86_gpu_training
  453. @pytest.mark.platform_x86_ascend_training
  454. @pytest.mark.parametrize("body_data, expect_file", [
  455. ({'mode': 'all'}, 'multi_retrieve_all.json'),
  456. ({'mode': 'node', 'params': {'name': 'Default', 'graph_name': 'graph_1'}}, 'retrieve_scope_node.json'),
  457. ({'mode': 'node', 'params': {'name': 'graph_0'}}, 'multi_retrieve_scope_node.json'),
  458. ({'mode': 'node', 'params': {'name': 'graph_0/Default/optimizer-Momentum/Parameter[18]_7'}},
  459. 'multi_retrieve_aggregation_scope_node.json'),
  460. ({'mode': 'node', 'params': {
  461. 'name': 'graph_0/Default/TransData-op99',
  462. 'single_node': True}}, 'multi_retrieve_single_node.json'),
  463. ({'mode': 'node', 'params': {
  464. 'name': 'Default/TransData-op99',
  465. 'single_node': True, 'graph_name': 'graph_0'}}, 'retrieve_single_node.json')
  466. ])
  467. def test_multi_retrieve_when_train_begin(self, app_client, body_data, expect_file):
  468. """Test retrieve when train_begin."""
  469. url = 'retrieve'
  470. with self._debugger_client.get_thread_instance():
  471. check_state(app_client)
  472. send_and_compare_result(app_client, url, body_data, expect_file)
  473. send_terminate_cmd(app_client)
  474. @pytest.mark.level0
  475. @pytest.mark.env_single
  476. @pytest.mark.platform_x86_cpu
  477. @pytest.mark.platform_arm_ascend_training
  478. @pytest.mark.platform_x86_gpu_training
  479. @pytest.mark.platform_x86_ascend_training
  480. @pytest.mark.parametrize("filter_condition, expect_file", [
  481. ({'name': '', 'node_category': 'weight'}, 'search_weight_multi_graph.json'),
  482. ({'node_category': 'activation'}, 'search_activation_multi_graph.json'),
  483. ({'node_category': 'gradient'}, 'search_gradient_multi_graph.json')
  484. ])
  485. def test_search_by_category_with_multi_graph(self, app_client, filter_condition, expect_file):
  486. """Test search by category request."""
  487. with self._debugger_client.get_thread_instance():
  488. check_state(app_client)
  489. send_and_compare_result(app_client, 'search', filter_condition, expect_file, method='get')
  490. send_terminate_cmd(app_client)
  491. @pytest.mark.level0
  492. @pytest.mark.env_single
  493. @pytest.mark.platform_x86_cpu
  494. @pytest.mark.platform_arm_ascend_training
  495. @pytest.mark.platform_x86_gpu_training
  496. @pytest.mark.platform_x86_ascend_training
  497. @pytest.mark.parametrize("filter_condition, expect_id", [
  498. ({'condition': {'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]},
  499. 'watch_nodes': ['Default/optimizer-Momentum/Parameter[18]_7'],
  500. 'graph_name': 'graph_0'}, 1),
  501. ({'condition': {'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]},
  502. 'watch_nodes': ['graph_0/Default/optimizer-Momentum/ApplyMomentum[8]_1'],
  503. 'graph_name': None}, 1)
  504. ])
  505. def test_create_watchpoint(self, app_client, filter_condition, expect_id):
  506. """Test create watchpoint with multiple graphs."""
  507. url = 'create-watchpoint'
  508. with self._debugger_client.get_thread_instance():
  509. check_state(app_client)
  510. res = get_request_result(app_client, url, filter_condition)
  511. assert res.get('id') == expect_id
  512. send_terminate_cmd(app_client)
  513. @pytest.mark.level0
  514. @pytest.mark.env_single
  515. @pytest.mark.platform_x86_cpu
  516. @pytest.mark.platform_arm_ascend_training
  517. @pytest.mark.platform_x86_gpu_training
  518. @pytest.mark.platform_x86_ascend_training
  519. @pytest.mark.parametrize("params, expect_file", [
  520. ({'level': 'node'}, 'multi_next_node.json'),
  521. ({'level': 'node', 'node_name': 'graph_0/Default/TransData-op99'}, 'multi_next_node.json'),
  522. ({'level': 'node', 'node_name': 'Default/TransData-op99', 'graph_name': 'graph_0'},
  523. 'multi_next_node.json')
  524. ])
  525. def test_continue_on_gpu(self, app_client, params, expect_file):
  526. """Test get next node on GPU."""
  527. gpu_debugger_client = MockDebuggerClient(backend='GPU', graph_num=2)
  528. original_value = settings.ENABLE_RECOMMENDED_WATCHPOINTS
  529. settings.ENABLE_RECOMMENDED_WATCHPOINTS = True
  530. try:
  531. with gpu_debugger_client.get_thread_instance():
  532. check_state(app_client)
  533. # send run command to get watchpoint hit
  534. url = 'control'
  535. body_data = {'mode': 'continue'}
  536. body_data.update(params)
  537. res = get_request_result(app_client, url, body_data)
  538. assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}}
  539. # get metadata
  540. check_state(app_client)
  541. url = 'retrieve'
  542. body_data = {'mode': 'all'}
  543. send_and_compare_result(app_client, url, body_data, expect_file)
  544. send_terminate_cmd(app_client)
  545. finally:
  546. settings.ENABLE_RECOMMENDED_WATCHPOINTS = original_value
  547. @pytest.mark.level0
  548. @pytest.mark.env_single
  549. @pytest.mark.platform_x86_cpu
  550. @pytest.mark.platform_arm_ascend_training
  551. @pytest.mark.platform_x86_gpu_training
  552. @pytest.mark.platform_x86_ascend_training
  553. @pytest.mark.parametrize("body_data, expect_file", [
  554. ({'tensor_name': 'Default/TransData-op99:0', 'graph_name': 'graph_0'}, 'retrieve_tensor_hits-0.json'),
  555. ({'tensor_name': 'Default/optimizer-Momentum/Parameter[18]_7/moments.fc1.bias:0', 'graph_name': 'graph_0'},
  556. 'retrieve_tensor_hits-1.json')
  557. ])
  558. def test_retrieve_tensor_hits(self, app_client, body_data, expect_file):
  559. """Test retrieve tensor graph."""
  560. url = 'tensor-hits'
  561. with self._debugger_client.get_thread_instance():
  562. check_state(app_client)
  563. send_and_compare_result(app_client, url, body_data, expect_file, method='GET')
  564. send_terminate_cmd(app_client)
  565. def create_watchpoint(app_client, condition, expect_id):
  566. """Create watchpoint."""
  567. url = 'create-watchpoint'
  568. body_data = {'condition': condition,
  569. 'watch_nodes': ['Default/optimizer-Momentum/Parameter[18]_7',
  570. 'Default/optimizer-Momentum/Parameter[18]_7/moments.fc3.bias',
  571. 'Default/optimizer-Momentum/Parameter[18]_7/moments.fc1.bias',
  572. 'Default/TransData-op99']}
  573. res = get_request_result(app_client, url, body_data)
  574. assert res.get('id') == expect_id
  575. def create_watchpoint_and_wait(app_client):
  576. """Preparation for recheck."""
  577. check_state(app_client)
  578. create_watchpoint(app_client, condition={'id': 'tensor_too_large', 'params': [{'name': 'max_gt', 'value': 1.0}]},
  579. expect_id=1)
  580. # send run command to get watchpoint hit
  581. url = 'control'
  582. body_data = {'mode': 'continue',
  583. 'steps': 2}
  584. res = get_request_result(app_client, url, body_data)
  585. assert res == {'metadata': {'state': 'sending', 'enable_recheck': False}}
  586. # wait for server has received watchpoint hit
  587. check_state(app_client)
  588. class TestMismatchDebugger:
  589. """Test debugger when Mindinsight and Mindspore is mismatched."""
  590. @classmethod
  591. def setup_class(cls):
  592. """Setup class."""
  593. cls._debugger_client = MockDebuggerClient(backend='Ascend', ms_version='1.0.0')
  594. @pytest.mark.level0
  595. @pytest.mark.env_single
  596. @pytest.mark.platform_x86_cpu
  597. @pytest.mark.platform_arm_ascend_training
  598. @pytest.mark.platform_x86_gpu_training
  599. @pytest.mark.platform_x86_ascend_training
  600. @pytest.mark.parametrize("body_data, expect_file", [
  601. ({'mode': 'all'}, 'version_mismatch.json')
  602. ])
  603. def test_retrieve_when_version_mismatch(self, app_client, body_data, expect_file):
  604. """Test retrieve when train_begin."""
  605. url = 'retrieve'
  606. with self._debugger_client.get_thread_instance():
  607. check_state(app_client, ServerStatus.MISMATCH.value)
  608. send_and_compare_result(app_client, url, body_data, expect_file)
  609. send_terminate_cmd(app_client)