Browse Source

!1080 fix sigmoid function and add reluv2 in recommended watchpoint condtions

From: @jiang-shuqiang
Reviewed-by: @wenkai_dist,@lilongfei15
Signed-off-by: @lilongfei15
pull/1080/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
7f69f76e4e
3 changed files with 8 additions and 3 deletions
  1. +1
    -0
      mindinsight/debugger/conditionmgr/condition.py
  2. +3
    -2
      mindinsight/debugger/conditionmgr/recommender.py
  3. +4
    -1
      mindinsight/debugger/debugger_server.py

+ 1
- 0
mindinsight/debugger/conditionmgr/condition.py View File

@@ -110,6 +110,7 @@ class ActivationFuncEnum(Enum):
TANH = 'tanh' TANH = 'tanh'
SIGMOID = 'sigmoid' SIGMOID = 'sigmoid'
RELU = 'relu' RELU = 'relu'
RELUV2 = 'reluv2'




class ConditionContext: class ConditionContext:


+ 3
- 2
mindinsight/debugger/conditionmgr/recommender.py View File

@@ -115,7 +115,8 @@ def recommend_watchpoints(condition_mgr: ConditionMgr, graph_stream, condition_c
_recommend_activation_range(merged_info, condition_mgr, watch_points, condition_context, _recommend_activation_range(merged_info, condition_mgr, watch_points, condition_context,
ActivationFuncEnum.SIGMOID.value) ActivationFuncEnum.SIGMOID.value)


merged_info = get_basic_node_info(TargetTypeEnum.ACTIVATION.value, graph_stream, ActivationFuncEnum.RELU.value)
merged_info = get_basic_node_info(TargetTypeEnum.ACTIVATION.value, graph_stream,
[ActivationFuncEnum.RELU.value, ActivationFuncEnum.RELUV2.value])
_recommend_activation_range(merged_info, condition_mgr, watch_points, condition_context, _recommend_activation_range(merged_info, condition_mgr, watch_points, condition_context,
ActivationFuncEnum.RELU.value) ActivationFuncEnum.RELU.value)
return watch_points return watch_points
@@ -415,7 +416,7 @@ def _add_graph_name(nodes, graph_stream):


def _sigmoid(value): def _sigmoid(value):
"""Calculate the sigmoid of value.""" """Calculate the sigmoid of value."""
return 1.0 / (1.0 + math.exp(value))
return 1.0 / (1.0 + math.exp(-value))




def _get_recommend_activation_params(condition, activation_func): def _get_recommend_activation_params(condition, activation_func):


+ 4
- 1
mindinsight/debugger/debugger_server.py View File

@@ -79,16 +79,19 @@ class DebuggerServer:
if not isinstance(set_recommended, bool): if not isinstance(set_recommended, bool):
log.error("Bool param should be given for set_recommended") log.error("Bool param should be given for set_recommended")
raise DebuggerParamValueError("Bool param should be given.") raise DebuggerParamValueError("Bool param should be given.")

metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA) metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
if metadata_stream.recommendation_confirmed: if metadata_stream.recommendation_confirmed:
log.error("User has confirmed setting recommended watchpoints") log.error("User has confirmed setting recommended watchpoints")
raise DebuggerSetRecommendWatchpointsError() raise DebuggerSetRecommendWatchpointsError()

metadata_stream.recommendation_confirmed = True
condition_context = ConditionContext(metadata_stream.backend, metadata_stream.step) condition_context = ConditionContext(metadata_stream.backend, metadata_stream.step)
log.debug("Train_id: %s, backend: %s", train_id, condition_context.backend) log.debug("Train_id: %s, backend: %s", train_id, condition_context.backend)
res = metadata_stream.get(['state', 'enable_recheck']) res = metadata_stream.get(['state', 'enable_recheck'])
if set_recommended: if set_recommended:
res['id'] = self._add_recommended_watchpoints(condition_context) res['id'] = self._add_recommended_watchpoints(condition_context)
metadata_stream.recommendation_confirmed = True
return res return res


def _add_recommended_watchpoints(self, condition_context): def _add_recommended_watchpoints(self, condition_context):


Loading…
Cancel
Save