Browse Source

add 'name' in watchpoint to show a short description for recommended watchpoints

add a recommend activation_range watchpoint
fix the feature that auto-choose node for recommend watchpoints
tags/v1.1.0
jiangshuqiang 5 years ago
parent
commit
16d694e281
6 changed files with 155 additions and 27 deletions
  1. +8
    -2
      mindinsight/backend/conditionmgr/conditionmgr_api.py
  2. +7
    -0
      mindinsight/debugger/conditionmgr/condition.py
  3. +112
    -7
      mindinsight/debugger/conditionmgr/recommender.py
  4. +5
    -1
      mindinsight/debugger/debugger_server.py
  5. +20
    -15
      mindinsight/debugger/stream_cache/watchpoint.py
  6. +3
    -2
      mindinsight/debugger/stream_handler/watchpoint_handler.py

+ 8
- 2
mindinsight/backend/conditionmgr/conditionmgr_api.py View File

@@ -19,6 +19,7 @@ from flask import Blueprint, request

from mindinsight.conf import settings
from mindinsight.utils.exceptions import ParamValueError
from mindinsight.utils.exceptions import ParamMissError
from mindinsight.backend.debugger.debugger_api import BACKEND_SERVER, _wrap_reply

BLUEPRINT = Blueprint("conditionmgr", __name__,
@@ -42,12 +43,17 @@ def get_condition_collections(train_id):
@BLUEPRINT.route("/conditionmgr/train-jobs/<train_id>/set-recommended-watch-points", methods=["POST"])
def set_recommended_watch_points(train_id):
"""set recommended watch points."""
set_recommended = request.stream.read()
body = request.stream.read()
try:
set_recommended = json.loads(set_recommended if set_recommended else "{}")
body = json.loads(body if body else "{}")
except json.JSONDecodeError:
raise ParamValueError("Json data parse failed.")

request_body = body.get('requestBody')
if request_body is None:
raise ParamMissError('requestBody')

set_recommended = request_body.get('set_recommended')
reply = _wrap_reply(BACKEND_SERVER.set_recommended_watch_points, set_recommended, train_id)
return reply



+ 7
- 0
mindinsight/debugger/conditionmgr/condition.py View File

@@ -92,6 +92,13 @@ class ParamTypeEnum(Enum):
SUPPORT_PARAM = "SUPPORT_PARAM"


class ActivationFuncEnum(Enum):
"""Activation functions."""
TANH = 'Tanh'
SIGMOID = 'Sigmoid'
RELU = 'ReLU'


class ConditionContext:
"""
The class for condition context.


+ 112
- 7
mindinsight/debugger/conditionmgr/recommender.py View File

@@ -17,11 +17,13 @@ Predefined watchpoints.

This module predefine recommend watchpoints.
"""
import math
import queue as Queue

from mindinsight.debugger.conditionmgr.conditionmgr import ConditionMgr
from mindinsight.debugger.conditionmgr.condition import TargetTypeEnum
from mindinsight.debugger.conditionmgr.condition import ConditionIdEnum
from mindinsight.debugger.conditionmgr.condition import ActivationFuncEnum
from mindinsight.debugger.conditionmgr.common.utils import NodeBasicInfo
from mindinsight.debugger.conditionmgr.log import logger
from mindinsight.conf import settings
@@ -33,10 +35,18 @@ SELECTED_STATUS = 2


class _WatchPointData:
"""WatchPoint data container"""
def __init__(self, watch_condition, watch_nodes):
"""
WatchPoint data container

Args:
watch_condition (dict): The dict of watch conditions.
watch_nodes (list[NodeBasicInfo]): The list of node basic info.
name (str): The name of watchpoint.
"""
def __init__(self, watch_condition, watch_nodes, name):
self.watch_condition = watch_condition
self.watch_nodes = watch_nodes
self.name = name

def get_watch_condition_dict(self):
return {
@@ -99,6 +109,19 @@ def recommend_watchpoints(condition_mgr: ConditionMgr, graph_stream, condition_c
_recommend_overflow_ascend_chip(merged_info, condition_mgr, watch_points, condition_context)
_recommend_tensor_overflow(merged_info, condition_mgr, watch_points, condition_context)
_recommend_tensor_all_zero(merged_info, condition_mgr, watch_points, condition_context)

# add activation watch points
merged_info = get_basic_node_info(TargetTypeEnum.ACTIVATION.value, graph_stream, ActivationFuncEnum.TANH.value)
_recommend_activation_range(merged_info, condition_mgr, watch_points, condition_context,
ActivationFuncEnum.TANH.value)

merged_info = get_basic_node_info(TargetTypeEnum.ACTIVATION.value, graph_stream, ActivationFuncEnum.SIGMOID.value)
_recommend_activation_range(merged_info, condition_mgr, watch_points, condition_context,
ActivationFuncEnum.SIGMOID.value)

merged_info = get_basic_node_info(TargetTypeEnum.ACTIVATION.value, graph_stream, ActivationFuncEnum.RELU.value)
_recommend_activation_range(merged_info, condition_mgr, watch_points, condition_context,
ActivationFuncEnum.RELU.value)
return watch_points


@@ -118,6 +141,7 @@ def _recommend_tensor_all_zero(basic_info_nodes, condition_mgr, watch_points, co
)]
},
watch_nodes=basic_info_nodes.copy(),
name='recommend_tensor_all_zero_watchpoint'
)
watch_points.append(tensor_all_zero_watchpoint)

