Browse Source

fix the bug about showing parameter value on GPU

tags/v1.1.0
yelihua 5 years ago
parent
commit
4dc2774295
6 changed files with 63 additions and 5 deletions
  1. +2
    -1
      mindinsight/debugger/conditionmgr/condition.py
  2. +13
    -1
      mindinsight/debugger/debugger_grpc_server.py
  3. +13
    -0
      mindinsight/debugger/stream_cache/node_type_identifier.py
  4. +33
    -1
      mindinsight/debugger/stream_handler/tensor_handler.py
  5. +1
    -2
      tests/st/func/debugger/mock_ms_client.py
  6. +1
    -0
      tests/st/func/debugger/test_restful_api.py

+ 2
- 1
mindinsight/debugger/conditionmgr/condition.py View File

@@ -68,9 +68,10 @@ class PlatformEnum(Enum):
class TargetTypeEnum(Enum):
"""Target types."""
TENSOR = 'tensor'
WEIGHT = 'weight'
ACTIVATION = 'activation'
GRADIENT = 'gradient'
PARAMETER = 'parameter'
WEIGHT = 'weight'


class ParamTypeEnum(Enum):


+ 13
- 1
mindinsight/debugger/debugger_grpc_server.py View File

@@ -21,6 +21,7 @@ import mindinsight
from mindinsight.debugger.common.log import LOGGER as log
from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \
Streams, RunLevel
from mindinsight.debugger.conditionmgr.condition import TargetTypeEnum
from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base
from mindinsight.debugger.proto.ms_graph_pb2 import GraphProto

@@ -117,9 +118,9 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
# clean cache data at the beginning of new step or node has been changed.
if is_new_step or is_new_node:
self._cache_store.clean_data()
self._cache_store.get_stream_handler(Streams.TENSOR).clean_tensors(request.cur_step)
if is_new_step:
self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean()
self._cache_store.get_stream_handler(Streams.TENSOR).clean_tensors(request.cur_step)
# receive graph at the beginning of the training
if self._status == ServerStatus.RECEIVE_GRAPH:
self._send_graph_flag(metadata_stream)
@@ -397,6 +398,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
self._cache_store.get_stream_handler(Streams.GRAPH).put(graph_dict)
self._cache_store.get_stream_handler(Streams.TENSOR).put_const_vals(graph.const_vals)
self._cache_store.get_stream_handler(Streams.METADATA).graph_name = graph.name
self._record_parameter_names()
self._status = ServerStatus.RECEIVE_GRAPH
log.debug("Send the reply for graph.")
return reply
@@ -423,10 +425,20 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
sub_graph.const_vals)

self._cache_store.get_stream_handler(Streams.GRAPH).put(graph_dict)
self._record_parameter_names()
self._status = ServerStatus.RECEIVE_GRAPH
log.debug("Send the reply for graph.")
return reply

def _record_parameter_names(self):
"""Record parameter full names in tensor handler."""
parameter_nodes = self._cache_store.get_stream_handler(Streams.GRAPH).get_searched_nodes(
pattern={'node_category': TargetTypeEnum.PARAMETER.value})
tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR)
for nodes in parameter_nodes.values():
tensor_names = [node.full_name + ':0' for node in nodes]
tensor_stream.record_parameter_names(tensor_names)

@debugger_wrap
def SendTensors(self, request_iterator, context):
"""Send tensors into DebuggerCache."""


+ 13
- 0
mindinsight/debugger/stream_cache/node_type_identifier.py View File

@@ -66,6 +66,19 @@ class NodeTypeIdentifier:
return self.identify_func(*args, **kwargs)


def is_parameter_node(node):
"""
Check if the node is weight type.

Args:
node (Node): The node object.

Returns:
bool, if the node is weight type.
"""
return bool(node.type == NodeTypeEnum.PARAMETER.value)


def is_weight_node(node):
"""
Check if the node is weight type.


+ 33
- 1
mindinsight/debugger/stream_handler/tensor_handler.py View File

@@ -27,11 +27,17 @@ from mindinsight.utils.tensor import TensorUtils, TensorComparison

TensorBasicInfo = namedtuple('tensor_basic_info', ['full_name', 'node_type', 'iter'])


class TensorHandler(StreamHandlerBase):
"""Metadata Handler."""

def __init__(self):
# the collection of parameter full names
self._param_names = set()
# const value objects, the format is like: dict[<const name>, <OpTensor object>]
self._const_vals = {}
# tensor values, the format is like:
# dict[<tensor full name>, dict[<step_num>, <OpTensor object>]]
self._tensors = {}
self._cur_step = 0

@@ -147,6 +153,19 @@ class TensorHandler(StreamHandlerBase):
const_tensor = ConstTensor(const_val)
self._const_vals[const_tensor.name] = const_tensor

def record_parameter_names(self, names):
"""
Record parameter names.

Note:
Parameter values could be changed during an iteration step. It must be cleaned after each node step.

Args:
names (list[str]): List of tensor full names.
"""
self._param_names.update(names)
log.debug("Record %d parameters in cache. Total parameter number: %d", len(names), len(self._param_names))

def get(self, filter_condition=None):
"""
Get full tensor value.
@@ -293,7 +312,13 @@ class TensorHandler(StreamHandlerBase):

def clean_tensors(self, cur_step):
"""Clean the tensor cache."""
self._cur_step = cur_step
if cur_step != self._cur_step:
self._cur_step = cur_step
self._clean_expired_tensors(cur_step)
self._clean_parameters()

def _clean_expired_tensors(self, cur_step):
"""Clean expired tensors less than current steps."""
expired_tensor = []
for tensor_name, tensor in self._tensors.items():
expired_step = [step for step in tensor.keys() if step <= cur_step - 2]
@@ -304,6 +329,13 @@ class TensorHandler(StreamHandlerBase):
for tensor_name in expired_tensor:
self._tensors.pop(tensor_name)

def _clean_parameters(self):
"""Clean parameter cache."""
for param in self._param_names:
if param in self._tensors:
self._tensors.pop(param)
log.debug("Clean param %s in cache.", param)

def get_tensors_diff(self, tensor_name, shape, tolerance=0):
"""
Get tensor comparisons data for given name, detail, shape and tolerance.


+ 1
- 2
tests/st/func/debugger/mock_ms_client.py View File

@@ -123,11 +123,10 @@ class MockDebuggerClient:
metadata = self.get_metadata_cmd(training_done)
response = self.stub.SendMetadata(metadata)
assert response.status == EventReply.Status.OK
if response.version_matched is False:
if response.HasField('version_matched') and response.version_matched is False:
self.command_loop()
if training_done is False:
self.send_graph_cmd()
print("finish")

def send_graph_cmd(self):
"""Send graph to debugger server."""


+ 1
- 0
tests/st/func/debugger/test_restful_api.py View File

@@ -672,6 +672,7 @@ def create_watchpoint_and_wait(app_client):
# wait for server has received watchpoint hit
check_state(app_client)


class TestMismatchDebugger:
"""Test debugger when Mindinsight and Mindspore is mismatched."""



Loading…
Cancel
Save