From bf852b1cdc1f3097bdad2b721c37e6357bcffeae Mon Sep 17 00:00:00 2001 From: jiangshuqiang <962978787@qq.com> Date: Sat, 26 Dec 2020 16:27:59 +0800 Subject: [PATCH] fix sigmoid function and add reluv2 in recommended watchpoint condtions --- mindinsight/debugger/conditionmgr/condition.py | 1 + mindinsight/debugger/conditionmgr/recommender.py | 5 +++-- mindinsight/debugger/debugger_server.py | 5 ++++- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/mindinsight/debugger/conditionmgr/condition.py b/mindinsight/debugger/conditionmgr/condition.py index 936375b3..2f58ed84 100644 --- a/mindinsight/debugger/conditionmgr/condition.py +++ b/mindinsight/debugger/conditionmgr/condition.py @@ -110,6 +110,7 @@ class ActivationFuncEnum(Enum): TANH = 'tanh' SIGMOID = 'sigmoid' RELU = 'relu' + RELUV2 = 'reluv2' class ConditionContext: diff --git a/mindinsight/debugger/conditionmgr/recommender.py b/mindinsight/debugger/conditionmgr/recommender.py index b9952c0f..7f82d0a4 100644 --- a/mindinsight/debugger/conditionmgr/recommender.py +++ b/mindinsight/debugger/conditionmgr/recommender.py @@ -115,7 +115,8 @@ def recommend_watchpoints(condition_mgr: ConditionMgr, graph_stream, condition_c _recommend_activation_range(merged_info, condition_mgr, watch_points, condition_context, ActivationFuncEnum.SIGMOID.value) - merged_info = get_basic_node_info(TargetTypeEnum.ACTIVATION.value, graph_stream, ActivationFuncEnum.RELU.value) + merged_info = get_basic_node_info(TargetTypeEnum.ACTIVATION.value, graph_stream, + [ActivationFuncEnum.RELU.value, ActivationFuncEnum.RELUV2.value]) _recommend_activation_range(merged_info, condition_mgr, watch_points, condition_context, ActivationFuncEnum.RELU.value) return watch_points @@ -415,7 +416,7 @@ def _add_graph_name(nodes, graph_stream): def _sigmoid(value): """Calculate the sigmoid of value.""" - return 1.0 / (1.0 + math.exp(value)) + return 1.0 / (1.0 + math.exp(-value)) def _get_recommend_activation_params(condition, activation_func): diff --git a/mindinsight/debugger/debugger_server.py b/mindinsight/debugger/debugger_server.py index 5b794045..53009c19 100644 --- a/mindinsight/debugger/debugger_server.py +++ b/mindinsight/debugger/debugger_server.py @@ -79,16 +79,19 @@ class DebuggerServer: if not isinstance(set_recommended, bool): log.error("Bool param should be given for set_recommended") raise DebuggerParamValueError("Bool param should be given.") + metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA) if metadata_stream.recommendation_confirmed: log.error("User has confirmed setting recommended watchpoints") raise DebuggerSetRecommendWatchpointsError() + + metadata_stream.recommendation_confirmed = True condition_context = ConditionContext(metadata_stream.backend, metadata_stream.step) log.debug("Train_id: %s, backend: %s", train_id, condition_context.backend) res = metadata_stream.get(['state', 'enable_recheck']) if set_recommended: res['id'] = self._add_recommended_watchpoints(condition_context) - metadata_stream.recommendation_confirmed = True + return res def _add_recommended_watchpoints(self, condition_context):