@@ -136,6 +160,7 @@ def _recommend_tensor_overflow(basic_info_nodes, condition_mgr, watch_points, co
"params": []
},
watch_nodes=basic_info_nodes.copy(),
name='recommend_tensor_overflow_watchpoint'
)
watch_points.append(overflow_watchpoint)

@@ -154,6 +179,7 @@ def _recommend_overflow_ascend_chip(basic_info_nodes, condition_mgr, watch_point
"params": []
},
watch_nodes=basic_info_nodes.copy(),
name='recommend_overflow_ascend_chip_watchpoint'
)
watch_points.append(overflow_d_watchpoint)

@@ -175,6 +201,7 @@ def _recommend_gradient_vanishing(basic_info_nodes, condition_mgr, watch_points,
)]
},
watch_nodes=basic_info_nodes.copy(),
name='recommend_gradient_vanishing_watchpoint'
)
watch_points.append(gradient_vanishing_watchpoint)

@@ -198,6 +225,7 @@ def _recommend_weight_change_too_small(condition_mgr, trainable_weight_nodes, wa
]
},
watch_nodes=trainable_weight_nodes,
name='recommend_weight_change_too_small_watchpoint'
)
watch_points.append(weight_change_too_small_watchpoint)

@@ -225,6 +253,7 @@ def _recommend_weight_not_changed(condition_mgr, trainable_weight_nodes, watch_p
]
},
watch_nodes=trainable_weight_nodes,
name='recommend_weight_not_changed_watchpoint'
)
watch_points.append(weight_no_change_watchpoint)

@@ -246,6 +275,7 @@ def _recommend_weight_change_too_large(basic_info_nodes, condition_mgr, watch_po
)]
},
watch_nodes=basic_info_nodes.copy(),
name='recommend_weight_change_too_large_watchpoint'
)
watch_points.append(weight_initialization_watchpoint)

@@ -267,21 +297,91 @@ def _recommend_weight_initialization(basic_info_nodes, condition_mgr, watch_poin
)]
},
watch_nodes=basic_info_nodes.copy(),
name='recommend_weight_initialization_watchpoint'
)
watch_points.append(weight_initialization_watchpoint)


