Browse Source

change names of input and output

tags/v1.0.0
Li Hongzhang 5 years ago
parent
commit
8fdce0280e
4 changed files with 36 additions and 36 deletions
  1. +12
    -12
      mindinsight/datavisual/data_transform/graph/graph.py
  2. +6
    -6
      mindinsight/datavisual/data_transform/graph/msgraph.py
  3. +12
    -12
      mindinsight/datavisual/data_transform/graph/node.py
  4. +6
    -6
      mindinsight/debugger/stream_cache/debugger_graph.py

+ 12
- 12
mindinsight/datavisual/data_transform/graph/graph.py View File

@@ -276,7 +276,7 @@ class Graph:
So in this scenario, we need to filter the const nodes.
"""
filtered_type = {NodeTypeEnum.CONST.value} if filtered_type is None else filtered_type
for method in ['input', 'output', 'proxy_input', 'proxy_output']:
for method in ['inputs', 'outputs', 'proxy_inputs', 'proxy_outputs']:
for node in subnode_list:
for item_name, item_attr in getattr(node, method).items():
target_node = self._get_normal_node(node_name=item_name)
@@ -375,7 +375,7 @@ class Graph:
logger.info("Start to add %s nodes to each scope in graph.", node_type)
node_map = {}
for node in self._normal_node_map.values():
for src_name, input_attr in dict(node.input).items():
for src_name, input_attr in dict(node.inputs).items():

if node_type == NodeTypeEnum.CONST.value and not self._const_node_temp_cache.get(src_name):
continue
@@ -393,11 +393,11 @@ class Graph:
Node.copy_node_without_input_output(cache_node, variable_node)
variable_node.scope = node.scope

variable_node.add_output(dst_name=node.name, output_attr=input_attr)
variable_node.add_outputs(dst_name=node.name, output_attr=input_attr)
node_map.update({variable_name: variable_node})

node.delete_input(src_name)
node.add_input(variable_name, input_attr)
node.delete_inputs(src_name)
node.add_inputs(variable_name, input_attr)

for node in node_map.values():
self._cache_node(node)
@@ -497,7 +497,7 @@ class Graph:

# Find all nodes that need to modify the input and input
update_node_map = {}
for method in ['input', 'output', 'proxy_input', 'proxy_output']:
for method in ['inputs', 'outputs', 'proxy_inputs', 'proxy_outputs']:
for target_name in getattr(node, method):
target_node = self._get_normal_node(node_name=target_name)
if target_node is None:
@@ -529,7 +529,7 @@ class Graph:

# Update the input and output of the nodes
for target_node in update_node_map.values():
for method in ['input', 'output', 'proxy_input', 'proxy_output']:
for method in ['inputs', 'outputs', 'proxy_inputs', 'proxy_outputs']:
attr_temp = getattr(target_node, method).get(origin_name)
if attr_temp is None:
# This method does not have this node, so it is skipped
@@ -563,12 +563,12 @@ class Graph:
for scope_node in independent_layout_node_map.values():
scope_node.independent_layout = True

method = 'output'
method = 'outputs'
for target_name, target_attr in dict(getattr(scope_node, method)).items():
proxy_attr = dict(edge_type=target_attr['edge_type'])

target_node = self._get_normal_node(node_name=target_name)
getattr(target_node, 'add_proxy_input')(scope_node.name, proxy_attr)
getattr(target_node, 'add_proxy_inputs')(scope_node.name, proxy_attr)

# Note:
# If the source node and the destination node are not in the same scope,
@@ -581,7 +581,7 @@ class Graph:
else:
target_scope_node = self._get_normal_node(node_name=target_node.scope)
getattr(scope_node, f'add_proxy_{method}')(target_node.scope, proxy_attr)
getattr(target_scope_node, 'add_proxy_input')(scope_node.name, proxy_attr)
getattr(target_scope_node, 'add_proxy_inputs')(scope_node.name, proxy_attr)

for subnode in subnode_map[scope_node.name]:
subnode.independent_layout = True
@@ -593,6 +593,6 @@ class Graph:
else:
getattr(subnode, f'add_proxy_{method}')(target_node.scope, proxy_attr)

input_attr = getattr(target_node, 'input')[subnode.name]
input_attr = getattr(target_node, 'inputs')[subnode.name]
input_attr['independent_layout'] = True
target_node.add_input(subnode.name, input_attr)
target_node.add_inputs(subnode.name, input_attr)

+ 6
- 6
mindinsight/datavisual/data_transform/graph/msgraph.py View File

@@ -225,7 +225,7 @@ class MSGraph(Graph):
'data_type': ''
}

node.add_input(src_name=input_proto.name, input_attr=input_attr)
node.add_inputs(src_name=input_proto.name, input_attr=input_attr)

def _parse_attributes(self, attributes, node):
"""
@@ -246,8 +246,8 @@ class MSGraph(Graph):
def _update_input_after_create_node(self):
"""Update the input of node after create node."""
for node in self._normal_node_map.values():
for src_node_id, input_attr in dict(node.input).items():
node.delete_input(src_node_id)
for src_node_id, input_attr in dict(node.inputs).items():
node.delete_inputs(src_node_id)
if not self._is_node_exist(node_id=src_node_id):
message = f"The input node could not be found by node id({src_node_id}) " \
f"while updating the input of the node({node})"
@@ -258,19 +258,19 @@ class MSGraph(Graph):
src_node = self._get_normal_node(node_id=src_node_id)
input_attr['shape'] = src_node.output_shape
input_attr['data_type'] = src_node.output_data_type
node.add_input(src_name=src_node.name, input_attr=input_attr)
node.add_inputs(src_name=src_node.name, input_attr=input_attr)

