Browse Source

!954 fix the bug about showing updated parameter value

From: @yelihua
Reviewed-by: @wangyue01,@lilongfei15
Signed-off-by: @lilongfei15
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
f8e3634723
6 changed files with 63 additions and 5 deletions
  1. +2
    -1
      mindinsight/debugger/conditionmgr/condition.py
  2. +13
    -1
      mindinsight/debugger/debugger_grpc_server.py
  3. +13
    -0
      mindinsight/debugger/stream_cache/node_type_identifier.py
  4. +33
    -1
      mindinsight/debugger/stream_handler/tensor_handler.py
  5. +1
    -2
      tests/st/func/debugger/mock_ms_client.py
  6. +1
    -0
      tests/st/func/debugger/test_restful_api.py

+ 2
- 1
mindinsight/debugger/conditionmgr/condition.py View File

@@ -68,9 +68,10 @@ class PlatformEnum(Enum):
class TargetTypeEnum(Enum): class TargetTypeEnum(Enum):
"""Target types.""" """Target types."""
TENSOR = 'tensor' TENSOR = 'tensor'
WEIGHT = 'weight'
ACTIVATION = 'activation' ACTIVATION = 'activation'
GRADIENT = 'gradient' GRADIENT = 'gradient'
PARAMETER = 'parameter'
WEIGHT = 'weight'




class ParamTypeEnum(Enum): class ParamTypeEnum(Enum):


+ 13
- 1
mindinsight/debugger/debugger_grpc_server.py View File

@@ -21,6 +21,7 @@ import mindinsight
from mindinsight.debugger.common.log import LOGGER as log from mindinsight.debugger.common.log import LOGGER as log
from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \ from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \
Streams, RunLevel Streams, RunLevel
from mindinsight.debugger.conditionmgr.condition import TargetTypeEnum
from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base
from mindinsight.debugger.proto.ms_graph_pb2 import GraphProto from mindinsight.debugger.proto.ms_graph_pb2 import GraphProto


@@ -117,9 +118,9 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
# clean cache data at the beginning of new step or node has been changed. # clean cache data at the beginning of new step or node has been changed.
if is_new_step or is_new_node: if is_new_step or is_new_node:
self._cache_store.clean_data() self._cache_store.clean_data()
self._cache_store.get_stream_handler(Streams.TENSOR).clean_tensors(request.cur_step)
if is_new_step: if is_new_step:
self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean() self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean()
self._cache_store.get_stream_handler(Streams.TENSOR).clean_tensors(request.cur_step)
# receive graph at the beginning of the training # receive graph at the beginning of the training
if self._status == ServerStatus.RECEIVE_GRAPH: if self._status == ServerStatus.RECEIVE_GRAPH:
self._send_graph_flag(metadata_stream) self._send_graph_flag(metadata_stream)
@@ -397,6 +398,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
self._cache_store.get_stream_handler(Streams.GRAPH).put(graph_dict) self._cache_store.get_stream_handler(Streams.GRAPH).put(graph_dict)
self._cache_store.get_stream_handler(Streams.TENSOR).put_const_vals(graph.const_vals) self._cache_store.get_stream_handler(Streams.TENSOR).put_const_vals(graph.const_vals)
self._cache_store.get_stream_handler(Streams.METADATA).graph_name = graph.name self._cache_store.get_stream_handler(Streams.METADATA).graph_name = graph.name
self._record_parameter_names()
self._status = ServerStatus.RECEIVE_GRAPH self._status = ServerStatus.RECEIVE_GRAPH
log.debug("Send the reply for graph.") log.debug("Send the reply for graph.")
return reply return reply
@@ -423,10 +425,20 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
sub_graph.const_vals) sub_graph.const_vals)


self._cache_store.get_stream_handler(Streams.GRAPH).put(graph_dict) self._cache_store.get_stream_handler(Streams.GRAPH).put(graph_dict)
self._record_parameter_names()
self._status = ServerStatus.RECEIVE_GRAPH self._status = ServerStatus.RECEIVE_GRAPH
log.debug("Send the reply for graph.") log.debug("Send the reply for graph.")
return reply return reply


def _record_parameter_names(self):
"""Record parameter full names in tensor handler."""
parameter_nodes = self._cache_store.get_stream_handler(Streams.GRAPH).get_searched_nodes(
pattern={'node_category': TargetTypeEnum.PARAMETER.value})
tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR)
for nodes in parameter_nodes.values():
tensor_names = [node.full_name + ':0' for node in nodes]
tensor_stream.record_parameter_names(tensor_names)

@debugger_wrap @debugger_wrap
def SendTensors(self, request_iterator, context): def SendTensors(self, request_iterator, context):
"""Send tensors into DebuggerCache.""" """Send tensors into DebuggerCache."""


+ 13
- 0
mindinsight/debugger/stream_cache/node_type_identifier.py View File

@@ -66,6 +66,19 @@ class NodeTypeIdentifier:
return self.identify_func(*args, **kwargs) return self.identify_func(*args, **kwargs)




