| @@ -276,7 +276,7 @@ class Graph: | |||
| So in this scenario, we need to filter the const nodes. | |||
| """ | |||
| filtered_type = {NodeTypeEnum.CONST.value} if filtered_type is None else filtered_type | |||
| for method in ['input', 'output', 'proxy_input', 'proxy_output']: | |||
| for method in ['inputs', 'outputs', 'proxy_inputs', 'proxy_outputs']: | |||
| for node in subnode_list: | |||
| for item_name, item_attr in getattr(node, method).items(): | |||
| target_node = self._get_normal_node(node_name=item_name) | |||
| @@ -375,7 +375,7 @@ class Graph: | |||
| logger.info("Start to add %s nodes to each scope in graph.", node_type) | |||
| node_map = {} | |||
| for node in self._normal_node_map.values(): | |||
| for src_name, input_attr in dict(node.input).items(): | |||
| for src_name, input_attr in dict(node.inputs).items(): | |||
| if node_type == NodeTypeEnum.CONST.value and not self._const_node_temp_cache.get(src_name): | |||
| continue | |||
| @@ -393,11 +393,11 @@ class Graph: | |||
| Node.copy_node_without_input_output(cache_node, variable_node) | |||
| variable_node.scope = node.scope | |||
| variable_node.add_output(dst_name=node.name, output_attr=input_attr) | |||
| variable_node.add_outputs(dst_name=node.name, output_attr=input_attr) | |||
| node_map.update({variable_name: variable_node}) | |||
| node.delete_input(src_name) | |||
| node.add_input(variable_name, input_attr) | |||
| node.delete_inputs(src_name) | |||
| node.add_inputs(variable_name, input_attr) | |||
| for node in node_map.values(): | |||
| self._cache_node(node) | |||
| @@ -497,7 +497,7 @@ class Graph: | |||
| # Find all nodes that need to modify the input and input | |||
| update_node_map = {} | |||
| for method in ['input', 'output', 'proxy_input', 'proxy_output']: | |||
| for method in ['inputs', 'outputs', 'proxy_inputs', 'proxy_outputs']: | |||
| for target_name in getattr(node, method): | |||
| target_node = self._get_normal_node(node_name=target_name) | |||
| if target_node is None: | |||
| @@ -529,7 +529,7 @@ class Graph: | |||
| # Update the input and output of the nodes | |||
| for target_node in update_node_map.values(): | |||
| for method in ['input', 'output', 'proxy_input', 'proxy_output']: | |||
| for method in ['inputs', 'outputs', 'proxy_inputs', 'proxy_outputs']: | |||
| attr_temp = getattr(target_node, method).get(origin_name) | |||
| if attr_temp is None: | |||
| # This method does not have this node, so it is skipped | |||
| @@ -563,12 +563,12 @@ class Graph: | |||
| for scope_node in independent_layout_node_map.values(): | |||
| scope_node.independent_layout = True | |||
| method = 'output' | |||
| method = 'outputs' | |||
| for target_name, target_attr in dict(getattr(scope_node, method)).items(): | |||
| proxy_attr = dict(edge_type=target_attr['edge_type']) | |||
| target_node = self._get_normal_node(node_name=target_name) | |||
| getattr(target_node, 'add_proxy_input')(scope_node.name, proxy_attr) | |||
| getattr(target_node, 'add_proxy_inputs')(scope_node.name, proxy_attr) | |||
| # Note: | |||
| # If the source node and the destination node are not in the same scope, | |||
| @@ -581,7 +581,7 @@ class Graph: | |||
| else: | |||
| target_scope_node = self._get_normal_node(node_name=target_node.scope) | |||
| getattr(scope_node, f'add_proxy_{method}')(target_node.scope, proxy_attr) | |||
| getattr(target_scope_node, 'add_proxy_input')(scope_node.name, proxy_attr) | |||
| getattr(target_scope_node, 'add_proxy_inputs')(scope_node.name, proxy_attr) | |||
| for subnode in subnode_map[scope_node.name]: | |||
| subnode.independent_layout = True | |||
| @@ -593,6 +593,6 @@ class Graph: | |||
| else: | |||
| getattr(subnode, f'add_proxy_{method}')(target_node.scope, proxy_attr) | |||
| input_attr = getattr(target_node, 'input')[subnode.name] | |||
| input_attr = getattr(target_node, 'inputs')[subnode.name] | |||
| input_attr['independent_layout'] = True | |||
| target_node.add_input(subnode.name, input_attr) | |||
| target_node.add_inputs(subnode.name, input_attr) | |||
| @@ -225,7 +225,7 @@ class MSGraph(Graph): | |||
| 'data_type': '' | |||
| } | |||
| node.add_input(src_name=input_proto.name, input_attr=input_attr) | |||
| node.add_inputs(src_name=input_proto.name, input_attr=input_attr) | |||
| def _parse_attributes(self, attributes, node): | |||
| """ | |||
| @@ -246,8 +246,8 @@ class MSGraph(Graph): | |||
| def _update_input_after_create_node(self): | |||
| """Update the input of node after create node.""" | |||
| for node in self._normal_node_map.values(): | |||
| for src_node_id, input_attr in dict(node.input).items(): | |||
| node.delete_input(src_node_id) | |||
| for src_node_id, input_attr in dict(node.inputs).items(): | |||
| node.delete_inputs(src_node_id) | |||
| if not self._is_node_exist(node_id=src_node_id): | |||
| message = f"The input node could not be found by node id({src_node_id}) " \ | |||
| f"while updating the input of the node({node})" | |||
| @@ -258,19 +258,19 @@ class MSGraph(Graph): | |||
| src_node = self._get_normal_node(node_id=src_node_id) | |||
| input_attr['shape'] = src_node.output_shape | |||
| input_attr['data_type'] = src_node.output_data_type | |||
| node.add_input(src_name=src_node.name, input_attr=input_attr) | |||
| node.add_inputs(src_name=src_node.name, input_attr=input_attr) | |||
| def _update_output_after_create_node(self): | |||
| """Update the output of node after create node.""" | |||
| # Constants and parameter should not exist for input and output. | |||
| filtered_node = {NodeTypeEnum.CONST.value, NodeTypeEnum.PARAMETER.value} | |||
| for node in self._normal_node_map.values(): | |||
| for src_name, input_attr in node.input.items(): | |||
| for src_name, input_attr in node.inputs.items(): | |||
| src_node = self._get_normal_node(node_name=src_name) | |||
| if src_node.type in filtered_node: | |||
| continue | |||
| src_node.add_output(node.name, input_attr) | |||
| src_node.add_outputs(node.name, input_attr) | |||
| @staticmethod | |||
| def _get_data_type_name_by_value(data_type, value, field_name='data_type'): | |||
| @@ -103,7 +103,7 @@ class Node: | |||
| self._attr.update(attr_dict) | |||
| @property | |||
| def input(self): | |||
| def inputs(self): | |||
| """ | |||
| Get all input of current node. | |||
| @@ -112,7 +112,7 @@ class Node: | |||
| """ | |||
| return self._input | |||
| def add_input(self, src_name, input_attr): | |||
| def add_inputs(self, src_name, input_attr): | |||
| """ | |||
| Update input. | |||
| @@ -127,7 +127,7 @@ class Node: | |||
| """ | |||
| self._input.update({src_name: input_attr}) | |||
| def delete_input(self, src_name): | |||
| def delete_inputs(self, src_name): | |||
| """ | |||
| Delete input attribute by the given source name. | |||
| @@ -137,11 +137,11 @@ class Node: | |||
| self._input.pop(src_name) | |||
| @property | |||
| def output(self): | |||
| def outputs(self): | |||
| """The output node of this node.""" | |||
| return self._output | |||
| def add_output(self, dst_name, output_attr): | |||
| def add_outputs(self, dst_name, output_attr): | |||
| """ | |||
| Add a output node to this node. | |||
| @@ -151,7 +151,7 @@ class Node: | |||
| """ | |||
| self._output.update({dst_name: output_attr}) | |||
| def delete_output(self, dst_name): | |||
| def delete_outputs(self, dst_name): | |||
| """ | |||
| Delete a output node. | |||
| @@ -161,11 +161,11 @@ class Node: | |||
| self._output.pop(dst_name) | |||
| @property | |||
| def proxy_input(self): | |||
| def proxy_inputs(self): | |||
| """Return proxy input, type is dict.""" | |||
| return self._proxy_input | |||
| def add_proxy_input(self, src_name, attr): | |||
| def add_proxy_inputs(self, src_name, attr): | |||
| """ | |||
| Add a proxy input to node. | |||
| @@ -177,16 +177,16 @@ class Node: | |||
| """ | |||
| self._proxy_input.update({src_name: attr}) | |||
| def delete_proxy_input(self, src_name): | |||
| def delete_proxy_inputs(self, src_name): | |||
| """Delete a proxy input by the src name.""" | |||
| self._proxy_input.pop(src_name) | |||
| @property | |||
| def proxy_output(self): | |||
| def proxy_outputs(self): | |||
| """Get proxy output, data type is dict.""" | |||
| return self._proxy_output | |||
| def add_proxy_output(self, dst_name, attr): | |||
| def add_proxy_outputs(self, dst_name, attr): | |||
| """ | |||
| Add a proxy output to node. | |||
| @@ -198,7 +198,7 @@ class Node: | |||
| """ | |||
| self._proxy_output.update({dst_name: attr}) | |||
| def delete_proxy_output(self, dst_name): | |||
| def delete_proxy_outputs(self, dst_name): | |||
| """Delete a proxy output by dst name.""" | |||
| self._proxy_output.pop(dst_name) | |||
| @@ -179,7 +179,7 @@ class DebuggerGraph(MSGraph): | |||
| if tensors_info: | |||
| tensor_history.extend(tensors_info) | |||
| if cur_depth < depth: | |||
| for name in cur_node.input.keys(): | |||
| for name in cur_node.inputs.keys(): | |||
| trace_list.append((self._leaf_nodes[name], cur_depth + 1)) | |||
| return tensor_history, cur_outputs_nums | |||
| @@ -208,7 +208,7 @@ class DebuggerGraph(MSGraph): | |||
| def _get_input_tensors_of_node(self, cur_node): | |||
| """Get input tensors of node.""" | |||
| tensors_info = [] | |||
| for name in cur_node.input.keys(): | |||
| for name in cur_node.inputs.keys(): | |||
| node = self._leaf_nodes.get(name) | |||
| tensor_info = self._get_tensor_infos_of_node(node) | |||
| tensors_info.extend(tensor_info) | |||
| @@ -258,12 +258,12 @@ class DebuggerGraph(MSGraph): | |||
| continue | |||
| bfs_order.append(node_name) | |||
| if node.input: | |||
| for name in node.input.keys(): | |||
| if node.inputs: | |||
| for name in node.inputs.keys(): | |||
| if name not in temp_list and name not in bfs_order: | |||
| temp_list.append(name) | |||
| if node.output: | |||
| for name in node.output.keys(): | |||
| if node.outputs: | |||
| for name in node.outputs.keys(): | |||
| if name not in temp_list and name not in bfs_order: | |||
| temp_list.append(name) | |||