def _update_output_after_create_node(self):
"""Update the output of node after create node."""
# Constants and parameter should not exist for input and output.
filtered_node = {NodeTypeEnum.CONST.value, NodeTypeEnum.PARAMETER.value}
for node in self._normal_node_map.values():
for src_name, input_attr in node.input.items():
for src_name, input_attr in node.inputs.items():
src_node = self._get_normal_node(node_name=src_name)
if src_node.type in filtered_node:
continue

src_node.add_output(node.name, input_attr)
src_node.add_outputs(node.name, input_attr)

@staticmethod
def _get_data_type_name_by_value(data_type, value, field_name='data_type'):


+ 12
- 12
mindinsight/datavisual/data_transform/graph/node.py View File

@@ -103,7 +103,7 @@ class Node:
self._attr.update(attr_dict)

@property
def input(self):
def inputs(self):
"""
Get all input of current node.

@@ -112,7 +112,7 @@ class Node:
"""
return self._input

def add_input(self, src_name, input_attr):
def add_inputs(self, src_name, input_attr):
"""
Update input.

@@ -127,7 +127,7 @@ class Node:
"""
self._input.update({src_name: input_attr})

def delete_input(self, src_name):
def delete_inputs(self, src_name):
"""
Delete input attribute by the given source name.

@@ -137,11 +137,11 @@ class Node:
self._input.pop(src_name)

@property
def output(self):
def outputs(self):
"""The output node of this node."""
return self._output

def add_output(self, dst_name, output_attr):
def add_outputs(self, dst_name, output_attr):
"""
Add a output node to this node.

@@ -151,7 +151,7 @@ class Node:
"""
self._output.update({dst_name: output_attr})

def delete_output(self, dst_name):
def delete_outputs(self, dst_name):
"""
Delete a output node.

@@ -161,11 +161,11 @@ class Node:
self._output.pop(dst_name)

@property
def proxy_input(self):
def proxy_inputs(self):
"""Return proxy input, type is dict."""
return self._proxy_input

def add_proxy_input(self, src_name, attr):
def add_proxy_inputs(self, src_name, attr):
"""
Add a proxy input to node.

@@ -177,16 +177,16 @@ class Node:
"""
self._proxy_input.update({src_name: attr})

def delete_proxy_input(self, src_name):
def delete_proxy_inputs(self, src_name):
"""Delete a proxy input by the src name."""
self._proxy_input.pop(src_name)

@property
def proxy_output(self):
def proxy_outputs(self):
"""Get proxy output, data type is dict."""
return self._proxy_output

def add_proxy_output(self, dst_name, attr):
def add_proxy_outputs(self, dst_name, attr):
"""
Add a proxy output to node.

@@ -198,7 +198,7 @@ class Node:
"""
self._proxy_output.update({dst_name: attr})

def delete_proxy_output(self, dst_name):
def delete_proxy_outputs(self, dst_name):
"""Delete a proxy output by dst name."""
self._proxy_output.pop(dst_name)



+ 6
- 6
mindinsight/debugger/stream_cache/debugger_graph.py View File

@@ -179,7 +179,7 @@ class DebuggerGraph(MSGraph):
if tensors_info:
tensor_history.extend(tensors_info)
if cur_depth < depth:
for name in cur_node.input.keys():
for name in cur_node.inputs.keys():
trace_list.append((self._leaf_nodes[name], cur_depth + 1))

return tensor_history, cur_outputs_nums
@@ -208,7 +208,7 @@ class DebuggerGraph(MSGraph):
def _get_input_tensors_of_node(self, cur_node):
"""Get input tensors of node."""
tensors_info = []
for name in cur_node.input.keys():
for name in cur_node.inputs.keys():
node = self._leaf_nodes.get(name)
tensor_info = self._get_tensor_infos_of_node(node)
tensors_info.extend(tensor_info)
@@ -258,12 +258,12 @@ class DebuggerGraph(MSGraph):
continue

bfs_order.append(node_name)
if node.input:
for name in node.input.keys():
if node.inputs:
for name in node.inputs.keys():
if name not in temp_list and name not in bfs_order:
temp_list.append(name)
if node.output:
for name in node.output.keys():
if node.outputs:
for name in node.outputs.keys():
if name not in temp_list and name not in bfs_order:
temp_list.append(name)



Loading…
Cancel
Save