|
|
|
@@ -23,9 +23,10 @@ import os |
|
|
|
import pytest |
|
|
|
|
|
|
|
from mindinsight.conf import settings |
|
|
|
from mindinsight.debugger.common.utils import ServerStatus |
|
|
|
from tests.st.func.debugger.conftest import DEBUGGER_BASE_URL |
|
|
|
from tests.st.func.debugger.mock_ms_client import MockDebuggerClient |
|
|
|
from tests.st.func.debugger.utils import check_waiting_state, get_request_result, \ |
|
|
|
from tests.st.func.debugger.utils import check_state, get_request_result, \ |
|
|
|
send_and_compare_result |
|
|
|
|
|
|
|
|
|
|
|
@@ -77,7 +78,7 @@ class TestAscendDebugger: |
|
|
|
"""Test retrieve when train_begin.""" |
|
|
|
url = 'retrieve' |
|
|
|
with self._debugger_client.get_thread_instance(): |
|
|
|
check_waiting_state(app_client) |
|
|
|
check_state(app_client) |
|
|
|
send_and_compare_result(app_client, url, body_data, expect_file) |
|
|
|
send_terminate_cmd(app_client) |
|
|
|
|
|
|
|
@@ -87,7 +88,7 @@ class TestAscendDebugger: |
|
|
|
body_data = {} |
|
|
|
expect_file = 'get_conditions_for_ascend.json' |
|
|
|
with self._debugger_client.get_thread_instance(): |
|
|
|
check_waiting_state(app_client) |
|
|
|
check_state(app_client) |
|
|
|
send_and_compare_result(app_client, url, body_data, expect_file, method='get', full_url=True) |
|
|
|
send_terminate_cmd(app_client) |
|
|
|
|
|
|
|
@@ -115,7 +116,7 @@ class TestAscendDebugger: |
|
|
|
url = 'retrieve' |
|
|
|
debugger_client = MockDebuggerClient(backend='Ascend', graph_num=2) |
|
|
|
with debugger_client.get_thread_instance(): |
|
|
|
check_waiting_state(app_client) |
|
|
|
check_state(app_client) |
|
|
|
send_and_compare_result(app_client, url, body_data, expect_file) |
|
|
|
send_terminate_cmd(app_client) |
|
|
|
|
|
|
|
@@ -128,7 +129,7 @@ class TestAscendDebugger: |
|
|
|
def test_create_and_delete_watchpoint(self, app_client): |
|
|
|
"""Test create and delete watchpoint.""" |
|
|
|
with self._debugger_client.get_thread_instance(): |
|
|
|
check_waiting_state(app_client) |
|
|
|
check_state(app_client) |
|
|
|
conditions = [ |
|
|
|
{'id': 'max_gt', 'params': [{'name': 'param', 'value': 1.0}]}, |
|
|
|
{'id': 'max_lt', 'params': [{'name': 'param', 'value': -1.0}]}, |
|
|
|
@@ -165,7 +166,7 @@ class TestAscendDebugger: |
|
|
|
watch_point_id = 1 |
|
|
|
leaf_node_name = 'Default/optimizer-Momentum/Parameter[18]_7/moments.fc3.bias' |
|
|
|
with self._debugger_client.get_thread_instance(): |
|
|
|
check_waiting_state(app_client) |
|
|
|
check_state(app_client) |
|
|
|
condition = {'id': 'inf', 'params': []} |
|
|
|
create_watchpoint(app_client, condition, watch_point_id) |
|
|
|
# update watchpoint watchpoint list |
|
|
|
@@ -219,7 +220,7 @@ class TestAscendDebugger: |
|
|
|
"""Test retrieve tensor value.""" |
|
|
|
node_name = 'Default/TransData-op99' |
|
|
|
with self._debugger_client.get_thread_instance(): |
|
|
|
check_waiting_state(app_client) |
|
|
|
check_state(app_client) |
|
|
|
# prepare tensor value |
|
|
|
url = 'retrieve_tensor_history' |
|
|
|
body_data = {'name': node_name} |
|
|
|
@@ -252,13 +253,13 @@ class TestAscendDebugger: |
|
|
|
"""Test compare tensor value.""" |
|
|
|
node_name = 'Default/args0' |
|
|
|
with self._debugger_client.get_thread_instance(): |
|
|
|
check_waiting_state(app_client) |
|
|
|
check_state(app_client) |
|
|
|
# prepare tensor values |
|
|
|
url = 'control' |
|
|
|
body_data = {'mode': 'continue', |
|
|
|
'steps': 2} |
|
|
|
get_request_result(app_client, url, body_data) |
|
|
|
check_waiting_state(app_client) |
|
|
|
check_state(app_client) |
|
|
|
get_request_result( |
|
|
|
app_client=app_client, url='retrieve_tensor_history', body_data={'name': node_name}) |
|
|
|
res = get_request_result( |
|
|
|
@@ -289,7 +290,7 @@ class TestAscendDebugger: |
|
|
|
def test_retrieve_bfs_node(self, app_client, body_data, expect_file): |
|
|
|
"""Test retrieve bfs node.""" |
|
|
|
with self._debugger_client.get_thread_instance(): |
|
|
|
check_waiting_state(app_client) |
|
|
|
check_state(app_client) |
|
|
|
# prepare tensor values |
|
|
|
url = 'retrieve_node_by_bfs' |
|
|
|
send_and_compare_result(app_client, url, body_data, expect_file, method='get') |
|
|
|
@@ -304,7 +305,7 @@ class TestAscendDebugger: |
|
|
|
def test_pause(self, app_client): |
|
|
|
"""Test pause the training.""" |
|
|
|
with self._debugger_client.get_thread_instance(): |
|
|
|
check_waiting_state(app_client) |
|
|
|
check_state(app_client) |
|
|
|
# send run command to execute to next node |
|
|
|
url = 'control' |
|
|
|
body_data = {'mode': 'continue', |
|
|
|
@@ -357,7 +358,7 @@ class TestAscendDebugger: |
|
|
|
settings.ENABLE_RECOMMENDED_WATCHPOINTS = True |
|
|
|
try: |
|
|
|
with self._debugger_client.get_thread_instance(): |
|
|
|
check_waiting_state(app_client) |
|
|
|
check_state(app_client) |
|
|
|
url = 'retrieve' |
|
|
|
body_data = {'mode': 'watchpoint'} |
|
|
|
expect_file = 'recommended_watchpoints_at_startup.json' |
|
|
|
@@ -409,7 +410,7 @@ class TestGPUDebugger: |
|
|
|
"""Test get next node on GPU.""" |
|
|
|
gpu_debugger_client = MockDebuggerClient(backend='GPU') |
|
|
|
with gpu_debugger_client.get_thread_instance(): |
|
|
|
check_waiting_state(app_client) |
|
|
|
check_state(app_client) |
|
|
|
# send run command to get watchpoint hit |
|
|
|
url = 'control' |
|
|
|
body_data = {'mode': 'continue', |
|
|
|
@@ -418,7 +419,7 @@ class TestGPUDebugger: |
|
|
|
res = get_request_result(app_client, url, body_data) |
|
|
|
assert res == {'metadata': {'state': 'running', 'enable_recheck': False}} |
|
|
|
# get metadata |
|
|
|
check_waiting_state(app_client) |
|
|
|
check_state(app_client) |
|
|
|
url = 'retrieve' |
|
|
|
body_data = {'mode': 'all'} |
|
|
|
expect_file = 'retrieve_next_node_on_gpu.json' |
|
|
|
@@ -475,7 +476,7 @@ class TestGPUDebugger: |
|
|
|
body_data = {} |
|
|
|
expect_file = 'get_conditions_for_gpu.json' |
|
|
|
with self._debugger_client.get_thread_instance(): |
|
|
|
check_waiting_state(app_client) |
|
|
|
check_state(app_client) |
|
|
|
send_and_compare_result(app_client, url, body_data, expect_file, method='get', full_url=True) |
|
|
|
send_terminate_cmd(app_client) |
|
|
|
|
|
|
|
@@ -512,7 +513,7 @@ class TestGPUDebugger: |
|
|
|
def test_search_by_category(self, app_client, filter_condition, expect_file): |
|
|
|
"""Test recheck request.""" |
|
|
|
with self._debugger_client.get_thread_instance(): |
|
|
|
check_waiting_state(app_client) |
|
|
|
check_state(app_client) |
|
|
|
send_and_compare_result(app_client, 'search', filter_condition, expect_file, |
|
|
|
method='get') |
|
|
|
send_terminate_cmd(app_client) |
|
|
|
@@ -549,7 +550,7 @@ class TestMultiGraphDebugger: |
|
|
|
"""Test retrieve when train_begin.""" |
|
|
|
url = 'retrieve' |
|
|
|
with self._debugger_client.get_thread_instance(): |
|
|
|
check_waiting_state(app_client) |
|
|
|
check_state(app_client) |
|
|
|
send_and_compare_result(app_client, url, body_data, expect_file) |
|
|
|
send_terminate_cmd(app_client) |
|
|
|
|
|
|
|
@@ -567,7 +568,7 @@ class TestMultiGraphDebugger: |
|
|
|
def test_search_by_category_with_multi_graph(self, app_client, filter_condition, expect_file): |
|
|
|
"""Test search by category request.""" |
|
|
|
with self._debugger_client.get_thread_instance(): |
|
|
|
check_waiting_state(app_client) |
|
|
|
check_state(app_client) |
|
|
|
send_and_compare_result(app_client, 'search', filter_condition, expect_file, method='get') |
|
|
|
send_terminate_cmd(app_client) |
|
|
|
|
|
|
|
@@ -589,7 +590,7 @@ class TestMultiGraphDebugger: |
|
|
|
"""Test create watchpoint with multiple graphs.""" |
|
|
|
url = 'create_watchpoint' |
|
|
|
with self._debugger_client.get_thread_instance(): |
|
|
|
check_waiting_state(app_client) |
|
|
|
check_state(app_client) |
|
|
|
res = get_request_result(app_client, url, filter_condition) |
|
|
|
assert res.get('id') == expect_id |
|
|
|
send_terminate_cmd(app_client) |
|
|
|
@@ -613,7 +614,7 @@ class TestMultiGraphDebugger: |
|
|
|
settings.ENABLE_RECOMMENDED_WATCHPOINTS = True |
|
|
|
try: |
|
|
|
with gpu_debugger_client.get_thread_instance(): |
|
|
|
check_waiting_state(app_client) |
|
|
|
check_state(app_client) |
|
|
|
# send run command to get watchpoint hit |
|
|
|
url = 'control' |
|
|
|
body_data = {'mode': 'continue'} |
|
|
|
@@ -621,7 +622,7 @@ class TestMultiGraphDebugger: |
|
|
|
res = get_request_result(app_client, url, body_data) |
|
|
|
assert res == {'metadata': {'state': 'running', 'enable_recheck': False}} |
|
|
|
# get metadata |
|
|
|
check_waiting_state(app_client) |
|
|
|
check_state(app_client) |
|
|
|
url = 'retrieve' |
|
|
|
body_data = {'mode': 'all'} |
|
|
|
send_and_compare_result(app_client, url, body_data, expect_file) |
|
|
|
@@ -644,7 +645,7 @@ class TestMultiGraphDebugger: |
|
|
|
"""Test retrieve tensor graph.""" |
|
|
|
url = 'tensor-hits' |
|
|
|
with self._debugger_client.get_thread_instance(): |
|
|
|
check_waiting_state(app_client) |
|
|
|
check_state(app_client) |
|
|
|
send_and_compare_result(app_client, url, body_data, expect_file, method='GET') |
|
|
|
send_terminate_cmd(app_client) |
|
|
|
|
|
|
|
@@ -663,7 +664,7 @@ def create_watchpoint(app_client, condition, expect_id): |
|
|
|
|
|
|
|
def create_watchpoint_and_wait(app_client): |
|
|
|
"""Preparation for recheck.""" |
|
|
|
check_waiting_state(app_client) |
|
|
|
check_state(app_client) |
|
|
|
create_watchpoint(app_client, condition={'id': 'inf', 'params': []}, expect_id=1) |
|
|
|
# send run command to get watchpoint hit |
|
|
|
url = 'control' |
|
|
|
@@ -672,7 +673,7 @@ def create_watchpoint_and_wait(app_client): |
|
|
|
res = get_request_result(app_client, url, body_data) |
|
|
|
assert res == {'metadata': {'state': 'running', 'enable_recheck': False}} |
|
|
|
# wait for server has received watchpoint hit |
|
|
|
check_waiting_state(app_client) |
|
|
|
check_state(app_client) |
|
|
|
|
|
|
|
class TestMismatchDebugger: |
|
|
|
"""Test debugger when Mindinsight and Mindspore is mismatched.""" |
|
|
|
@@ -695,5 +696,6 @@ class TestMismatchDebugger: |
|
|
|
"""Test retrieve when train_begin.""" |
|
|
|
url = 'retrieve' |
|
|
|
with self._debugger_client.get_thread_instance(): |
|
|
|
check_state(app_client, ServerStatus.MISMATCH.value) |
|
|
|
send_and_compare_result(app_client, url, body_data, expect_file) |
|
|
|
send_terminate_cmd(app_client) |