def is_parameter_node(node):
"""
Check if the node is weight type.

Args:
node (Node): The node object.

Returns:
bool, if the node is weight type.
"""
return bool(node.type == NodeTypeEnum.PARAMETER.value)


def is_weight_node(node): def is_weight_node(node):
""" """
Check if the node is weight type. Check if the node is weight type.


+ 33
- 1
mindinsight/debugger/stream_handler/tensor_handler.py View File

@@ -27,11 +27,17 @@ from mindinsight.utils.tensor import TensorUtils, TensorComparison


TensorBasicInfo = namedtuple('tensor_basic_info', ['full_name', 'node_type', 'iter']) TensorBasicInfo = namedtuple('tensor_basic_info', ['full_name', 'node_type', 'iter'])



class TensorHandler(StreamHandlerBase): class TensorHandler(StreamHandlerBase):
"""Metadata Handler.""" """Metadata Handler."""


def __init__(self): def __init__(self):
# the collection of parameter full names
self._param_names = set()
# const value objects, the format is like: dict[<const name>, <OpTensor object>]
self._const_vals = {} self._const_vals = {}
# tensor values, the format is like:
# dict[<tensor full name>, dict[<step_num>, <OpTensor object>]]
self._tensors = {} self._tensors = {}
self._cur_step = 0 self._cur_step = 0


@@ -147,6 +153,19 @@ class TensorHandler(StreamHandlerBase):
const_tensor = ConstTensor(const_val) const_tensor = ConstTensor(const_val)
self._const_vals[const_tensor.name] = const_tensor self._const_vals[const_tensor.name] = const_tensor


def record_parameter_names(self, names):
"""
Record parameter names.

Note:
Parameter values could be changed during an iteration step. It must be cleaned after each node step.

Args:
names (list[str]): List of tensor full names.
"""
self._param_names.update(names)
log.debug("Record %d parameters in cache. Total parameter number: %d", len(names), len(self._param_names))

def get(self, filter_condition=None): def get(self, filter_condition=None):
""" """
Get full tensor value. Get full tensor value.
@@ -293,7 +312,13 @@ class TensorHandler(StreamHandlerBase):


def clean_tensors(self, cur_step): def clean_tensors(self, cur_step):
"""Clean the tensor cache.""" """Clean the tensor cache."""
self._cur_step = cur_step
if cur_step != self._cur_step:
self._cur_step = cur_step
self._clean_expired_tensors(cur_step)
self._clean_parameters()

def _clean_expired_tensors(self, cur_step):
"""Clean expired tensors less than current steps."""
expired_tensor = [] expired_tensor = []
for tensor_name, tensor in self._tensors.items(): for tensor_name, tensor in self._tensors.items():
expired_step = [step for step in tensor.keys() if step <= cur_step - 2] expired_step = [step for step in tensor.keys() if step <= cur_step - 2]
@@ -304,6 +329,13 @@ class TensorHandler(StreamHandlerBase):
for tensor_name in expired_tensor: for tensor_name in expired_tensor:
self._tensors.pop(tensor_name) self._tensors.pop(tensor_name)


def _clean_parameters(self):
"""Clean parameter cache."""
for param in self._param_names:
if param in self._tensors:
self._tensors.pop(param)
log.debug("Clean param %s in cache.", param)

def get_tensors_diff(self, tensor_name, shape, tolerance=0): def get_tensors_diff(self, tensor_name, shape, tolerance=0):
""" """
Get tensor comparisons data for given name, detail, shape and tolerance. Get tensor comparisons data for given name, detail, shape and tolerance.


+ 1
- 2
tests/st/func/debugger/mock_ms_client.py View File

@@ -123,11 +123,10 @@ class MockDebuggerClient:
metadata = self.get_metadata_cmd(training_done) metadata = self.get_metadata_cmd(training_done)
response = self.stub.SendMetadata(metadata) response = self.stub.SendMetadata(metadata)
assert response.status == EventReply.Status.OK assert response.status == EventReply.Status.OK
if response.version_matched is False:
if response.HasField('version_matched') and response.version_matched is False:
self.command_loop() self.command_loop()
if training_done is False: if training_done is False:
self.send_graph_cmd() self.send_graph_cmd()
print("finish")


def send_graph_cmd(self): def send_graph_cmd(self):
"""Send graph to debugger server.""" """Send graph to debugger server."""


+ 1
- 0
tests/st/func/debugger/test_restful_api.py View File

@@ -672,6 +672,7 @@ def create_watchpoint_and_wait(app_client):
# wait for server has received watchpoint hit # wait for server has received watchpoint hit
check_state(app_client) check_state(app_client)



class TestMismatchDebugger: class TestMismatchDebugger:
"""Test debugger when Mindinsight and Mindspore is mismatched.""" """Test debugger when Mindinsight and Mindspore is mismatched."""




Loading…
Cancel
Save