def get_basic_node_info(node_category, graph_stream):
def _recommend_activation_range(basic_info_nodes, condition_mgr, watch_points, condition_context, activation_func):
"""Recommend activation range watchpoint."""
if not basic_info_nodes:
return
if not condition_mgr.has_condition(ConditionIdEnum.ACTIVATION_RANGE.value, condition_context):
return
condition = condition_mgr.get_condition(condition_id=ConditionIdEnum.ACTIVATION_RANGE.value)
params = []
if activation_func == ActivationFuncEnum.TANH.value:
# The recommend params for Tanh: The percentage of value in range (tanh(-8.8), tanh(8.8)) is lower than 50.0%
params = [
_ConditionParameterValue(
parameter=condition.get_parameter_definition("range_percentage_lt"),
value=50.0
),
_ConditionParameterValue(
parameter=condition.get_parameter_definition("range_start_inclusive"),
value=math.tanh(-8.8)
),
_ConditionParameterValue(
parameter=condition.get_parameter_definition("range_end_inclusive"),
value=math.tanh(8.8)
)]
if activation_func == ActivationFuncEnum.SIGMOID.value:
# The recommend params for Sigmoid:
# The percentage of value in range (sigmoid(-16.2)), sigmoid(16.2)) is lower than 50.0%
params = [
_ConditionParameterValue(
parameter=condition.get_parameter_definition("range_percentage_lt"),
value=50.0
),
_ConditionParameterValue(
parameter=condition.get_parameter_definition("range_start_inclusive"),
value=_sigmoid(-16.2)
),
_ConditionParameterValue(
parameter=condition.get_parameter_definition("range_end_inclusive"),
value=_sigmoid(16.2)
)]
if activation_func == ActivationFuncEnum.RELU.value:
# The recommend params for ReLU:
# The percentage of value in range (float('-inf'), 0) is greater than 50.0%
params = [
_ConditionParameterValue(
parameter=condition.get_parameter_definition("range_percentage_gt"),
value=50.0
),
_ConditionParameterValue(
parameter=condition.get_parameter_definition("range_start_inclusive"),
value=float('-inf')
),
_ConditionParameterValue(
parameter=condition.get_parameter_definition("range_end_inclusive"),
value=0
)]
activation_range_watchpoint = _WatchPointData(
watch_condition={
"condition": condition.id,
"params": params
},
watch_nodes=basic_info_nodes.copy(),
name='recommend_{}_activation_range_watchpoint'.format(activation_func.lower())
)
watch_points.append(activation_range_watchpoint)


def get_basic_node_info(node_category, graph_stream, activation_func=None):
"""Get node merged info."""
basic_info_nodes = _get_basic_node_info_by_node_category(node_category, graph_stream)
basic_info_nodes = _get_basic_node_info_by_node_category(node_category, graph_stream, activation_func)
merged_info = _merge_nodes(basic_info_nodes, graph_stream.whole_graph)
merged_info = _add_graph_name(merged_info, graph_stream)
return merged_info


def _get_basic_node_info_by_node_category(node_category, graph_stream):
def _get_basic_node_info_by_node_category(node_category, graph_stream, activation_func=None):
"""Get node basic info by node category."""
all_graph_nodes = graph_stream.get_searched_nodes(pattern={'node_category': node_category})
pattern = {'node_category': node_category}
if activation_func:
pattern['condition'] = {'activation_func': activation_func}
all_graph_nodes = graph_stream.get_searched_nodes(pattern)
basic_info_nodes = []
for graph_name, nodes in all_graph_nodes.items():
if len(all_graph_nodes) == 1:
@@ -329,7 +429,7 @@ def _merge_nodes(leaf_nodes, graph):
cur_node = watch_nodes.pop()
node_name = cur_node["name"]
sub_count = graph.normal_node_map.get(node_name).subnode_count
if len(cur_node["nodes"]) < sub_count or not cur_node["nodes"]:
if len(cur_node["nodes"]) < sub_count:
continue
is_all_chosen = True
for sub_node in cur_node["nodes"]:
@@ -362,3 +462,8 @@ def _add_graph_name(nodes, graph_stream):
full_name=node.name, graph_name=graph_name, node_name=node.name, node_type=node.type)
output_nodes.append(node_basic_info)
return output_nodes


def _sigmoid(value):
"""return sigmoid value"""
return 1.0 / (1.0 + math.exp(value))

+ 5
- 1
mindinsight/debugger/debugger_server.py View File

@@ -84,11 +84,14 @@ class DebuggerServer:

