Browse Source

!1001 enable recheck on step 0

From: @yelihua
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
6610462fa4
5 changed files with 31 additions and 14 deletions
  1. +3
    -4
      mindinsight/debugger/stream_cache/tensor.py
  2. +1
    -1
      mindinsight/debugger/stream_handler/metadata_handler.py
  3. +20
    -9
      mindinsight/debugger/stream_handler/watchpoint_handler.py
  4. +5
    -0
      mindinsight/debugger/stream_operator/training_control_operator.py
  5. +2
    -0
      tests/st/func/debugger/mock_ms_client.py

+ 3
- 4
mindinsight/debugger/stream_cache/tensor.py View File

@@ -247,8 +247,7 @@ class OpTensor(BaseTensor):

class ConstTensor(BaseTensor):
"""Tensor data structure for Const Node."""
STRING_TYPE = 'DT_STRING'
BOOL_TYPE = 'DT_BOOL'
_STRING_TYPE = 'DT_STRING'

def __init__(self, const_proto):
# the type of const_proto is NamedValueProto
@@ -298,7 +297,7 @@ class ConstTensor(BaseTensor):
if field_name != 'dtype':
tensor_value = field_value
break
if tensor_value and self.dtype != self.STRING_TYPE:
if tensor_value and self.dtype != self._STRING_TYPE:
tensor_value = np.array(tensor_value, dtype=NUMPY_TYPE_MAP.get(self.dtype))
return tensor_value

@@ -323,7 +322,7 @@ class ConstTensor(BaseTensor):
Returns:
dict, overall statistics.
"""
if self.empty or self.dtype == self.STRING_TYPE:
if self.empty or self.dtype == self._STRING_TYPE:
return {}
stats = TensorUtils.get_statistics_from_tensor(self.value)
statistics = TensorUtils.get_overall_statistic_dict(stats)


+ 1
- 1
mindinsight/debugger/stream_handler/metadata_handler.py View File

@@ -109,7 +109,7 @@ class MetadataHandler(StreamHandlerBase):
@property
def enable_recheck(self):
"""The property of enable_recheck."""
return self._enable_recheck and self._state == ServerStatus.WAITING and self._step > 0
return self._enable_recheck and self._state == ServerStatus.WAITING

@enable_recheck.setter
def enable_recheck(self, value):


+ 20
- 9
mindinsight/debugger/stream_handler/watchpoint_handler.py View File

@@ -557,15 +557,7 @@ def validate_watch_condition_params(condition_mgr, watch_condition):
raise DebuggerParamValueError("Invalid name of parameter.")

condition_param = condition.get_parameter_definition(condition_param_name)
if condition_param.type.name in (ValueTypeEnum.FLOAT64.name, ValueTypeEnum.INT64.name) \
and not isinstance(param.get("value"), (float, int)):
log.error("Number param should be given for condition: %s", condition_id)
raise DebuggerParamValueError("Number param should be given.")

if condition_param.type.name == ValueTypeEnum.BOOL.name \
and not isinstance(param.get("value"), bool):
log.error("Bool param should be given for condition: %s", condition_id)
raise DebuggerParamValueError("Bool param should be given.")
validate_param_type(condition_id, condition_param, param)

if not condition_param.is_valid(param.get("value")):
log.error("Param %s out of range for condition: %s", condition_param_name, condition_id)
@@ -596,6 +588,25 @@ def validate_watch_condition_params(condition_mgr, watch_condition):
raise DebuggerParamValueError("Invalid support params.")


def validate_param_type(condition_id, condition_param, param):
"""
Validate parameter type.

Args:
condition_id (str): Condition id. Should be in WATCHPOINT_CONDITION_MAPPING.
condition_param (ConditionParameter): Condition Parameter object.
param (dict): Condition parameter value.
"""
if condition_param.type.name in (ValueTypeEnum.FLOAT64.name, ValueTypeEnum.INT64.name) \
and not isinstance(param.get("value"), (float, int)):
log.error("Number param should be given for condition: %s", condition_id)
raise DebuggerParamValueError("Number param should be given.")
if condition_param.type.name == ValueTypeEnum.BOOL.name \
and not isinstance(param.get("value"), bool):
log.error("Bool param should be given for condition: %s", condition_id)
raise DebuggerParamValueError("Bool param should be given.")


def set_default_param(condition_mgr, watch_condition):
"""
Set default param.


+ 5
- 0
mindinsight/debugger/stream_operator/training_control_operator.py View File

@@ -99,7 +99,12 @@ class TrainingControlOperator:
try:
self._validate_continue_params(params)
event = self._construct_run_event(params)
# whether need to send recheck before continue, especially for initialization watchpoint
recheck_flag = bool(self._metadata_stream.step == 0 and self._watchpoint_stream.is_recheckable())
self._send_watchpoints()
if recheck_flag:
self._cache_store.put_command(self._construct_run_event({'level': 'recheck'}))
log.info("Send recheck command for initialization watchpoints before continue command.")
self._cache_store.put_command(event)
except MindInsightException as err:
log.error("Failed to send run event.")


+ 2
- 0
tests/st/func/debugger/mock_ms_client.py View File

@@ -14,6 +14,7 @@
# ============================================================================
"""Mocked MindSpore debugger client."""
from threading import Thread
from time import sleep

import grpc
import numpy as np
@@ -77,6 +78,7 @@ class MockDebuggerClient:
wait_flag = True
while self.flag and wait_flag:
if self._step > total_steps:
sleep(0.5)
self.send_metadata_cmd(training_done=True)
return
wait_flag = self._wait_cmd()


Loading…
Cancel
Save