You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_restful_api.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Function:
  17. Test query debugger restful api.
  18. Usage:
  19. pytest tests/st/func/debugger/test_restful_api.py
  20. """
  21. import os
  22. import pytest
  23. from tests.st.func.debugger.conftest import DEBUGGER_BASE_URL
  24. from tests.st.func.debugger.mock_ms_client import MockDebuggerClient
  25. from tests.st.func.debugger.utils import check_waiting_state, get_request_result, \
  26. send_and_compare_result
  27. class TestAscendDebugger:
  28. """Test debugger on Ascend backend."""
  29. @classmethod
  30. def setup_class(cls):
  31. """Setup class."""
  32. cls._debugger_client = MockDebuggerClient(backend='Ascend')
  33. @staticmethod
  34. def _send_terminate_cmd(app_client):
  35. """Send terminate command to debugger client."""
  36. url = os.path.join(DEBUGGER_BASE_URL, 'control')
  37. body_data = {'mode': 'terminate'}
  38. send_and_compare_result(app_client, url, body_data)
  39. @staticmethod
  40. def _create_watchpoint(app_client, condition, expect_id):
  41. """Create watchpoint."""
  42. url = 'create_watchpoint'
  43. body_data = {'condition': condition,
  44. 'watch_nodes': ['Default/optimizer-Momentum/Parameter[18]_7',
  45. 'Default/TransData-op99']}
  46. res = get_request_result(app_client, url, body_data)
  47. assert res.get('id') == expect_id
  48. @pytest.mark.level0
  49. @pytest.mark.env_single
  50. @pytest.mark.platform_x86_cpu
  51. @pytest.mark.platform_arm_ascend_training
  52. @pytest.mark.platform_x86_gpu_training
  53. @pytest.mark.platform_x86_ascend_training
  54. def test_before_train_begin(self, app_client):
  55. """Test retrieve all."""
  56. url = 'retrieve'
  57. body_data = {'mode': 'all'}
  58. expect_file = 'before_train_begin.json'
  59. send_and_compare_result(app_client, url, body_data, expect_file)
  60. @pytest.mark.level0
  61. @pytest.mark.env_single
  62. @pytest.mark.platform_x86_cpu
  63. @pytest.mark.platform_arm_ascend_training
  64. @pytest.mark.platform_x86_gpu_training
  65. @pytest.mark.platform_x86_ascend_training
  66. @pytest.mark.parametrize("body_data, expect_file", [
  67. ({'mode': 'all'}, 'retrieve_all.json'),
  68. ({'mode': 'node', 'params': {'name': 'Default'}}, 'retrieve_scope_node.json'),
  69. ({'mode': 'node', 'params': {'name': 'Default/optimizer-Momentum/Parameter[18]_7'}},
  70. 'retrieve_aggregation_scope_node.json'),
  71. ({'mode': 'node', 'params': {
  72. 'name': 'Default/TransData-op99',
  73. 'single_node': True}}, 'retrieve_single_node.json'),
  74. ({'mode': 'watchpoint_hit'}, 'retrieve_empty_watchpoint_hit_list')
  75. ])
  76. def test_retrieve_when_train_begin(self, app_client, body_data, expect_file):
  77. """Test retrieve when train_begin."""
  78. url = 'retrieve'
  79. with self._debugger_client.get_thread_instance():
  80. flag = check_waiting_state(app_client)
  81. assert flag is True
  82. send_and_compare_result(app_client, url, body_data, expect_file)
  83. self._send_terminate_cmd(app_client)
  84. @pytest.mark.level0
  85. @pytest.mark.env_single
  86. @pytest.mark.platform_x86_cpu
  87. @pytest.mark.platform_arm_ascend_training
  88. @pytest.mark.platform_x86_gpu_training
  89. @pytest.mark.platform_x86_ascend_training
  90. def test_create_and_delete_watchpoint(self, app_client):
  91. """Test create and delete watchpoint."""
  92. with self._debugger_client.get_thread_instance():
  93. flag = check_waiting_state(app_client)
  94. assert flag is True
  95. conditions = [
  96. {'condition': 'MAX_GT', 'param': 1.0},
  97. {'condition': 'MAX_LT', 'param': -1.0},
  98. {'condition': 'MIN_GT', 'param': 1e+32},
  99. {'condition': 'MIN_LT', 'param': -1e+32},
  100. {'condition': 'MAX_MIN_GT', 'param': 0},
  101. {'condition': 'MAX_MIN_LT', 'param': 0},
  102. {'condition': 'MEAN_GT', 'param': 0},
  103. {'condition': 'MEAN_LT', 'param': 0},
  104. {'condition': 'INF'},
  105. {'condition': 'OVERFLOW'},
  106. ]
  107. for idx, condition in enumerate(conditions):
  108. self._create_watchpoint(app_client, condition, idx + 1)
  109. # delete 4-th watchpoint
  110. url = 'delete_watchpoint'
  111. body_data = {'watch_point_id': 4}
  112. get_request_result(app_client, url, body_data)
  113. # test watchpoint list
  114. url = 'retrieve'
  115. body_data = {'mode': 'watchpoint'}
  116. expect_file = 'create_and_delete_watchpoint.json'
  117. send_and_compare_result(app_client, url, body_data, expect_file)
  118. self._send_terminate_cmd(app_client)
  119. @pytest.mark.level0
  120. @pytest.mark.env_single
  121. @pytest.mark.platform_x86_cpu
  122. @pytest.mark.platform_arm_ascend_training
  123. @pytest.mark.platform_x86_gpu_training
  124. @pytest.mark.platform_x86_ascend_training
  125. def test_update_watchpoint(self, app_client):
  126. """Test retrieve when train_begin."""
  127. watch_point_id = 1
  128. leaf_node_name = 'Default/optimizer-Momentum/Parameter[18]_7/moments.fc3.bias'
  129. with self._debugger_client.get_thread_instance():
  130. flag = check_waiting_state(app_client)
  131. assert flag is True
  132. condition = {'condition': 'INF'}
  133. self._create_watchpoint(app_client, condition, watch_point_id)
  134. # update watchpoint watchpoint list
  135. url = 'update_watchpoint'
  136. body_data = {'watch_point_id': watch_point_id,
  137. 'watch_nodes': [leaf_node_name],
  138. 'mode': 0}
  139. get_request_result(app_client, url, body_data)
  140. # get updated nodes
  141. url = 'search'
  142. body_data = {'name': leaf_node_name, 'watch_point_id': watch_point_id}
  143. expect_file = 'search_unwatched_leaf_node.json'
  144. send_and_compare_result(app_client, url, body_data, expect_file, method='get')
  145. self._send_terminate_cmd(app_client)
  146. @pytest.mark.level0
  147. @pytest.mark.env_single
  148. @pytest.mark.platform_x86_cpu
  149. @pytest.mark.platform_arm_ascend_training
  150. @pytest.mark.platform_x86_gpu_training
  151. @pytest.mark.platform_x86_ascend_training
  152. def test_watchpoint_hit(self, app_client):
  153. """Test retrieve watchpoint hit."""
  154. with self._debugger_client.get_thread_instance():
  155. flag = check_waiting_state(app_client)
  156. assert flag is True
  157. self._create_watchpoint(app_client, condition={'condition': 'INF'}, expect_id=1)
  158. # send run command to get watchpoint hit
  159. url = 'control'
  160. body_data = {'mode': 'continue',
  161. 'steps': 2}
  162. res = get_request_result(app_client, url, body_data)
  163. assert res == {'metadata': {'state': 'running'}}
  164. # wait for server has received watchpoint hit
  165. flag = check_waiting_state(app_client)
  166. assert flag is True
  167. # check watchpoint hit list
  168. url = 'retrieve'
  169. body_data = {'mode': 'watchpoint_hit'}
  170. expect_file = 'retrieve_watchpoint_hit.json'
  171. send_and_compare_result(app_client, url, body_data, expect_file)
  172. # check single watchpoint hit
  173. body_data = {
  174. 'mode': 'watchpoint_hit',
  175. 'params': {
  176. 'name': 'Default/TransData-op99',
  177. 'single_node': True,
  178. 'watch_point_id': 1
  179. }
  180. }
  181. expect_file = 'retrieve_single_watchpoint_hit.json'
  182. send_and_compare_result(app_client, url, body_data, expect_file)
  183. self._send_terminate_cmd(app_client)
  184. @pytest.mark.level0
  185. @pytest.mark.env_single
  186. @pytest.mark.platform_x86_cpu
  187. @pytest.mark.platform_arm_ascend_training
  188. @pytest.mark.platform_x86_gpu_training
  189. @pytest.mark.platform_x86_ascend_training
  190. def test_retrieve_tensor_value(self, app_client):
  191. """Test retrieve tensor value."""
  192. node_name = 'Default/TransData-op99'
  193. with self._debugger_client.get_thread_instance():
  194. flag = check_waiting_state(app_client)
  195. assert flag is True
  196. # prepare tensor value
  197. url = 'retrieve_tensor_history'
  198. body_data = {'name': node_name}
  199. expect_file = 'retrieve_empty_tensor_history.json'
  200. send_and_compare_result(app_client, url, body_data, expect_file)
  201. # check full tensor history from poll data
  202. res = get_request_result(
  203. app_client=app_client, url='poll_data', body_data={'pos': 0}, method='get')
  204. assert res.get('receive_tensor', {}).get('node_name') == node_name
  205. expect_file = 'retrieve_full_tensor_history.json'
  206. send_and_compare_result(app_client, url, body_data, expect_file)
  207. # check tensor value
  208. url = 'tensors'
  209. body_data = {
  210. 'name': node_name + ':0',
  211. 'detail': 'data',
  212. 'shape': '[1, 1:3]'
  213. }
  214. expect_file = 'retrieve_tensor_value.json'
  215. send_and_compare_result(app_client, url, body_data, expect_file, method='get')
  216. self._send_terminate_cmd(app_client)
  217. @pytest.mark.level0
  218. @pytest.mark.env_single
  219. @pytest.mark.platform_x86_cpu
  220. @pytest.mark.platform_arm_ascend_training
  221. @pytest.mark.platform_x86_gpu_training
  222. @pytest.mark.platform_x86_ascend_training
  223. def test_compare_tensor_value(self, app_client):
  224. """Test compare tensor value."""
  225. node_name = 'Default/args0'
  226. with self._debugger_client.get_thread_instance():
  227. flag = check_waiting_state(app_client)
  228. assert flag is True
  229. # prepare tensor values
  230. url = 'control'
  231. body_data = {'mode': 'continue',
  232. 'steps': 2}
  233. get_request_result(app_client, url, body_data)
  234. flag = check_waiting_state(app_client)
  235. assert flag is True
  236. get_request_result(
  237. app_client=app_client, url='retrieve_tensor_history', body_data={'name': node_name})
  238. res = get_request_result(
  239. app_client=app_client, url='poll_data', body_data={'pos': 0}, method='get')
  240. assert res.get('receive_tensor', {}).get('node_name') == node_name
  241. # get compare results
  242. url = 'tensor-comparisons'
  243. body_data = {
  244. 'name': node_name + ':0',
  245. 'detail': 'data',
  246. 'shape': '[:, :]',
  247. 'tolerance': 1
  248. }
  249. expect_file = 'compare_tensors.json'
  250. send_and_compare_result(app_client, url, body_data, expect_file, method='get')
  251. self._send_terminate_cmd(app_client)
  252. @pytest.mark.level0
  253. @pytest.mark.env_single
  254. @pytest.mark.platform_x86_cpu
  255. @pytest.mark.platform_arm_ascend_training
  256. @pytest.mark.platform_x86_gpu_training
  257. @pytest.mark.platform_x86_ascend_training
  258. @pytest.mark.parametrize("body_data, expect_file", [
  259. ({'ascend': True}, 'retrieve_node_by_bfs_ascend.json'),
  260. ({'name': 'Default/args0', 'ascend': False}, 'retrieve_node_by_bfs.json')
  261. ])
  262. def test_retrieve_bfs_node(self, app_client, body_data, expect_file):
  263. """Test retrieve bfs node."""
  264. with self._debugger_client.get_thread_instance():
  265. flag = check_waiting_state(app_client)
  266. assert flag is True
  267. # prepare tensor values
  268. url = 'retrieve_node_by_bfs'
  269. send_and_compare_result(app_client, url, body_data, expect_file, method='get')
  270. self._send_terminate_cmd(app_client)
  271. @pytest.mark.level0
  272. @pytest.mark.env_single
  273. @pytest.mark.platform_x86_cpu
  274. @pytest.mark.platform_arm_ascend_training
  275. @pytest.mark.platform_x86_gpu_training
  276. @pytest.mark.platform_x86_ascend_training
  277. def test_next_node_on_gpu(self, app_client):
  278. """Test get next node on GPU."""
  279. gpu_debugger_client = MockDebuggerClient(backend='GPU')
  280. with gpu_debugger_client.get_thread_instance():
  281. flag = check_waiting_state(app_client)
  282. assert flag is True
  283. # send run command to get watchpoint hit
  284. url = 'control'
  285. body_data = {'mode': 'continue',
  286. 'level': 'node',
  287. 'name': 'Default/TransData-op99'}
  288. res = get_request_result(app_client, url, body_data)
  289. assert res == {'metadata': {'state': 'running'}}
  290. # get metadata
  291. flag = check_waiting_state(app_client)
  292. assert flag is True
  293. url = 'retrieve'
  294. body_data = {'mode': 'all'}
  295. expect_file = 'retrieve_next_node_on_gpu.json'
  296. send_and_compare_result(app_client, url, body_data, expect_file)
  297. self._send_terminate_cmd(app_client)
  298. @pytest.mark.level0
  299. @pytest.mark.env_single
  300. @pytest.mark.platform_x86_cpu
  301. @pytest.mark.platform_arm_ascend_training
  302. @pytest.mark.platform_x86_gpu_training
  303. @pytest.mark.platform_x86_ascend_training
  304. def test_pause(self, app_client):
  305. """Test pause the training."""
  306. with self._debugger_client.get_thread_instance():
  307. flag = check_waiting_state(app_client)
  308. assert flag is True
  309. # send run command to execute to next node
  310. url = 'control'
  311. body_data = {'mode': 'continue',
  312. 'steps': -1}
  313. res = get_request_result(app_client, url, body_data)
  314. assert res == {'metadata': {'state': 'running'}}
  315. # send pause command
  316. url = 'control'
  317. body_data = {'mode': 'pause'}
  318. res = get_request_result(app_client, url, body_data)
  319. assert res == {'metadata': {'state': 'waiting'}}
  320. self._send_terminate_cmd(app_client)