add a recommend activation_range watchpoint fix the feature that auto-choose node for recommend watchpointstags/v1.1.0
| @@ -19,6 +19,7 @@ from flask import Blueprint, request | |||||
| from mindinsight.conf import settings | from mindinsight.conf import settings | ||||
| from mindinsight.utils.exceptions import ParamValueError | from mindinsight.utils.exceptions import ParamValueError | ||||
| from mindinsight.utils.exceptions import ParamMissError | |||||
| from mindinsight.backend.debugger.debugger_api import BACKEND_SERVER, _wrap_reply | from mindinsight.backend.debugger.debugger_api import BACKEND_SERVER, _wrap_reply | ||||
| BLUEPRINT = Blueprint("conditionmgr", __name__, | BLUEPRINT = Blueprint("conditionmgr", __name__, | ||||
| @@ -42,12 +43,17 @@ def get_condition_collections(train_id): | |||||
| @BLUEPRINT.route("/conditionmgr/train-jobs/<train_id>/set-recommended-watch-points", methods=["POST"]) | @BLUEPRINT.route("/conditionmgr/train-jobs/<train_id>/set-recommended-watch-points", methods=["POST"]) | ||||
| def set_recommended_watch_points(train_id): | def set_recommended_watch_points(train_id): | ||||
| """set recommended watch points.""" | """set recommended watch points.""" | ||||
| set_recommended = request.stream.read() | |||||
| body = request.stream.read() | |||||
| try: | try: | ||||
| set_recommended = json.loads(set_recommended if set_recommended else "{}") | |||||
| body = json.loads(body if body else "{}") | |||||
| except json.JSONDecodeError: | except json.JSONDecodeError: | ||||
| raise ParamValueError("Json data parse failed.") | raise ParamValueError("Json data parse failed.") | ||||
| request_body = body.get('requestBody') | |||||
| if request_body is None: | |||||
| raise ParamMissError('requestBody') | |||||
| set_recommended = request_body.get('set_recommended') | |||||
| reply = _wrap_reply(BACKEND_SERVER.set_recommended_watch_points, set_recommended, train_id) | reply = _wrap_reply(BACKEND_SERVER.set_recommended_watch_points, set_recommended, train_id) | ||||
| return reply | return reply | ||||
| @@ -92,6 +92,13 @@ class ParamTypeEnum(Enum): | |||||
| SUPPORT_PARAM = "SUPPORT_PARAM" | SUPPORT_PARAM = "SUPPORT_PARAM" | ||||
| class ActivationFuncEnum(Enum): | |||||
| """Activation functions.""" | |||||
| TANH = 'Tanh' | |||||
| SIGMOID = 'Sigmoid' | |||||
| RELU = 'ReLU' | |||||
| class ConditionContext: | class ConditionContext: | ||||
| """ | """ | ||||
| The class for condition context. | The class for condition context. | ||||
| @@ -17,11 +17,13 @@ Predefined watchpoints. | |||||
| This module predefine recommend watchpoints. | This module predefine recommend watchpoints. | ||||
| """ | """ | ||||
| import math | |||||
| import queue as Queue | import queue as Queue | ||||
| from mindinsight.debugger.conditionmgr.conditionmgr import ConditionMgr | from mindinsight.debugger.conditionmgr.conditionmgr import ConditionMgr | ||||
| from mindinsight.debugger.conditionmgr.condition import TargetTypeEnum | from mindinsight.debugger.conditionmgr.condition import TargetTypeEnum | ||||
| from mindinsight.debugger.conditionmgr.condition import ConditionIdEnum | from mindinsight.debugger.conditionmgr.condition import ConditionIdEnum | ||||
| from mindinsight.debugger.conditionmgr.condition import ActivationFuncEnum | |||||
| from mindinsight.debugger.conditionmgr.common.utils import NodeBasicInfo | from mindinsight.debugger.conditionmgr.common.utils import NodeBasicInfo | ||||
| from mindinsight.debugger.conditionmgr.log import logger | from mindinsight.debugger.conditionmgr.log import logger | ||||
| from mindinsight.conf import settings | from mindinsight.conf import settings | ||||
| @@ -33,10 +35,18 @@ SELECTED_STATUS = 2 | |||||
| class _WatchPointData: | class _WatchPointData: | ||||
| """WatchPoint data container""" | |||||
| def __init__(self, watch_condition, watch_nodes): | |||||
| """ | |||||
| WatchPoint data container | |||||
| Args: | |||||
| watch_condition (dict): The dict of watch conditions. | |||||
| watch_nodes (list[NodeBasicInfo]): The list of node basic info. | |||||
| name (str): The name of watchpoint. | |||||
| """ | |||||
| def __init__(self, watch_condition, watch_nodes, name): | |||||
| self.watch_condition = watch_condition | self.watch_condition = watch_condition | ||||
| self.watch_nodes = watch_nodes | self.watch_nodes = watch_nodes | ||||
| self.name = name | |||||
| def get_watch_condition_dict(self): | def get_watch_condition_dict(self): | ||||
| return { | return { | ||||
| @@ -99,6 +109,19 @@ def recommend_watchpoints(condition_mgr: ConditionMgr, graph_stream, condition_c | |||||
| _recommend_overflow_ascend_chip(merged_info, condition_mgr, watch_points, condition_context) | _recommend_overflow_ascend_chip(merged_info, condition_mgr, watch_points, condition_context) | ||||
| _recommend_tensor_overflow(merged_info, condition_mgr, watch_points, condition_context) | _recommend_tensor_overflow(merged_info, condition_mgr, watch_points, condition_context) | ||||
| _recommend_tensor_all_zero(merged_info, condition_mgr, watch_points, condition_context) | _recommend_tensor_all_zero(merged_info, condition_mgr, watch_points, condition_context) | ||||
| # add activation watch points | |||||
| merged_info = get_basic_node_info(TargetTypeEnum.ACTIVATION.value, graph_stream, ActivationFuncEnum.TANH.value) | |||||
| _recommend_activation_range(merged_info, condition_mgr, watch_points, condition_context, | |||||
| ActivationFuncEnum.TANH.value) | |||||
| merged_info = get_basic_node_info(TargetTypeEnum.ACTIVATION.value, graph_stream, ActivationFuncEnum.SIGMOID.value) | |||||
| _recommend_activation_range(merged_info, condition_mgr, watch_points, condition_context, | |||||
| ActivationFuncEnum.SIGMOID.value) | |||||
| merged_info = get_basic_node_info(TargetTypeEnum.ACTIVATION.value, graph_stream, ActivationFuncEnum.RELU.value) | |||||
| _recommend_activation_range(merged_info, condition_mgr, watch_points, condition_context, | |||||
| ActivationFuncEnum.RELU.value) | |||||
| return watch_points | return watch_points | ||||
| @@ -118,6 +141,7 @@ def _recommend_tensor_all_zero(basic_info_nodes, condition_mgr, watch_points, co | |||||
| )] | )] | ||||
| }, | }, | ||||
| watch_nodes=basic_info_nodes.copy(), | watch_nodes=basic_info_nodes.copy(), | ||||
| name='recommend_tensor_all_zero_watchpoint' | |||||
| ) | ) | ||||
| watch_points.append(tensor_all_zero_watchpoint) | watch_points.append(tensor_all_zero_watchpoint) | ||||
| @@ -136,6 +160,7 @@ def _recommend_tensor_overflow(basic_info_nodes, condition_mgr, watch_points, co | |||||
| "params": [] | "params": [] | ||||
| }, | }, | ||||
| watch_nodes=basic_info_nodes.copy(), | watch_nodes=basic_info_nodes.copy(), | ||||
| name='recommend_tensor_overflow_watchpoint' | |||||
| ) | ) | ||||
| watch_points.append(overflow_watchpoint) | watch_points.append(overflow_watchpoint) | ||||
| @@ -154,6 +179,7 @@ def _recommend_overflow_ascend_chip(basic_info_nodes, condition_mgr, watch_point | |||||
| "params": [] | "params": [] | ||||
| }, | }, | ||||
| watch_nodes=basic_info_nodes.copy(), | watch_nodes=basic_info_nodes.copy(), | ||||
| name='recommend_overflow_ascend_chip_watchpoint' | |||||
| ) | ) | ||||
| watch_points.append(overflow_d_watchpoint) | watch_points.append(overflow_d_watchpoint) | ||||
| @@ -175,6 +201,7 @@ def _recommend_gradient_vanishing(basic_info_nodes, condition_mgr, watch_points, | |||||
| )] | )] | ||||
| }, | }, | ||||
| watch_nodes=basic_info_nodes.copy(), | watch_nodes=basic_info_nodes.copy(), | ||||
| name='recommend_gradient_vanishing_watchpoint' | |||||
| ) | ) | ||||
| watch_points.append(gradient_vanishing_watchpoint) | watch_points.append(gradient_vanishing_watchpoint) | ||||
| @@ -198,6 +225,7 @@ def _recommend_weight_change_too_small(condition_mgr, trainable_weight_nodes, wa | |||||
| ] | ] | ||||
| }, | }, | ||||
| watch_nodes=trainable_weight_nodes, | watch_nodes=trainable_weight_nodes, | ||||
| name='recommend_weight_change_too_small_watchpoint' | |||||
| ) | ) | ||||
| watch_points.append(weight_change_too_small_watchpoint) | watch_points.append(weight_change_too_small_watchpoint) | ||||
| @@ -225,6 +253,7 @@ def _recommend_weight_not_changed(condition_mgr, trainable_weight_nodes, watch_p | |||||
| ] | ] | ||||
| }, | }, | ||||
| watch_nodes=trainable_weight_nodes, | watch_nodes=trainable_weight_nodes, | ||||
| name='recommend_weight_not_changed_watchpoint' | |||||
| ) | ) | ||||
| watch_points.append(weight_no_change_watchpoint) | watch_points.append(weight_no_change_watchpoint) | ||||
| @@ -246,6 +275,7 @@ def _recommend_weight_change_too_large(basic_info_nodes, condition_mgr, watch_po | |||||
| )] | )] | ||||
| }, | }, | ||||
| watch_nodes=basic_info_nodes.copy(), | watch_nodes=basic_info_nodes.copy(), | ||||
| name='recommend_weight_change_too_large_watchpoint' | |||||
| ) | ) | ||||
| watch_points.append(weight_initialization_watchpoint) | watch_points.append(weight_initialization_watchpoint) | ||||
| @@ -267,21 +297,91 @@ def _recommend_weight_initialization(basic_info_nodes, condition_mgr, watch_poin | |||||
| )] | )] | ||||
| }, | }, | ||||
| watch_nodes=basic_info_nodes.copy(), | watch_nodes=basic_info_nodes.copy(), | ||||
| name='recommend_weight_initialization_watchpoint' | |||||
| ) | ) | ||||
| watch_points.append(weight_initialization_watchpoint) | watch_points.append(weight_initialization_watchpoint) | ||||
| def get_basic_node_info(node_category, graph_stream): | |||||
| def _recommend_activation_range(basic_info_nodes, condition_mgr, watch_points, condition_context, activation_func): | |||||
| """Recommend activation range watchpoint.""" | |||||
| if not basic_info_nodes: | |||||
| return | |||||
| if not condition_mgr.has_condition(ConditionIdEnum.ACTIVATION_RANGE.value, condition_context): | |||||
| return | |||||
| condition = condition_mgr.get_condition(condition_id=ConditionIdEnum.ACTIVATION_RANGE.value) | |||||
| params = [] | |||||
| if activation_func == ActivationFuncEnum.TANH.value: | |||||
| # The recommend params for Tanh: The percentage of value in range (tanh(-8.8), tanh(8.8)) is lower than 50.0% | |||||
| params = [ | |||||
| _ConditionParameterValue( | |||||
| parameter=condition.get_parameter_definition("range_percentage_lt"), | |||||
| value=50.0 | |||||
| ), | |||||
| _ConditionParameterValue( | |||||
| parameter=condition.get_parameter_definition("range_start_inclusive"), | |||||
| value=math.tanh(-8.8) | |||||
| ), | |||||
| _ConditionParameterValue( | |||||
| parameter=condition.get_parameter_definition("range_end_inclusive"), | |||||
| value=math.tanh(8.8) | |||||
| )] | |||||
| if activation_func == ActivationFuncEnum.SIGMOID.value: | |||||
| # The recommend params for Sigmoid: | |||||
| # The percentage of value in range (sigmoid(-16.2)), sigmoid(16.2)) is lower than 50.0% | |||||
| params = [ | |||||
| _ConditionParameterValue( | |||||
| parameter=condition.get_parameter_definition("range_percentage_lt"), | |||||
| value=50.0 | |||||
| ), | |||||
| _ConditionParameterValue( | |||||
| parameter=condition.get_parameter_definition("range_start_inclusive"), | |||||
| value=_sigmoid(-16.2) | |||||
| ), | |||||
| _ConditionParameterValue( | |||||
| parameter=condition.get_parameter_definition("range_end_inclusive"), | |||||
| value=_sigmoid(16.2) | |||||
| )] | |||||
| if activation_func == ActivationFuncEnum.RELU.value: | |||||
| # The recommend params for ReLU: | |||||
| # The percentage of value in range (float('-inf'), 0) is greater than 50.0% | |||||
| params = [ | |||||
| _ConditionParameterValue( | |||||
| parameter=condition.get_parameter_definition("range_percentage_gt"), | |||||
| value=50.0 | |||||
| ), | |||||
| _ConditionParameterValue( | |||||
| parameter=condition.get_parameter_definition("range_start_inclusive"), | |||||
| value=float('-inf') | |||||
| ), | |||||
| _ConditionParameterValue( | |||||
| parameter=condition.get_parameter_definition("range_end_inclusive"), | |||||
| value=0 | |||||
| )] | |||||
| activation_range_watchpoint = _WatchPointData( | |||||
| watch_condition={ | |||||
| "condition": condition.id, | |||||
| "params": params | |||||
| }, | |||||
| watch_nodes=basic_info_nodes.copy(), | |||||
| name='recommend_{}_activation_range_watchpoint'.format(activation_func.lower()) | |||||
| ) | |||||
| watch_points.append(activation_range_watchpoint) | |||||
| def get_basic_node_info(node_category, graph_stream, activation_func=None): | |||||
| """Get node merged info.""" | """Get node merged info.""" | ||||
| basic_info_nodes = _get_basic_node_info_by_node_category(node_category, graph_stream) | |||||
| basic_info_nodes = _get_basic_node_info_by_node_category(node_category, graph_stream, activation_func) | |||||
| merged_info = _merge_nodes(basic_info_nodes, graph_stream.whole_graph) | merged_info = _merge_nodes(basic_info_nodes, graph_stream.whole_graph) | ||||
| merged_info = _add_graph_name(merged_info, graph_stream) | merged_info = _add_graph_name(merged_info, graph_stream) | ||||
| return merged_info | return merged_info | ||||
| def _get_basic_node_info_by_node_category(node_category, graph_stream): | |||||
| def _get_basic_node_info_by_node_category(node_category, graph_stream, activation_func=None): | |||||
| """Get node basic info by node category.""" | """Get node basic info by node category.""" | ||||
| all_graph_nodes = graph_stream.get_searched_nodes(pattern={'node_category': node_category}) | |||||
| pattern = {'node_category': node_category} | |||||
| if activation_func: | |||||
| pattern['condition'] = {'activation_func': activation_func} | |||||
| all_graph_nodes = graph_stream.get_searched_nodes(pattern) | |||||
| basic_info_nodes = [] | basic_info_nodes = [] | ||||
| for graph_name, nodes in all_graph_nodes.items(): | for graph_name, nodes in all_graph_nodes.items(): | ||||
| if len(all_graph_nodes) == 1: | if len(all_graph_nodes) == 1: | ||||
| @@ -329,7 +429,7 @@ def _merge_nodes(leaf_nodes, graph): | |||||
| cur_node = watch_nodes.pop() | cur_node = watch_nodes.pop() | ||||
| node_name = cur_node["name"] | node_name = cur_node["name"] | ||||
| sub_count = graph.normal_node_map.get(node_name).subnode_count | sub_count = graph.normal_node_map.get(node_name).subnode_count | ||||
| if len(cur_node["nodes"]) < sub_count or not cur_node["nodes"]: | |||||
| if len(cur_node["nodes"]) < sub_count: | |||||
| continue | continue | ||||
| is_all_chosen = True | is_all_chosen = True | ||||
| for sub_node in cur_node["nodes"]: | for sub_node in cur_node["nodes"]: | ||||
| @@ -362,3 +462,8 @@ def _add_graph_name(nodes, graph_stream): | |||||
| full_name=node.name, graph_name=graph_name, node_name=node.name, node_type=node.type) | full_name=node.name, graph_name=graph_name, node_name=node.name, node_type=node.type) | ||||
| output_nodes.append(node_basic_info) | output_nodes.append(node_basic_info) | ||||
| return output_nodes | return output_nodes | ||||
| def _sigmoid(value): | |||||
| """return sigmoid value""" | |||||
| return 1.0 / (1.0 + math.exp(value)) | |||||
| @@ -84,11 +84,14 @@ class DebuggerServer: | |||||
| def set_recommended_watch_points(self, set_recommended, train_id): | def set_recommended_watch_points(self, set_recommended, train_id): | ||||
| """set recommended watch points.""" | """set recommended watch points.""" | ||||
| if not isinstance(set_recommended, bool): | |||||
| log.error("Bool param should be given for set_recommended") | |||||
| raise DebuggerParamValueError("Bool param should be given.") | |||||
| metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA) | metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA) | ||||
| condition_context = ConditionContext(metadata_stream.backend, metadata_stream.step, (1, 0)) | condition_context = ConditionContext(metadata_stream.backend, metadata_stream.step, (1, 0)) | ||||
| log.debug("Train_id: %s, backend: %s", train_id, condition_context.backend) | log.debug("Train_id: %s, backend: %s", train_id, condition_context.backend) | ||||
| res = metadata_stream.get(['state', 'enable_recheck']) | res = metadata_stream.get(['state', 'enable_recheck']) | ||||
| if set_recommended: | |||||
| if set_recommended and not metadata_stream.recommendation_confirmed: | |||||
| res['id'] = self._add_recommended_watchpoints(condition_context) | res['id'] = self._add_recommended_watchpoints(condition_context) | ||||
| metadata_stream.recommendation_confirmed = True | metadata_stream.recommendation_confirmed = True | ||||
| return res | return res | ||||
| @@ -104,6 +107,7 @@ class DebuggerServer: | |||||
| watch_points_id = watch_point_stream_handler.create_watchpoint( | watch_points_id = watch_point_stream_handler.create_watchpoint( | ||||
| watch_condition=watchpoint.get_watch_condition_dict(), | watch_condition=watchpoint.get_watch_condition_dict(), | ||||
| watch_nodes=watchpoint.watch_nodes, | watch_nodes=watchpoint.watch_nodes, | ||||
| name=watchpoint.name, | |||||
| condition_mgr=self.condition_mgr | condition_mgr=self.condition_mgr | ||||
| ) | ) | ||||
| watch_points_ids.append(watch_points_id) | watch_points_ids.append(watch_points_id) | ||||
| @@ -24,33 +24,35 @@ from mindinsight.debugger.conditionmgr.condition import ConditionIdEnum | |||||
| from mindinsight.debugger.proto.debug_grpc_pb2 import SetCMD, WatchCondition | from mindinsight.debugger.proto.debug_grpc_pb2 import SetCMD, WatchCondition | ||||
| WATCHPOINT_CONDITION_MAPPING = { | WATCHPOINT_CONDITION_MAPPING = { | ||||
| ConditionIdEnum.NAN.value: WatchCondition.Condition.nan, | |||||
| ConditionIdEnum.ACTIVATION_RANGE.value: WatchCondition.Condition.tensor_range, | |||||
| ConditionIdEnum.GRADIENT_EXPLODING.value: WatchCondition.Condition.tensor_general_overflow, | |||||
| ConditionIdEnum.GRADIENT_TOO_LARGE.value: WatchCondition.Condition.tensor_too_large, | |||||
| ConditionIdEnum.GRADIENT_VANISHING.value: WatchCondition.Condition.tensor_too_small, | |||||
| ConditionIdEnum.INF.value: WatchCondition.Condition.inf, | ConditionIdEnum.INF.value: WatchCondition.Condition.inf, | ||||
| ConditionIdEnum.OVERFLOW_ASCEND_CHIP.value: WatchCondition.Condition.overflow, | |||||
| ConditionIdEnum.MAX_GT.value: WatchCondition.Condition.max_gt, | ConditionIdEnum.MAX_GT.value: WatchCondition.Condition.max_gt, | ||||
| ConditionIdEnum.MAX_LT.value: WatchCondition.Condition.max_lt, | ConditionIdEnum.MAX_LT.value: WatchCondition.Condition.max_lt, | ||||
| ConditionIdEnum.MIN_GT.value: WatchCondition.Condition.min_gt, | |||||
| ConditionIdEnum.MIN_LT.value: WatchCondition.Condition.min_lt, | |||||
| ConditionIdEnum.MAX_MIN_GT.value: WatchCondition.Condition.max_min_gt, | ConditionIdEnum.MAX_MIN_GT.value: WatchCondition.Condition.max_min_gt, | ||||
| ConditionIdEnum.MAX_MIN_LT.value: WatchCondition.Condition.max_min_lt, | ConditionIdEnum.MAX_MIN_LT.value: WatchCondition.Condition.max_min_lt, | ||||
| ConditionIdEnum.MEAN_GT.value: WatchCondition.Condition.mean_gt, | ConditionIdEnum.MEAN_GT.value: WatchCondition.Condition.mean_gt, | ||||
| ConditionIdEnum.MEAN_LT.value: WatchCondition.Condition.mean_lt, | ConditionIdEnum.MEAN_LT.value: WatchCondition.Condition.mean_lt, | ||||
| ConditionIdEnum.TENSOR_OVERFLOW.value: WatchCondition.Condition.tensor_general_overflow, | |||||
| ConditionIdEnum.WEIGHT_OVERFLOW.value: WatchCondition.Condition.tensor_general_overflow, | |||||
| ConditionIdEnum.MIN_GT.value: WatchCondition.Condition.min_gt, | |||||
| ConditionIdEnum.MIN_LT.value: WatchCondition.Condition.min_lt, | |||||
| ConditionIdEnum.NAN.value: WatchCondition.Condition.nan, | |||||
| ConditionIdEnum.OPERATOR_OVERFLOW.value: WatchCondition.Condition.overflow, | ConditionIdEnum.OPERATOR_OVERFLOW.value: WatchCondition.Condition.overflow, | ||||
| ConditionIdEnum.OVERFLOW_ASCEND_CHIP.value: WatchCondition.Condition.overflow, | |||||
| ConditionIdEnum.TENSOR_ALL_ZERO.value: WatchCondition.Condition.tensor_all_zero, | |||||
| ConditionIdEnum.TENSOR_INITIALIZATION.value: WatchCondition.Condition.tensor_initialization, | ConditionIdEnum.TENSOR_INITIALIZATION.value: WatchCondition.Condition.tensor_initialization, | ||||
| ConditionIdEnum.WEIGHT_INITIALIZATION.value: WatchCondition.Condition.tensor_initialization, | |||||
| ConditionIdEnum.TENSOR_OVERFLOW.value: WatchCondition.Condition.tensor_general_overflow, | |||||
| ConditionIdEnum.TENSOR_RANGE.value: WatchCondition.Condition.tensor_range, | |||||
| ConditionIdEnum.TENSOR_TOO_LARGE.value: WatchCondition.Condition.tensor_too_large, | ConditionIdEnum.TENSOR_TOO_LARGE.value: WatchCondition.Condition.tensor_too_large, | ||||
| ConditionIdEnum.WEIGHT_TOO_LARGE.value: WatchCondition.Condition.tensor_too_large, | |||||
| ConditionIdEnum.GRADIENT_TOO_LARGE.value: WatchCondition.Condition.tensor_too_large, | |||||
| ConditionIdEnum.GRADIENT_EXPLODING.value: WatchCondition.Condition.tensor_general_overflow, | |||||
| ConditionIdEnum.TENSOR_TOO_SMALL.value: WatchCondition.Condition.tensor_too_small, | ConditionIdEnum.TENSOR_TOO_SMALL.value: WatchCondition.Condition.tensor_too_small, | ||||
| ConditionIdEnum.WEIGHT_TOO_SMALL.value: WatchCondition.Condition.tensor_too_small, | |||||
| ConditionIdEnum.GRADIENT_VANISHING.value: WatchCondition.Condition.tensor_too_small, | |||||
| ConditionIdEnum.TENSOR_ALL_ZERO.value: WatchCondition.Condition.tensor_all_zero, | |||||
| ConditionIdEnum.WEIGHT_CHANGE_TOO_LARGE.value: WatchCondition.Condition.tensor_change_too_large, | ConditionIdEnum.WEIGHT_CHANGE_TOO_LARGE.value: WatchCondition.Condition.tensor_change_too_large, | ||||
| ConditionIdEnum.WEIGHT_CHANGE_TOO_SMALL.value: WatchCondition.Condition.tensor_change_too_small, | ConditionIdEnum.WEIGHT_CHANGE_TOO_SMALL.value: WatchCondition.Condition.tensor_change_too_small, | ||||
| ConditionIdEnum.WEIGHT_NOT_CHANGED.value: WatchCondition.Condition.tensor_not_changed | |||||
| ConditionIdEnum.WEIGHT_INITIALIZATION.value: WatchCondition.Condition.tensor_initialization, | |||||
| ConditionIdEnum.WEIGHT_NOT_CHANGED.value: WatchCondition.Condition.tensor_not_changed, | |||||
| ConditionIdEnum.WEIGHT_OVERFLOW.value: WatchCondition.Condition.tensor_general_overflow, | |||||
| ConditionIdEnum.WEIGHT_TOO_LARGE.value: WatchCondition.Condition.tensor_too_large, | |||||
| ConditionIdEnum.WEIGHT_TOO_SMALL.value: WatchCondition.Condition.tensor_too_small | |||||
| } | } | ||||
| @@ -180,10 +182,11 @@ class Watchpoint: | |||||
| - param (list[float]): Not defined yet. | - param (list[float]): Not defined yet. | ||||
| """ | """ | ||||
| def __init__(self, watchpoint_id, watch_condition): | |||||
| def __init__(self, watchpoint_id, watch_condition, name=None): | |||||
| self._id = watchpoint_id | self._id = watchpoint_id | ||||
| self._condition = watch_condition | self._condition = watch_condition | ||||
| self._watch_node = WatchNodeTree() | self._watch_node = WatchNodeTree() | ||||
| self.name = name | |||||
| @property | @property | ||||
| def watchpoint_id(self): | def watchpoint_id(self): | ||||
| @@ -308,6 +311,8 @@ class Watchpoint: | |||||
| 'id': self._id, | 'id': self._id, | ||||
| 'watch_condition': self._condition | 'watch_condition': self._condition | ||||
| } | } | ||||
| if self.name: | |||||
| watchpoint_info['name'] = self.name | |||||
| return watchpoint_info | return watchpoint_info | ||||
| @@ -215,7 +215,7 @@ class WatchpointHandler(StreamHandlerBase): | |||||
| return state | return state | ||||
| def create_watchpoint(self, condition_mgr, watch_condition, watch_nodes=None, watch_point_id=None): | |||||
| def create_watchpoint(self, condition_mgr, watch_condition, watch_nodes=None, watch_point_id=None, name=None): | |||||
| """ | """ | ||||
| Create watchpoint. | Create watchpoint. | ||||
| Args: | Args: | ||||
| @@ -234,6 +234,7 @@ class WatchpointHandler(StreamHandlerBase): | |||||
| - param (list[dict]): The list of param for this condition. | - param (list[dict]): The list of param for this condition. | ||||
| watch_nodes (list[NodeBasicInfo]): The list of node basic info. | watch_nodes (list[NodeBasicInfo]): The list of node basic info. | ||||
| watch_point_id (int): The id of watchpoint. | watch_point_id (int): The id of watchpoint. | ||||
| name (str): The name of watchpoint. | |||||
| Returns: | Returns: | ||||
| int, the new id of watchpoint. | int, the new id of watchpoint. | ||||
| @@ -241,7 +242,7 @@ class WatchpointHandler(StreamHandlerBase): | |||||
| validate_watch_condition(condition_mgr, watch_condition) | validate_watch_condition(condition_mgr, watch_condition) | ||||
| watch_condition = set_default_param(condition_mgr, watch_condition) | watch_condition = set_default_param(condition_mgr, watch_condition) | ||||
| new_id = self._latest_id + 1 | new_id = self._latest_id + 1 | ||||
| watchpoint = Watchpoint(new_id, watch_condition) | |||||
| watchpoint = Watchpoint(new_id, watch_condition, name) | |||||
| if watch_nodes: | if watch_nodes: | ||||
| watchpoint.add_nodes(watch_nodes) | watchpoint.add_nodes(watch_nodes) | ||||
| elif watch_point_id: | elif watch_point_id: | ||||