def set_recommended_watch_points(self, set_recommended, train_id):
"""set recommended watch points."""
if not isinstance(set_recommended, bool):
log.error("Bool param should be given for set_recommended")
raise DebuggerParamValueError("Bool param should be given.")
metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
condition_context = ConditionContext(metadata_stream.backend, metadata_stream.step, (1, 0))
log.debug("Train_id: %s, backend: %s", train_id, condition_context.backend)
res = metadata_stream.get(['state', 'enable_recheck'])
if set_recommended:
if set_recommended and not metadata_stream.recommendation_confirmed:
res['id'] = self._add_recommended_watchpoints(condition_context)
metadata_stream.recommendation_confirmed = True
return res
@@ -104,6 +107,7 @@ class DebuggerServer:
watch_points_id = watch_point_stream_handler.create_watchpoint(
watch_condition=watchpoint.get_watch_condition_dict(),
watch_nodes=watchpoint.watch_nodes,
name=watchpoint.name,
condition_mgr=self.condition_mgr
)
watch_points_ids.append(watch_points_id)


+ 20
- 15
mindinsight/debugger/stream_cache/watchpoint.py View File

@@ -24,33 +24,35 @@ from mindinsight.debugger.conditionmgr.condition import ConditionIdEnum
from mindinsight.debugger.proto.debug_grpc_pb2 import SetCMD, WatchCondition

WATCHPOINT_CONDITION_MAPPING = {
ConditionIdEnum.NAN.value: WatchCondition.Condition.nan,
ConditionIdEnum.ACTIVATION_RANGE.value: WatchCondition.Condition.tensor_range,
ConditionIdEnum.GRADIENT_EXPLODING.value: WatchCondition.Condition.tensor_general_overflow,
ConditionIdEnum.GRADIENT_TOO_LARGE.value: WatchCondition.Condition.tensor_too_large,
ConditionIdEnum.GRADIENT_VANISHING.value: WatchCondition.Condition.tensor_too_small,
ConditionIdEnum.INF.value: WatchCondition.Condition.inf,
ConditionIdEnum.OVERFLOW_ASCEND_CHIP.value: WatchCondition.Condition.overflow,
ConditionIdEnum.MAX_GT.value: WatchCondition.Condition.max_gt,
ConditionIdEnum.MAX_LT.value: WatchCondition.Condition.max_lt,
ConditionIdEnum.MIN_GT.value: WatchCondition.Condition.min_gt,
ConditionIdEnum.MIN_LT.value: WatchCondition.Condition.min_lt,
ConditionIdEnum.MAX_MIN_GT.value: WatchCondition.Condition.max_min_gt,
ConditionIdEnum.MAX_MIN_LT.value: WatchCondition.Condition.max_min_lt,
ConditionIdEnum.MEAN_GT.value: WatchCondition.Condition.mean_gt,
ConditionIdEnum.MEAN_LT.value: WatchCondition.Condition.mean_lt,
ConditionIdEnum.TENSOR_OVERFLOW.value: WatchCondition.Condition.tensor_general_overflow,
ConditionIdEnum.WEIGHT_OVERFLOW.value: WatchCondition.Condition.tensor_general_overflow,
ConditionIdEnum.MIN_GT.value: WatchCondition.Condition.min_gt,
ConditionIdEnum.MIN_LT.value: WatchCondition.Condition.min_lt,
ConditionIdEnum.NAN.value: WatchCondition.Condition.nan,
ConditionIdEnum.OPERATOR_OVERFLOW.value: WatchCondition.Condition.overflow,
ConditionIdEnum.OVERFLOW_ASCEND_CHIP.value: WatchCondition.Condition.overflow,
ConditionIdEnum.TENSOR_ALL_ZERO.value: WatchCondition.Condition.tensor_all_zero,
ConditionIdEnum.TENSOR_INITIALIZATION.value: WatchCondition.Condition.tensor_initialization,
ConditionIdEnum.WEIGHT_INITIALIZATION.value: WatchCondition.Condition.tensor_initialization,
ConditionIdEnum.TENSOR_OVERFLOW.value: WatchCondition.Condition.tensor_general_overflow,
ConditionIdEnum.TENSOR_RANGE.value: WatchCondition.Condition.tensor_range,
ConditionIdEnum.TENSOR_TOO_LARGE.value: WatchCondition.Condition.tensor_too_large,
ConditionIdEnum.WEIGHT_TOO_LARGE.value: WatchCondition.Condition.tensor_too_large,
ConditionIdEnum.GRADIENT_TOO_LARGE.value: WatchCondition.Condition.tensor_too_large,
ConditionIdEnum.GRADIENT_EXPLODING.value: WatchCondition.Condition.tensor_general_overflow,
ConditionIdEnum.TENSOR_TOO_SMALL.value: WatchCondition.Condition.tensor_too_small,
ConditionIdEnum.WEIGHT_TOO_SMALL.value: WatchCondition.Condition.tensor_too_small,
ConditionIdEnum.GRADIENT_VANISHING.value: WatchCondition.Condition.tensor_too_small,
ConditionIdEnum.TENSOR_ALL_ZERO.value: WatchCondition.Condition.tensor_all_zero,
ConditionIdEnum.WEIGHT_CHANGE_TOO_LARGE.value: WatchCondition.Condition.tensor_change_too_large,
ConditionIdEnum.WEIGHT_CHANGE_TOO_SMALL.value: WatchCondition.Condition.tensor_change_too_small,
ConditionIdEnum.WEIGHT_NOT_CHANGED.value: WatchCondition.Condition.tensor_not_changed
ConditionIdEnum.WEIGHT_INITIALIZATION.value: WatchCondition.Condition.tensor_initialization,
ConditionIdEnum.WEIGHT_NOT_CHANGED.value: WatchCondition.Condition.tensor_not_changed,
ConditionIdEnum.WEIGHT_OVERFLOW.value: WatchCondition.Condition.tensor_general_overflow,
ConditionIdEnum.WEIGHT_TOO_LARGE.value: WatchCondition.Condition.tensor_too_large,
ConditionIdEnum.WEIGHT_TOO_SMALL.value: WatchCondition.Condition.tensor_too_small
}


@@ -180,10 +182,11 @@ class Watchpoint:
- param (list[float]): Not defined yet.
"""

def __init__(self, watchpoint_id, watch_condition):
def __init__(self, watchpoint_id, watch_condition, name=None):
self._id = watchpoint_id
self._condition = watch_condition
self._watch_node = WatchNodeTree()
self.name = name

@property
def watchpoint_id(self):
@@ -308,6 +311,8 @@ class Watchpoint:
'id': self._id,
'watch_condition': self._condition
}
if self.name:
watchpoint_info['name'] = self.name
return watchpoint_info




+ 3
- 2
mindinsight/debugger/stream_handler/watchpoint_handler.py View File

@@ -215,7 +215,7 @@ class WatchpointHandler(StreamHandlerBase):

return state

def create_watchpoint(self, condition_mgr, watch_condition, watch_nodes=None, watch_point_id=None):
def create_watchpoint(self, condition_mgr, watch_condition, watch_nodes=None, watch_point_id=None, name=None):
"""
Create watchpoint.
Args:
@@ -234,6 +234,7 @@ class WatchpointHandler(StreamHandlerBase):
- param (list[dict]): The list of param for this condition.
watch_nodes (list[NodeBasicInfo]): The list of node basic info.
watch_point_id (int): The id of watchpoint.
name (str): The name of watchpoint.

Returns:
int, the new id of watchpoint.
@@ -241,7 +242,7 @@ class WatchpointHandler(StreamHandlerBase):
validate_watch_condition(condition_mgr, watch_condition)
watch_condition = set_default_param(condition_mgr, watch_condition)
new_id = self._latest_id + 1
watchpoint = Watchpoint(new_id, watch_condition)
watchpoint = Watchpoint(new_id, watch_condition, name)
if watch_nodes:
watchpoint.add_nodes(watch_nodes)
elif watch_point_id:


Loading…
Cancel
Save