Browse Source

MindConverter adds generator and its required dependencies.

Add global context to exchange information among multiple classes to reduce passing arguments through multiple procedures.
Add Node and Module Struct to store all information converted for MindSpore script generation
Add Args translator and scope utils to help process scope information and operators' arguments
Add generator to generate the MindSpore script from information stored in Node and Module struct.
tags/v1.1.0
liangtianshu 5 years ago
parent
commit
03b2867978
11 changed files with 2395 additions and 10 deletions
  1. +40
    -7
      mindinsight/mindconverter/graph_based_converter/common/global_context.py
  2. +15
    -0
      mindinsight/mindconverter/graph_based_converter/common/utils.py
  3. +111
    -0
      mindinsight/mindconverter/graph_based_converter/generator/__init__.py
  4. +248
    -0
      mindinsight/mindconverter/graph_based_converter/generator/args_translator.py
  5. +630
    -0
      mindinsight/mindconverter/graph_based_converter/generator/generator.py
  6. +710
    -0
      mindinsight/mindconverter/graph_based_converter/generator/module_struct.py
  7. +423
    -0
      mindinsight/mindconverter/graph_based_converter/generator/node_struct.py
  8. +157
    -0
      mindinsight/mindconverter/graph_based_converter/generator/scope_utils.py
  9. +48
    -0
      mindinsight/mindconverter/graph_based_converter/hierarchical_tree/name_mgr.py
  10. +1
    -1
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py
  11. +12
    -2
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py

+ 40
- 7
mindinsight/mindconverter/graph_based_converter/common/global_context.py View File

@@ -40,8 +40,10 @@ class GlobalContext(metaclass=Singleton):
# Define data stored from onnx_utils
# Key as Onnx Name
self._onnx_nodes_collection = OrderedDict()
# key is topo_idx, value is onnx_node_name.
# key is topo_idx, value is onnx_node_name
self._onnx_nodes_topo_index = dict()
self.onnx_node_name_to_topo_idx = dict()
self.onnx_node_inputs = dict()
self._onnx_tensors_collection = dict()

# Define data stored from generator
@@ -50,7 +52,7 @@ class GlobalContext(metaclass=Singleton):
self.node_struct_adder_counter = 0
# Define onnx_utils <---> generator mapping
self.node_struct_to_onnx_node_map = dict()
self.onnx_node_to_node_struct_map = dict()
self.onnx_node_name_to_node_struct_map = dict()

# Define Module pattern to customize name mapping
self.module_customized_name = dict()
@@ -59,6 +61,8 @@ class GlobalContext(metaclass=Singleton):
self.node_fragments = OrderedDict()
self.module_fragments = OrderedDict()

# Define Known module mapping
self.known_module_name = dict()
# Define Structs
# key is pattern_id, value is [ModuleStructs]
self.module_structs = dict()
@@ -83,7 +87,7 @@ class GlobalContext(metaclass=Singleton):

def get_identifier_from_onnx_node_name(self, node_name):
"""Return the node identifier by Onnx Node name."""
identifier = self.onnx_node_to_node_struct_map.get(node_name)
identifier = self.onnx_node_name_to_node_struct_map.get(node_name)
return identifier

@property
@@ -98,9 +102,7 @@ class GlobalContext(metaclass=Singleton):

@onnx_nodes_collection.setter
def onnx_nodes_collection(self, arg):
"""
Set the onnx nodes collection.
"""
"""Set the onnx nodes collection."""
if isinstance(arg, OrderedDict):
self._onnx_nodes_collection = arg # arg must be nodes_dict in OnnxDataLoader
else:
@@ -108,11 +110,18 @@ class GlobalContext(metaclass=Singleton):

@property
def onnx_nodes_topo_index(self) -> dict:
"Return the onnx nodes and topological index."
"""Return the onnx nodes and topological index."""
return self._onnx_nodes_topo_index

@onnx_nodes_topo_index.setter
def onnx_nodes_topo_index(self, index_list):
"""
Set the onnx nodes and topological index.

Args:
index_list (list[tuple[int, str]]): a list of tuple contains the topological index and onnx node name.

"""
if not isinstance(index_list, list):
raise TypeError("The argument index_list must be a list of tuple (index, onnx_node_name).")
if not isinstance(index_list[0], tuple):
@@ -122,10 +131,17 @@ class GlobalContext(metaclass=Singleton):

@property
def onnx_tensors_collection(self):
"""Return the onnx tensors collection."""
return self.onnx_tensors_collection

@onnx_tensors_collection.setter
def onnx_tensors_collection(self, arg):
"""
Set the onnx tensors collection by OnnxDataLoader.

Args:
arg (dict): The OnnxDataLoader generated tensors_dict.
"""
if isinstance(arg, dict):
self._onnx_tensors_collection = arg # arg must be tensors_dict in OnnxDataLoader
else:
@@ -133,6 +149,12 @@ class GlobalContext(metaclass=Singleton):

@property
def latest_node_struct_count(self):
"""
Return the latest node struct count.

Note:
The counter will increase by 1 to tracking the number of nodes added.
"""
ret = self.node_struct_adder_counter
self.node_struct_adder_counter += 1
return ret
@@ -184,18 +206,29 @@ class GlobalContext(metaclass=Singleton):
self.module_customized_name[pattern_id] = customized_name

def get_node_fragment(self, identifier):
"""Return the node fragment by identifier."""
return self.node_fragments.get(identifier)

def add_code_fragment(self, identifier, frag):
"""Add the node fragment by identifier."""
self.node_fragments[identifier] = frag

def get_module_fragment(self, identifier):
"""Return the module fragment by identifier."""
return self.module_fragments.get(identifier)

def add_module_fragment(self, identifier, frag):
"""Add the module fragment by identifier."""
self.module_fragments[identifier] = frag

def add_module_struct(self, pattern_id, module_struct):
"""
Add module struct by its pattern_id.

Args:
pattern_id (int): The pattern which represents the structure of the module.
module_struct (ModuleStruct): The ModuleStruct instance.
"""
if self.module_structs.get(pattern_id) is None:
self.module_structs[pattern_id] = [module_struct]
else:


+ 15
- 0
mindinsight/mindconverter/graph_based_converter/common/utils.py View File

@@ -135,3 +135,18 @@ def lib_version_satisfied(current_ver: str, mini_ver_limited: str,
if current_ver < mini_ver_limited or (newest_ver_limited and current_ver > newest_ver_limited):
return False
return True
def get_dict_key_by_value(val, dic):
"""
Return the first appeared key of a dictionay by given value.

Args:
val (Any): Value of the key.
dic (dict): Dictionary to be checked.

Returns:
Any, key of the given value.
"""
for d_key, d_val in dic.items():
if d_val == val:
return d_key
return None

+ 111
- 0
mindinsight/mindconverter/graph_based_converter/generator/__init__.py View File

@@ -0,0 +1,111 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Generator module."""
__all__ = ["batch_add_nodes"]

import re
import copy

from .generator import Generator, CodeStruct
from ..common.code_fragment import CodeFragment


def _tf_model_node_name_reformat(node, node_name):
"""
Rename the node name by combining scope name and its original name.

Args:
node (OnnxGraphNode): OnnxGraphNode instance.
node_name (str): node name saved in Graph.

Returns:
str, re-formatted node name.
"""
scope_name = node.scope_name
new_name = None
regex = r"(?P<parent>.+/)(?P<op>\w+)"
match = re.match(regex, scope_name)
parent = match.group("parent")
node_name = '$' + node_name.replace('/', '::') + '$'

if scope_name:
new_name = parent + node_name
return new_name
return node_name


def batch_add_nodes(graph_obj, mapper) -> Generator:
"""
Add nodes to Generator in batch mode.

Args:
graph_obj (Graph): Graph obj.
mapper (Mapper): Mapper of third party framework and MindSpore.

"""
generator_inst = Generator()
for node_name in graph_obj.nodes_in_topological_order:
node_inst = graph_obj.get_node(node_name)
node_input = graph_obj.get_input_shape(node_name)
node_output = graph_obj.get_output_shape(node_name)
if not node_input:
raise ValueError("Unable to get the node's inputs from Graph object.")
node_name_with_scope = _tf_model_node_name_reformat(node_inst, node_name)
node_name = node_name_with_scope

node_inst.add_input_and_output_shape(node_input, node_output)
op_name, params, settings, weights = _convert_params(node_inst, mapper)
generator_inst.add_node(
node_name,
node_instance=node_inst,
node_fragment=CodeFragment(op_name, params,
settings,
node_inst.input_shape,
node_inst.output_shape,
weights)
)
return generator_inst


def _convert_params(node, mapper):
"""
Call mapper to convert node's params from ONNX to MindSpore.

Args:
node (GraphNode): Our defined GraphNode instance.
mapper (Mapper): The mapper instance which indicating conversion method.

Returns:
str, op name in MindSpore
dict, MindSpore parameters
dict, MindSpore settings
dict, weights of the node
"""
params = copy.deepcopy(node.node_params)
params.update({"input_shape": node.input_shape,
"output_shape": node.output_shape})

op_in_ms, ms_params, ms_settings, weights = mapper.convert(op_name=node.op_name,
params=params,
weights=node.weight)
if "input_shape" in ms_params:
ms_params.pop("input_shape")
if "output_shape" in ms_params:
ms_params.pop("output_shape")

if op_in_ms:
return op_in_ms, ms_params, ms_settings, weights

return node.op_name, node.node_params, dict(), dict()

+ 248
- 0
mindinsight/mindconverter/graph_based_converter/generator/args_translator.py View File

@@ -0,0 +1,248 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Define arguments translation related operations for params name changing."""
import copy


class ArgsTranslation:
"""Define a universal arguments translation manager."""

def __init__(self, original_actual_args: dict, var_name: str, translated_args: list):
"""
Init the ArgsTranslation.

Args:
original_actual_args (dict): The full original args from fragments.
var_name (str): The var name for current Node / Module.
translated_args (list): The list of args need to translate to formal args.
"""
if not var_name:
raise ValueError("Initialize ArgsTranslation requires the var_name.")

self.var_name = var_name
self.actual_args = dict() # e.g. key is 'num_features', value is 2048
self.formal_args = dict() # e.g. key is 'num_features', value is 'var_name_num_features'}
self.formal_args_values = dict() # e.g. key 'var_name_num_features', value 2048. Value use for up-level
self.actual_args_backup = dict() # backup actual args before translation

self.actual_args_to_str_list = list()
self.formal_args_to_str_list = list()
self.formal_args_values_to_str_list = list()
self.actual_args_backup_to_str_list = list()

if all([original_actual_args, translated_args]):
# MUST ensure only one var_name in a scope.
for arg_name, arg_value in original_actual_args.items():
if arg_name in translated_args:
formal_arg_name = '_'.join([var_name, arg_name])
self.formal_args[arg_name] = formal_arg_name
self.formal_args_values[formal_arg_name] = arg_value
else:
self.actual_args[arg_name] = arg_value

self.make_str()

@staticmethod
def dict_data_to_args_str_list(any_dict):
"""
Output a list of string of dict data by "key=value" format.

Args:
any_dict (dict): Any dictionary

Returns:
list, the list of strings showing dictionary as "key=value" format.
"""
ret = []
for key, val in any_dict.items():
ret.append('='.join([key, str(val)]))
return ret

def make_str(self):
"""Make string used in code generation."""
self.actual_args_to_str_list = list()
self.formal_args_to_str_list = list()
self.formal_args_values_to_str_list = list()
self.actual_args_backup_to_str_list = list()

if self.actual_args:
self.actual_args_to_str_list = ArgsTranslation.dict_data_to_args_str_list(self.actual_args)

if self.formal_args:
self.formal_args_to_str_list = ArgsTranslation.dict_data_to_args_str_list(self.formal_args)

if self.formal_args_values:
self.formal_args_values_to_str_list = ArgsTranslation.dict_data_to_args_str_list(self.formal_args_values)

if self.actual_args_backup:
self.actual_args_backup_to_str_list = ArgsTranslation.dict_data_to_args_str_list(self.actual_args_backup)

def __repr__(self):
return str({
"address": hex(id(self)),
"var_name": self.var_name,
"actual_args": self.actual_args,
"actual_bak": self.actual_args_backup,
"formal_args": self.formal_args,
"formal_val ": self.formal_args_values
})

def set_actual_args_backup(self):
"""Backup the actual args before translating to formal."""
self.actual_args_backup = copy.deepcopy(self.actual_args)

def deepcopy(self):
"""Return a deepcopy of self."""
return copy.deepcopy(self)

def make_actual_arg_to_formal(self, actual_arg_name):
"""
Make the actual arg to a formal arg.

Args:
actual_arg_name (str): The name of the actual arg to be formal.
"""
val = self.actual_args.get(actual_arg_name)
if val is None:
raise ValueError("Unable to convert the actual arg to formal due to missing arg.")
formal_arg_name = ('_').join([self.var_name, actual_arg_name])
self.actual_args.pop(actual_arg_name)
self.formal_args[actual_arg_name] = formal_arg_name
self.formal_args_values[formal_arg_name] = val
self.make_str()

def _update_dict_for_upper_level(self, d, upper_level_var_name):
"""Add upper level var name to key name of selected dictionary."""
new_d = dict()
for arg_name, val in d.items():
new_arg_name = '_'.join([upper_level_var_name, arg_name]) # e.g. conv2d_0_in_channels_Module_3_0
new_d[new_arg_name] = val
return new_d

def escalate_to_upper_level(self, upper_level_var_name):
"""
Escalate this args translator for upper level module use.

Note:
You MUST deepcopy this translator first to avoid editing values in the original translator.
"""
# update all args name by adding upper_level_var_name.
tmp_actual_args = self._update_dict_for_upper_level(self.actual_args, upper_level_var_name)
tmp_formal_args = self._update_dict_for_upper_level(self.formal_args, upper_level_var_name)
tmp_formal_args_values = self._update_dict_for_upper_level(self.formal_args_values, upper_level_var_name)

self.actual_args = tmp_actual_args
self.formal_args = tmp_formal_args
self.formal_args_values = tmp_formal_args_values

self.make_str()

def make_formal_args_back_to_actual(self, formal_arg):
"""
Move the formal arg back to actual arg.

Note:
This does not reset the formal arg name back,
Only used for module init statement.

Args:
formal_arg (str): formal argument name.
"""
if isinstance(formal_arg, str):
val = self.formal_args_values.pop(formal_arg)
self.actual_args[formal_arg] = val
if isinstance(formal_arg, list):
for arg in formal_arg:
val = self.formal_args_values.pop(arg)
self.actual_args[formal_arg] = val

self.make_str()

def take_formal_args_from_args_translator(self, args_translator, escalate_sub=False):
"""
Add submodule's or node's args translator to this translator.

Args:
args_translator (ArgsTranslation): submodule's or node's args translator.
"""
if escalate_sub:
sub_args_translator = args_translator.deepcopy()
sub_args_translator.escalate_to_upper_level(self.var_name)
else:
sub_args_translator = args_translator

original_actual_args = sub_args_translator.formal_args_values
self.actual_args.update(original_actual_args)
self.make_str()

def take_formal_args_from_nodes_and_submodules(self, args_translators: list, escalate_sub=False):
"""
Take all formal args from nodes and submodules from passed in args_translators.

Args:
args_translators (ArgsTranslation): A list of ArgsTranslation instances.
escalate_sub (Bool): should escalate all formal args. Default: False
"""
for arg_t in args_translators:
self.take_formal_args_from_args_translator(arg_t, escalate_sub=escalate_sub)


class ArgsTranslationHelper:
"""Define operations related to ArgsTranslation instances."""
@staticmethod
def find_formal_args_in_modules(args_translators):
"""
Find formal args among multiple args translators.

Args:
args_translators(list[ArgsTranslation]): a list of args translator to be checked.

Returns:
list, name of args to be formal.
"""
if len(args_translators) < 2:
# only one args_translator provided, no formal args.
return None
ret = []
base_args_t = args_translators[0]
for arg_name, arg_val in base_args_t.actual_args.items():
for args_t in args_translators[1:]:
val = args_t.actual_args.get(arg_name)

if val is None:
raise ValueError("Unable to find the given args as the args translator is not consistent.")
if val != arg_val: # val not equal
ret.append(arg_name)
break
return ret

@staticmethod
def change_args_to_formal_for_all_translators(args_name, args_translators):
"""
Change args to formal for all translators provided.

Args:
args_name (str): The name of args to be changing.
args_translators (ArgsTranslation): The args to be changed in args translators.
"""
if isinstance(args_name, str):
args_name = [args_name]
if isinstance(args_translators, ArgsTranslation):
args_translators = [args_translators]

for arg in args_name:
for args_t in args_translators:
args_t.set_actual_args_backup()
args_t.make_actual_arg_to_formal(arg)

+ 630
- 0
mindinsight/mindconverter/graph_based_converter/generator/generator.py View File

@@ -0,0 +1,630 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Main Generator module."""
import copy
from collections import OrderedDict

from .scope_utils import Scope
from .node_struct import NodeStruct
from .module_struct import ModuleStruct
from .args_translator import ArgsTranslationHelper
from ..common.global_context import GlobalContext
from ..hierarchical_tree.name_mgr import GlobalVarNameMgr
from ..constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT


class Singleton(type):
"""Metaclass to make the generator to be single instance."""
_instances = {}

def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]


class CodeStruct:
"""
Define the Code template for each module generated in the final output.
Each module has only one CodeStruct to its pattern.
"""
GLOBAL_CONTEXT = GlobalContext()
NOT_IN_SCOPE_OPT = dict()

def __init__(self, struct, repeated_submodules=None):
"""Initialize the CodeStruct."""
self.output_order = None # output order
self.input = None # opt_var_name for prev. node
self.extra_input = list() # extra_input(s) at construct method args
self.output = None # opt_var_name for next node
self.extra_output = list() # extra_output(s)
self.extra_comment = None # comments for this code line / block.
self.code_line_list = list() # list of code line, a item is a line.
self._global_var_mgr = GlobalVarNameMgr() # var name procs within same module

self.formal_args_collections = None

if isinstance(struct, NodeStruct):
self.output_order = struct.topo_idx
if isinstance(struct, ModuleStruct):
self.output_order = struct.head_nd_struct_index
self._generate_from_module_struct(struct, repeated_submodules)

def _add_line(self, s):
"""Add line of code."""
self.code_line_list.append(s)

@property
def new_line(self):
"""Return last generated line."""
try:
return self.code_line_list[-1]
except IndexError:
return ""

@new_line.setter
def new_line(self, s):
"""Make a new line."""
self._add_line(s)

def _generate_from_module_struct(self, md_struct, repeated_submodules):
"""
Generate the code of current Module Struct, collecting data from submodules.

Args:
md_struct (ModuleStruct): The ModuleStruct which generates codes.
repeated_submodules (dict): The dict contains all submodules which use repeatedly.
Can get this dict from generator.
"""
# Define tmp var for code generation.
opt_var_name_records = dict() # now only support multiple outputs within same scope.
return_value_records = dict() # save returned values for successor nodes/modules use.
# Define Module header code line below
if md_struct.pattern_id != -1:
class_name = f"Module{md_struct.pattern_id}"
else:
class_name = "Model"
# define a class declaration
self.new_line = f"class {class_name}(nn.Cell):"

# Get all formal args from nodes
module_def_args = ['self']
if md_struct.args_translator.actual_args:
for actual in md_struct.args_translator.actual_args.keys():
module_def_args.append(actual)
if md_struct.args_translator.formal_args:
for formal in md_struct.args_translator.formal_args.keys():
module_def_args.append(formal)

# Collect extra inputs and outputs

# For code line in init & construct blocks
init_lines = list()
cons_lines = list()
for (_, struct) in md_struct.get_generate_order():
if isinstance(struct, NodeStruct): # Generate code line for Node.
code_line_init = struct.code_line_in_init()
code_line_construct = struct.code_line_in_construct(in_module_returns=return_value_records)
init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_init)}")
cons_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_construct)}")
# add extra tensor
if struct.fragment.code_setting and struct.fragment.code_setting.op_extra_tensor:
code_extra_tensor = struct.add_extra_tensor()
init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_extra_tensor)}")

# record opt_var_name for succ nodes input in same scope.
target_onnx_name = struct.graph_node_ref.successor_nodes
for name in target_onnx_name:
if opt_var_name_records.get(name):
opt_var_name_records.get(name).append(code_line_construct[0])
else:
opt_var_name_records[name] = [code_line_construct[0]]

if struct.successor_nodes_names_external:
for ret_user in struct.successor_nodes_names_external:
if return_value_records.get(ret_user) is not None:
return_value_records[ret_user].append((struct.onnx_name, code_line_construct[0]))
else:
return_value_records[ret_user] = [(struct.onnx_name, code_line_construct[0])]

elif isinstance(struct, ModuleStruct):
# check if this instance generated CodeStruct
if self.GLOBAL_CONTEXT.code_structs.get(struct.pattern_id) is None:
CodeStruct(struct, repeated_submodules)

code_line_init = struct.code_line_in_init()
code_line_construct = struct.code_line_in_construct(inputs=struct.matched_inputs)
init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_init)}")
cons_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_construct)}")

# record opt_var_name for succ nodes input in same scope.
target_onnx_name = struct.tail_nd_struct.graph_node_ref.successor_nodes
for name in target_onnx_name:
if opt_var_name_records.get(name):
opt_var_name_records.get(name).append(code_line_construct[0])
else:
opt_var_name_records[name] = [code_line_construct[0]]

# record submodule's local return map for following nodes / submodules use
if struct.external_successor_local_returns_map:
for ret_user, _ in struct.external_successor_local_returns_map.items():
if return_value_records.get(ret_user) is not None:
# mulitple returns of a node may need modifiy the index.
return_value_records[ret_user].append((struct.identifier, code_line_construct[0]))
else:
return_value_records[ret_user] = [(struct.identifier, code_line_construct[0])]
else:
raise TypeError("Unable to generate code from args are not ModuleStruct or NodeStruct.")

# define header of init block
self.new_line = f"{FIRST_LEVEL_INDENT}def __init__({', '.join(module_def_args)}):"
self.new_line = f"{SECOND_LEVEL_INDENT}super({class_name}, self).__init__()"
# add init code lines to code line list.
self.code_line_list += init_lines
self.new_line = f"{NEW_LINE * 2}"

# define header of construct block
inputs = ['self'] + list(md_struct.construct_header_x.keys())
self.new_line = f"{FIRST_LEVEL_INDENT}def construct({', '.join(inputs)}):"
# add construct code lines to code line list.
self.code_line_list += cons_lines
# define returns
returns = []
if md_struct.external_successor_local_returns_map:
ret = list(md_struct.external_successor_local_returns_map.values())
for r in ret:
if isinstance(r, tuple): # results return with index nth output
returns.append(r[0])
else:
returns.append(r)
returns = list(set(returns))
else:
returns = [code_line_construct[0]]
self.new_line = f"{SECOND_LEVEL_INDENT}return {', '.join(returns)}"
self.new_line = f"{NEW_LINE * 2}"
self.GLOBAL_CONTEXT.code_structs[md_struct.pattern_id] = self


class Generator(metaclass=Singleton):
"""The generator controls all routines of code generation."""

def __init__(self):
"""Init the generator."""
# define basic attributes
self.framework = None

# define MUST have params
self._node_struct_collections = OrderedDict()
self._module_struct_collections = OrderedDict()
self._module_depth_max = 0
self._module_depth_min = 0

# define intermediate var. during conversion
self._module_map = OrderedDict()
self._global_context = GlobalContext()
self._global_context.node_struct_collections = self._node_struct_collections
self._repeated_submodules = set()

def _form_bottom_submodule(self):
"""Form the basic submodules, which only contains nodes."""
# Form module map
curr_scope_path = None
nd_struct_list_in_submodule = []
for nd_struct in self.node_structs.values():
idx = nd_struct.topo_idx
if curr_scope_path is None:
curr_scope_path = nd_struct.scope.path
nd_struct_list_in_submodule.append((idx, nd_struct))
elif curr_scope_path == nd_struct.scope.path:
nd_struct_list_in_submodule.append((idx, nd_struct))
else: # curr_scope_path changed
# save this submodule
if self._module_map.get(str(curr_scope_path)) is not None:
self._module_map[str(curr_scope_path)] += nd_struct_list_in_submodule
else:
self._module_map[str(curr_scope_path)] = nd_struct_list_in_submodule

# create a new one
curr_scope_path = nd_struct.scope.path
nd_struct_list_in_submodule = [(idx, nd_struct)]

# save last submodule
if self._module_map.get(str(curr_scope_path)) is not None:
self._module_map[str(curr_scope_path)] += nd_struct_list_in_submodule
else:
self._module_map[str(curr_scope_path)] = nd_struct_list_in_submodule

# Form bottom modules' ModuleStruct
for scope_path_str, nd_struct_list in self._module_map.items():
self._module_struct_collections[scope_path_str] = ModuleStruct(nd_struct_list)

def _list_repeated_submodules(self) -> OrderedDict:
"""
Return the repeated submodules by its depth and num.
For example, "Model/Module3_3" will return {1:(3)}

Return:
OrderedDict, a dict contains collections of repeated submodules.
"""
ret = OrderedDict()
for depth_control in range(self._module_depth_max, 0, -1):
repeated_submodules_at_this_depth = set()
for scope_path in self._module_map.keys():
path = Scope.path_str_to_list(scope_path)
if len(path) < depth_control:
continue
else: # depth control within path length
module_num = path[depth_control - 1][0]
repeated_submodules_at_this_depth.add(module_num)
ret[depth_control] = repeated_submodules_at_this_depth

self._repeated_submodules = ret
return ret

def _compare_with_base_parameters(self, nd_struct_list):
"""
Compare the parameter to check if it should be a formal args.

Args:
nd_struct_list (list): A list of NodeStructs which contains
all same nodes in repeated submodules.

Return:
set, a set of all formal args in this node.
"""

formal_args = set()
if len(nd_struct_list) < 2:
return formal_args
(_, base_nd_struct) = nd_struct_list[0]
for (base_parameter, base_value) in base_nd_struct.fragment.actual_args.items(): # for each param
for (_, nd_struct) in nd_struct_list[1:]:
compared_value = nd_struct.fragment.actual_args.get(base_parameter)
if compared_value == base_value:
continue
else:
formal_args.add(base_parameter)
break

return formal_args

def _list_formal_parameters_in_a_module(self, module_filter_return):
"""
Find all formal args / params from nodes in a module.

Args:
module_filter_return (dict): The filtered results from the module_map_filter.

Return:
list, a list of sets or None indicates all formal args of each node in the module in order.
"""
formal_params_list = list()
transposed = [list(e) for e in zip(*module_filter_return)]
for operation in transposed:
formal_parameters = self._compare_with_base_parameters(operation)
if formal_parameters:
formal_params_list.append(formal_parameters)
else:
formal_params_list.append(None)
return formal_params_list

def _list_formal_parameters(self, repeated_submodules) -> dict:
"""
Return a list of formal parameters in each submodule.

Args:
repeated_submodules (dict): A dict which contains repeated submodules,
acquire this dict from _list_repeated_submodules()

Return:
OrderedDict, a dict with each submodule's formal args.

Example:
A return for ResNet50 could be:

{0: # submoodule 0
[set('stride', 'in_channels', 'out_channels'), # args of the first node to be set as formal
set('num_features'), # args of the second node to be set as formal
None, # args of third node to be set as formal, which does not have
set('in_channels', 'out_channels'),
set('num_features'),
None
]},
{3: # submodule 3
[...],
{5: # submodule 5
[]} # empty returns means no nodes or it's a parent module of submodules.
}
"""
formal_args_in_each_submodule = OrderedDict()
checked_module = set()
# filter module_map by submodule_num (without depth)
for _, module_nums in repeated_submodules.items():
for module_num in module_nums:
if module_num in checked_module: # module already checked
continue
else:
checked_module.add(module_num)
map_filtered = self.module_map_filter(module_num=module_num)
formal_args_in_this_module = self._list_formal_parameters_in_a_module(map_filtered)
formal_args_in_each_submodule[module_num] = formal_args_in_this_module
return formal_args_in_each_submodule

def _add_submodule_to_parent(self):
"""
Recursively add all submodule to its parent module until Main module.

Note:
This function deepcopy the first node of the submodule, and reset its params as parent module.
"""
depth = self._module_depth_max
while depth > 0:
for (scope_path_str, md_struct) in self.module_structs.copy().items():
if scope_path_str == '[]':
continue # is main module, skip
if md_struct.scope_depth != depth:
continue # skip all submodules not at current depth
md_struct_scope = copy.deepcopy(md_struct.identifier)
md_struct_scope.pop()
parent_scope = md_struct_scope
# 1. check if this module has parent module
parent_md_struct = self.module_structs.get(str(parent_scope))
if parent_md_struct is not None:
# 1A. has parent, directly add md_struct to its parent ModuleStruct.
parent_md_struct.add_submodule(md_struct)
self.module_structs[str(parent_scope)] = parent_md_struct
else:
# 1B. not has parent, generate a new ModuleStruct
parent_md_struct = copy.deepcopy(md_struct) # use this submodule to create a parent module
# rewrite parent md struct
parent_md_struct.reset_as_parent()
parent_md_struct.add_submodule(md_struct)
self.module_structs[str(parent_scope)] = parent_md_struct
sub = self.module_structs.pop(scope_path_str) # remove this submodule from collections
self._global_context.add_module_struct(sub.pattern_id, sub)
depth -= 1

def _recursive_form_module(self):
"""Main routine in generator to build modules from bottom to top."""
# 1. List repeated submodules
repeated_submodules = self._list_repeated_submodules()
# 2. List reused parameters
formal_parameters = self._list_formal_parameters(repeated_submodules)
# 3. Build base subdmodules and set in/ext params translation
for module_struct in self.module_structs.values():
if module_struct.pattern_id == -1: # is main module
continue
formal_args = formal_parameters.get(module_struct.pattern_id)
module_struct.update_args_translation_list(formal_args)

# 4. Form parent modules
md_collection_len = len(self.module_structs.keys())
len_changes = True
while len_changes:
self._add_submodule_to_parent()
new_len = len(self.module_structs.keys())
if md_collection_len != new_len:
md_collection_len = new_len
else:
len_changes = False

# 5. Update all translated args from module map
self._update_all_modules_args_translator()

# 6. Update all nodes and moudles input/output
self.module_structs.get('[]').allocate_construct_header_x()
self.module_structs.get('[]').collect_returns()

def _update_all_modules_args_translator(self):
"""Update all modules' args translators."""
done_submodule = set()
for depth in range(self._module_depth_max, 0, -1):
# check modules from bottom to top
repeated_submodules = copy.deepcopy(self._repeated_submodules)
repeated_modules = repeated_submodules.get(depth)
if depth is None:
continue
for pattern_id in repeated_modules:
if pattern_id in done_submodule:
continue
# get all md_structs by same pattern
md_list = self._global_context.module_structs.get(pattern_id)
self._take_formal_args_from_updated_submodules(md_list)
args_translators = self.get_args_translator_from_module_structs_list(md_list)
formal_args_list = ArgsTranslationHelper.find_formal_args_in_modules(args_translators)
changed_args_translators = self.get_args_translator_from_module_structs_list(
md_list, exclude_root_son=True)
ArgsTranslationHelper.change_args_to_formal_for_all_translators(
formal_args_list, changed_args_translators)
done_submodule.add(pattern_id)

def _take_formal_args_from_updated_submodules(self, md_list):
"""
Take formal args from provided modules' nodes and submodules.

Args:
md_list (list): A list of ModuleStruct.
"""
if isinstance(md_list, ModuleStruct):
md_list = [md_list]

for md in md_list:
md.args_translator.take_formal_args_from_nodes_and_submodules(md.get_all_sub_translators())

def _update_module_depth_max(self, nd_struct: NodeStruct):
"""
Update the Generator attribute module_depth_max, which is the maximum depth in the Model.

Args:
nd_struct (NodeStruct): NodeStruct to be checked its depth.
"""
depth = nd_struct.scope.depth
if isinstance(depth, int):
if depth > self._module_depth_max:
self._module_depth_max = depth
else:
raise TypeError("Unable to update global depth due to TypeError in NodeStruct.scope.depth")

def add_node(self, node_identifier, node_instance=None, node_fragment=None, mapper_dict=None):
"""
Add Node information to the generator.

Args:
node_identifier (str): The unique identifier for the node passed in.
node_instance (GraphNode): The GraphNode instance of each node.
node_fragment (NodeFragment): The NodeFragment instance of this node passed in.
mapper_dict (dict): The dict contains converted attributes from mapper.
"""

if node_identifier is None:
raise ValueError("Node Identifier should not be None.")
self._global_context.node_fragments[node_identifier] = node_fragment
args = []
if node_instance is not None:
args.append(node_instance)
if mapper_dict is not None:
args.append(mapper_dict)
if node_fragment is not None:
args.append(node_fragment)

nd_struct = self.node_structs.get(node_identifier)
if nd_struct: # NodeStruct already exists
nd_struct.update(args)
else: # create new Node Struct
nd_struct = NodeStruct(args)
nd_struct.identifier = node_identifier
self._update_module_depth_max(nd_struct)
self.node_structs[node_identifier] = nd_struct

@property
def node_structs(self):
"""Return all NodeStructs in this model."""
return self._node_struct_collections

@property
def module_structs(self):
"""Return all ModuleStructs in this model."""
return self._module_struct_collections

def generate(self):
"""
Generate the final script file.

Returns:
list, a list of each line in script file.
"""
self._form_bottom_submodule()
self._recursive_form_module()
code = CodeStruct(self.module_structs.get('[]'), self._repeated_submodules)
return code.code_line_list

def get_node_struct(self, node_identifier):
"""
Get specific NodeStruct by node_identifier.

Args:
node_identifier (str): The node unique identifier.

Return:
NodeStruct, the node's NodeStruct.
"""
return self._node_struct_collections.get(node_identifier, None)

def get_module_struct(self, module_identifier):
"""
Get specific ModuleStruct by module_identifier.

Args:
module_identifier (str): The module unique identifier.

Return:
ModuleStruct, the node's ModuleStruct.
"""
return self._module_struct_collections.get(module_identifier, None)

def get_module_structs_by_pattern_under_same_parent_pattern(self, pattern_id, under_parent_pattern_id):
"""
Return a list of ModuleStruct by conditions of pattern and their parent parent's pattern.

Args:
pattern_id (int): The pattern id the returned ModuleSturct is.
under_parent_pattern_id (int): The pattern id the returned ModuleStruct's parent is.

Returns:
list, a list of MoudleStructs has the same pattern_id and the same parents' pattern.
"""
if not pattern_id:
raise ValueError("pattern_id is necessary to get the module struct.")
if not under_parent_pattern_id:
raise ValueError("under_parent_pattern_id is necessary to get the module struct.")
ret = []
md_list = self._global_context.module_structs.get(pattern_id)
for md in md_list:
if md.parent_id == under_parent_pattern_id:
ret.append(md)
return ret

def get_args_translator_from_module_structs_list(self, md_list, exclude_root_son=False):
"""
Return a list of args translators which belongs to given module structs.

Args:
md_list (list): A list of ModuleStruct.
exclude_root_son (Bool): If the returned result should include args translator belongs to
modules under the Main module.

Returns:
list, a list of args translators which belongs to given module structs.
"""
ret = []
for md in md_list:
if exclude_root_son and md.parent_id == -1:
continue
if md.args_translator is not None:
ret.append(md.args_translator)

return ret

def module_map_filter(self, depth=None, module_num=None, uid=None):
"""
Filter the module map by given conditions.

Args:
depth (int): Scope depth.
module_num (int): The submodule number.
uid (int): The unique identifier of a submodule.

Return:
list, list of NodeStruct list of each submodule.
"""
ret = list()
for scope_path, nd_struct_list in self._module_map.items():
path = Scope.path_str_to_list(scope_path)
if not path: # skip main
continue

# if depth not equals to the indicated depth, skip
if depth is not None and len(path) != depth:
continue

scope_at_depth = path[-1]
(m_num, m_uid) = scope_at_depth
if uid is not None:
if m_num == module_num and m_uid == uid:
ret.append(nd_struct_list)
else:
if m_num == module_num:
ret.append(nd_struct_list)
return ret

+ 710
- 0
mindinsight/mindconverter/graph_based_converter/generator/module_struct.py View File

@@ -0,0 +1,710 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Define a struct for module converted and save all required information here."""

from collections import OrderedDict

from .node_struct import NodeStruct
from .scope_utils import Scope
from ..common.utils import get_dict_key_by_value
from .args_translator import ArgsTranslation
from ..common.code_fragment import ModuleFragment
from ..common.global_context import GlobalContext
from ..hierarchical_tree.name_mgr import LocalVarNameMgr


class ModuleStruct:
"""
Define a module struct which stores all info. to generate statement.

Args:
args (list): A list of node structs.
"""
GLOBAL_CONTEXT_MGR = GlobalContext()

def __init__(self, nd_struct_list):
"""Init. a module by NodeStructs."""
self.pattern_id = -1 # pattern num, -1 as Main module
self.pattern_uid = -1 # unique module id for this pattern
self.parent_id = None # parent's pattern num
self.parent_uid = None # parent's pattern module unique id
self.initialized = False
self.identifier = None
self.module_name = None
self.scope_depth = None
self.head_nd_struct = None
self.head_nd_struct_index = None
self.tail_nd_struct = None
self.tail_nd_struct_index = None
self._node_structs = list()
self._module_structs = list()

self._fragment = None
self._args_translator = None
self._setting = None
self._parent_module_struct = None
# only store original formal args name, not global
self._nodes_structs_formal_args_list = list()
# only store translated (globalized) formal args
self._nodes_structs_formal_args_translated_list = list()

# define other settings here
self._node_args_translation_list = list()
self._var_name_mgr = LocalVarNameMgr()
self.construct_header_x = OrderedDict() # key is header x, value is precursors onnx name
self.inputs_in_construct_header = OrderedDict() # key is precursors onnx name, value is x in parent construct
self.inputs_in_parent_module = OrderedDict() # key is prec_node_name, value is its closet opt_var_name

# key is node's onnx name(output provider), value is (provider_succ_name, opt_var_name)
self.outputs_collection = dict()
self.matched_inputs = list() # Matched inputs will can be directly used by code line generation

# key is ext. succ node onnx name, value is local opt_var
self.external_successor_local_returns_map = OrderedDict()

# key is node's onnx_name, value is (successor_name, opt_var_name) <- node's level
self.outputs_collection = dict()

# start initialization
if not self.initialized:
self._init_module(nd_struct_list)
else:
self._update_module(nd_struct_list)

# assign this module reference to node
for (_, nd_struct) in nd_struct_list:
nd_struct.parent_module_struct = self

def reset_as_parent(self):
"""
Reset all attributes and filled as a parent module of this module.

Note:
This function must be called only after a deepcopy of this instance!
"""
self.identifier.pop()
self.scope_depth = self.scope_depth - 1
self._set_pattern_id()
self._find_parent_module()
self.module_name = Scope.scope_to_module_name(self.identifier)
self._node_structs = list()
self._module_structs = list()
self._fragment = None
self._args_translator = None
self.init_args_translator()
self._setting = None
self._parent_module_struct = None
self._nodes_structs_formal_args_list = list()

self._node_args_translation_list = list()

def _set_pattern_id(self):
"""Set pattern id which matches the module fragment pattern."""
if not self.initialized:
return
if self.scope_depth < 1:
self.pattern_id = -1
self.pattern_uid = -1
return
self.pattern_id = self.identifier[-1][0]
self.pattern_uid = self.identifier[-1][1]

def _init_module(self, nd_struct_list):
"""Init this ModuleStruct by a list of Nodes."""
(nd_topo_idx, nd_struct) = nd_struct_list[0]
self.identifier = nd_struct.scope.path
self.module_name = nd_struct.scope.to_str
self.scope_depth = nd_struct.scope.depth
self.head_nd_struct = nd_struct
self.head_nd_struct_index = nd_topo_idx
self.tail_nd_struct = nd_struct_list[-1][1]
self.tail_nd_struct_index = nd_struct_list[-1][0]
self._node_structs = nd_struct_list
self.initialized = True
self._set_pattern_id()
self._find_parent_module()
self.init_args_translator()

def _update_module(self, nd_struct_list):
"""Update the ModuleStruct attributes from a list of Nodes."""
(nd_topo_idx_head, nd_struct_head) = nd_struct_list[0]
(nd_topo_idx_tail, nd_struct_tail) = nd_struct_list[-1]
if self.identifier != nd_struct_head.scope.path:
raise ValueError("Unable to update this module struct {} due to different identifier {}".format(
self.identifier, nd_struct_head.scope.path))
if nd_topo_idx_head < self.head_nd_struct_index:
self.head_nd_struct_index = nd_topo_idx_head
self.head_nd_struct = nd_struct_head
if nd_topo_idx_tail > self.tail_nd_struct_index:
self.tail_nd_struct_index = nd_topo_idx_tail
self.tail_nd_struct = nd_struct_tail
self._node_structs += nd_struct_list

def _find_parent_module(self):
"""Set the parent's module pattern and uid."""
if not self.initialized:
return
if self.scope_depth == 0: # is Main Module
pass
elif self.scope_depth == 1: # parent pattern is Main module
self.parent_id = -1
self.parent_uid = -1
else: # this is a submodule in a module
(self.parent_id, self.parent_uid) = Scope.get_parent_module_num_and_uid(
self.identifier)

def __repr__(self):
return str({
"address": hex(id(self)),
"identifier": self.identifier,
"parent": (self.parent_id, self.parent_uid),
"name": self.module_name,
"pattern": self.pattern_id,
"scope_depth": self.scope_depth,
"nd_idx_range": "{} -> {}".format(self.head_nd_struct_index, self.tail_nd_struct_index),
"initialized": self.initialized
})

def init_module_fragment(self):
"""Init the module fragment."""
if not self.initialized:
return
# check if fragment exists in global context
op = "Module{}".format(self.pattern_id)
if op == "Module-1": # reset as Main Model's op name
op = "Model"
frag = GlobalContext().get_module_fragment(op)
if frag is not None: # use exists fragment
self._fragment = frag
else:
frag = ModuleFragment(operation=op,
actual_args=None,
input_shape=None,
output_shape=None,
settings=None)
self._fragment = frag
# set fragment pattern
self._fragment.pattern = self._node_structs
GlobalContext().add_module_fragment(op, frag)

def init_args_translator(self):
"""Initialize the Args Translator for the module."""
var_name = "Module{}_{}".format(self.pattern_id, self.pattern_uid)
self._args_translator = ArgsTranslation(None, var_name, None)

def update_module_fragment(self):
"""Update this module's fragment."""
if self._fragment is None:
return

# update input output shape
self._fragment.input_shape = self.head_nd_struct.fragment.input_shape
self._fragment.output_shape = self.tail_nd_struct.fragment.output_shape

# update formal args
self._fragment.formal_args.update(self._args_translator.formal_args)
self._fragment.formal_args_value.update(self._args_translator.formal_args_values)
# update actual args
self._fragment.actual_args.update(self._args_translator.actual_args)
# update others..

def add_submodule(self, md_structs):
"""
Add another module struct(s) to this ModuleStruct.

Args:
md_structs ([ModuleStruct, list]): a (list) ModuleStruct to be added in this ModuleStruct.
"""
tail_md = md_structs
if isinstance(md_structs, ModuleStruct):
md_structs.args_translator.take_formal_args_from_nodes_and_submodules(md_structs.get_all_sub_translators())
self._module_structs.append(md_structs)
md_structs.parent_module_struct = self
elif isinstance(md_structs, list):
for md_s in md_structs:
md_s.args_translator.take_formal_args_from_nodes_and_submodules(md_s.get_all_sub_translators())
md_s.parent_module_struct = self
self._module_structs += md_structs
tail_md = md_structs[-1]
else:
raise TypeError("ModuleStruct cannot add an unsupport Type {} to module_structs list.".format(
type(md_structs)))
# update tail node and index
if self.tail_nd_struct_index < tail_md.tail_nd_struct_index:
self.tail_nd_struct = tail_md.tail_nd_struct
self.tail_nd_struct_index = tail_md.tail_nd_struct_index

def _update_formal_args_for_all_nd_structs(self):
"""
Init nodes' args translator and find formal args.
And collect nodes' formal args.
"""
if len(self._node_args_translation_list) != len(self._node_structs):
raise ValueError(
"ModuleStruct cannot update nodes' formal args due to length inconsistent.")
for idx, (_, nd_struct) in enumerate(self._node_structs):
formal_arg_of_this_node = self._node_args_translation_list[idx]
# update var_name to ensure all node names' are unique in a module.
nd_struct.update_var_name(idx)
nd_struct.init_args_translator(formal_arg_of_this_node)
if nd_struct.args_translator is not None:
self._nodes_structs_formal_args_list.append(
nd_struct.args_translator.formal_args_values)
else:
self._nodes_structs_formal_args_list.append(None)

def update_args_translation_list(self, formal_args):
"""
Receive a list of args name to be changed to formal args, and change them.

Args:
formal_args (list[str]): a list of args name to be changed to formal args.
"""
self._node_args_translation_list = formal_args
self._update_formal_args_for_all_nd_structs()

def get_all_sub_translators(self):
"""
Return a list of args_translators of submodules / nodes affiliated to this module.

Note:
The order of returned list is followed by the actual topological order.

Returns:
list, a list of args_translators.
"""
ret = []
for (_, struct) in self.get_generate_order():
if struct.args_translator is not None:
ret.append(struct.args_translator)
return ret

def get_generate_order(self):
"""
Return the order of generated code by index.

Return:
list, a list of reference of node_struct or module_struct.
"""
ret = list()
if not self._module_structs:
return self._node_structs
# Generate a list of tuple (idx, module_structs)
for md_struct in self._module_structs:
ret.append((md_struct.head_nd_struct_index, md_struct))
if self.node_structs:
ret += self.node_structs
ret.sort(key=lambda x: x[0])
return ret

def code_line_in_init(self):
"""
Initialization line of code in module init block.

Args:
override_formal_val (dict): Indicate which args should be renamed for passing value from upper level.
"""
left = "self.{}".format(self.ms_var_name)
args_list = list()
# Load args in init statement.
if self._args_translator is not None: # from args_translator
if self._args_translator.actual_args: # load actual args
args_list += self._args_translator.actual_args_to_str_list
elif self._args_translator.actual_args_backup and self.parent_id == -1:
# For modules repeated in multiple levels, the module under main model should
# not use formal args as it is unnecessary -> load from actual args backup
args_list += self._args_translator.actual_args_backup_to_str_list
args_list += self._args_translator.formal_args_to_str_list # load from formal args
else:
args_list += self._fragment.actual_args
right = f"{self.class_name}({', '.join(args_list)})"
return (left, right)

def code_line_in_construct(self, inputs=None):
"""Construct line of code in module construct block."""
# check number of outputs this module has
opt_var_name_in_module = list(self.external_successor_local_returns_map.values())
num_output = len(set(opt_var_name_in_module))
if num_output == 1: # single output
left = f"{self.ms_opt_var_name}"
else:
left = [f"{self.ms_opt_var_name}_{num}" for num in range(num_output)]

if inputs is None and self.matched_inputs:
inputs = self.matched_inputs

if isinstance(inputs, str):
inputs = [inputs]
right = f"self.{self.ms_var_name}({', '.join(inputs)})"
return (left, right)

@property
def node_structs(self):
"""Return all node structs in this module."""
return self._node_structs

@property
def module_structs(self):
"""Return all module structs in this module."""
return self._module_structs

@property
def parent_module_struct(self):
"""Return this module's parent module struct."""
return self._parent_module_struct

@parent_module_struct.setter
def parent_module_struct(self, ref):
"""Set this modu;e's parent module struct."""
self._parent_module_struct = ref

@property
def args_translator(self):
"""Return the args translator."""
return self._args_translator

@property
def head_nd_struct_precursor_nodes_names(self) -> list:
"""Return head node's precursor nodes names."""
return self.head_nd_struct.precursor_nodes_names

@property
def head_nd_struct_precursor_nodes_structs(self) -> list:
"""Return head node's precursor nodes structs."""
return self.head_nd_struct.precursor_nodes_structs

@property
def tail_nd_struct_successor_nodes_names(self) -> list:
"""Return tail node's successor nodes names."""
return self.tail_nd_struct.successor_nodes_names

@property
def tail_nd_struct_successor_nodes_structs(self) -> list:
"""Return tail node's successor nodes structs."""
return self.tail_nd_struct.successor_nodes_structs

@property
def onnx_names_from_nodes(self) -> list:
"""Return all nodes onnx names in this module."""
ret = []
for (_, node) in self.node_structs:
ret.append(node.onnx_name)
return ret

@property
def onnx_names_from_submodules(self) -> list:
"""Return all nodes onnx names in submodules of this module."""
ret = []
for md_struct in self.module_structs:
ret += md_struct.onnx_names
return ret

@property
def onnx_names(self) -> list:
"""Return all nodes' onnx names which contained by this module."""
return self.onnx_names_from_nodes + self.onnx_names_from_submodules

@property
def external_precursor_nodes_names(self) -> list:
"""Return all precursors nodes names not in this module."""
ret = []
for _, struct in self.get_generate_order():
if isinstance(struct, NodeStruct):
precursor_nodes_names = struct.precursor_nodes_names

if isinstance(struct, ModuleStruct):
precursor_nodes_names = struct.external_precursor_nodes_names

for p_name in precursor_nodes_names:
if p_name in self.onnx_names:
continue
ret.append(p_name)
return ret

@property
def external_successor_nodes_names(self) -> list:
"""Return all precursors nodes names not in this module."""
ret = []
for _, struct in self.get_generate_order():
if isinstance(struct, NodeStruct):
successor_nodes_names = struct.successor_nodes_names

if isinstance(struct, ModuleStruct):
successor_nodes_names = struct.external_successor_nodes_names

for s_name in successor_nodes_names:
if s_name in self.onnx_names:
continue
ret.append(s_name)
return ret

@property
def class_name(self) -> str:
"""Return the class name for generating code of this module."""
if self.pattern_id == -1:
return "Model"
return "Module{}".format(self.pattern_id)

@property
def ms_var_name(self) -> str:
"""Return the variable name for generated code statement of this module."""
if self.pattern_id == -1:
return "Model"
return "Module{}_{}".format(self.pattern_id, self.pattern_uid).lower()

@property
def ms_opt_var_name(self) -> str:
"""Return the variable name for generated code statement of the output of this module."""
return "{}_opt".format(self.ms_var_name).lower()

# The following part will be resetting nodes' external inputs for supporting multi-in/out
# and should be called after generator.recursive_form_modules()

def set_inputs_in_construct_header(self, header_x, onnx_precursor_node_name):
"""
Mark the registered external inputs for code generation.

Note:
This function to be called by its parent (ModuleStruct).

Args:
header_x (str): The `x` in module construct header.
onnx_precursor_node_name (str): The original onnx node name.
"""
if self.inputs_in_construct_header.get(onnx_precursor_node_name) is not None:
raise ValueError("The input from {} has already registered. Check this Module \
{} has duplicate inputs or not.".format(onnx_precursor_node_name, self.identifier))
self.inputs_in_construct_header[onnx_precursor_node_name] = header_x

def allocate_construct_header_x(self, force_x=None):
"""
Allocate the x in construct header for each external input.

Args:
force_x (str): Force the arg name to customized.
"""
local_x_name = 'x'
if force_x: # name of x indicated by external
local_x_name = force_x

# set construct_header_x for current module
allocated = set()
for prec_name in self.external_precursor_nodes_names:
if prec_name in allocated:
continue
x_name_in_construct_header = self._var_name_mgr.get_name(local_x_name)
self.construct_header_x[x_name_in_construct_header] = prec_name
allocated.add(prec_name)

# Assign these inputs to nodes and submodules
for _, struct in self.get_generate_order():
if isinstance(struct, NodeStruct): # register node's ext input
self.reset_node_external_input_to_local(struct)
self.register_node_output_to_module(struct)
if isinstance(struct, ModuleStruct): # reg module's ext input
if not struct.construct_header_x:
struct.allocate_construct_header_x()
self.reset_submodule_external_input_to_local(struct)
self.register_submodule_output_to_module(struct)

# remove parent module's ext. map if ext nodes in this module (no need return)
for user_name in self.external_successor_local_returns_map.copy().keys():
if user_name in self.onnx_names:
self.external_successor_local_returns_map.pop(user_name)

def _match_node_inputs(self, struct):
"""Match node's inputs with its precursor nodes."""
for output_provider in struct.precursor_nodes_names:
output_list = self.outputs_collection.get(output_provider)
if output_list is None:
# not in this module, check construct header
for (self_x_name, self_output_provider) in self.construct_header_x.items():
if self_output_provider == output_provider:
struct.matched_inputs.append(self_x_name)
continue
for output in output_list:
(provider_succ, provider_closet_opt_var) = output
if provider_closet_opt_var in struct.matched_inputs:
continue # skip repeat
if provider_succ == struct.onnx_name:
struct.matched_inputs.append(provider_closet_opt_var)

def _match_sub_modules_inputs(self):
"""
Match current module's submodules' inputs with corresponding outputs registered in current module.

Description:
The function matches these inputs by the following steps:
1. For each submodule in the current module, take submodule's construct header
2. Check submodule's construct header element requires an input from current module's
construct header or outputs from other submodules.
3. If from current module's construct header, assign corresponding x to the submodule.
If from other submodules, assign required submodule output name to the submodule.
"""
if not self.outputs_collection:
return # skip first node
for (_, struct) in self.get_generate_order():
if isinstance(struct, NodeStruct):
self._match_node_inputs(struct)
continue # skip node
sub_construct_header = struct.construct_header_x
for (_, output_provider) in sub_construct_header.items():
# check from outputs collection
output_list = self.outputs_collection.get(output_provider)
if output_list is None:
# not in this module, need from current module construct header
for (self_x_name, self_output_provider) in self.construct_header_x.items():
if self_output_provider == output_provider:
struct.matched_inputs.append(self_x_name)
continue
for output in output_list:
(provider_succ, provider_closet_opt_var) = output
if provider_closet_opt_var in struct.matched_inputs:
continue # skip repeat
if provider_succ in struct.onnx_names:
struct.matched_inputs.append(provider_closet_opt_var)

def _append_to_outputs_collection(self, provider_name, val):
"""
Helper function to add a nodes or submodules outputs to current module return statement.

Args:
provider_name (str): The onnx name of the output provider.
val (list[tuple]): A list of tuple which contains
the output provider's successor name and its opt_var_name.
"""
exist_output = self.outputs_collection.get(provider_name)
if isinstance(val, tuple):
val = [val]
if exist_output is None: # add new entry
exist_output = list()
exist_output += (val)
self.outputs_collection[provider_name] = exist_output

def collect_returns(self):
"""
Collect all nodes and submodules' returns in the module.

Note:
The logic is to collect the return from nodes and submodules by the order
of topological index.

For returns from a node, it will check if the return will be used externally.
If external (external means the successor a.k.a the return user has different scope with the node),
add this return to current module's outputs_collection, where
key is this node's original onnx_name, and value is a list of
tuple(successor_name, this node's opt_var_name)

For returns from a submodule, it will check if the submodule has already collected returns,
If not, do it and then continue the following procedures.
Now we will check each element in submodule's outputs_collection. Note that we DO NOT check submodule's
returns should be continued returning, but just return them.
All these returns from submodules will be changes their original nodes output (a.k.a outputs provider)
`opt_var_name` to submodules' `opt_var_name`.

Finally, we match the outputs and inputs in the current module level.
"""
for (_, struct) in self.get_generate_order():
if isinstance(struct, NodeStruct):
outputs_list = []
# add these successor nodes name to collection for future use
for succ in struct.successor_nodes_names:
outputs_list.append((succ, struct.ms_opt_var_name))
if outputs_list:
self._append_to_outputs_collection(struct.onnx_name, outputs_list)
if isinstance(struct, ModuleStruct):
# Remove unnecessary returns, succ are all inside current
if not struct.outputs_collection:
struct.collect_returns()
sub_outputs_collection = struct.outputs_collection
# check each returns in sub
for provider_name, outputs_list in sub_outputs_collection.items():
for output in outputs_list:
(succ, _) = output # (succ, provider_opt_var_name) in output
new_output = (succ, struct.ms_opt_var_name)
self._append_to_outputs_collection(provider_name, new_output)
self._match_sub_modules_inputs()

def get_returned_opt_var_name(self) -> list:
"""Return a list of returned output var of this module."""
idx = 0
added_to_return = set()
ret = []
for ext_successor_requested, opt_var_name_in_this_module in self.external_successor_local_returns_map.items():
if ext_successor_requested in added_to_return:
continue
ret.append((ext_successor_requested, opt_var_name_in_this_module, idx))
added_to_return.add(ext_successor_requested)
return ret

def reset_node_external_input_to_local(self, nd_struct):
"""
Reset node's input to module's construct args
"""
for prec_node_name in nd_struct.precursor_nodes_names_external:
if prec_node_name in self.onnx_names: # prec node in current module's.
continue
if prec_node_name in self.construct_header_x.values():
# prec node assigned to construct header to passed in.
local_x = get_dict_key_by_value(prec_node_name, self.construct_header_x)
nd_struct.set_inputs_in_construct_header(local_x, prec_node_name)
else: # Extra precursor nodes, raise error
raise ValueError("Found external inputs of the Node but the module does not have it.")

def reset_submodule_external_input_to_local(self, md_struct):
"""
Reset submodule's external input to current module.

Args:
md_struct (ModuleStruct): The submodule in the current module.
"""
# check submodule's input
for _, submodule_precursor in md_struct.construct_header_x.items():
if submodule_precursor in self.onnx_names: # if internal, match with local nodes/submodules return
# but do nothing here
continue
else: # if external, match with current module construct header x
if submodule_precursor in self.construct_header_x.values():
local_x = get_dict_key_by_value(submodule_precursor, self.construct_header_x)
md_struct.set_inputs_in_construct_header(local_x, submodule_precursor)
else: # Extra precursor nodes, raise error
raise ValueError("Found external inputs of the submodule but the module does not have it.")

def register_node_output_to_module(self, nd_struct):
"""Register nodes outputs to this module's return."""
for succ_node_name in nd_struct.successor_nodes_names_external:
self.external_successor_local_returns_map[succ_node_name] = nd_struct.ms_opt_var_name

def register_submodule_output_to_module(self, md_struct):
"""Register submodule outputs to this module's return."""
submodule_returns = md_struct.get_returned_opt_var_name()
submodule_opt_var_name = md_struct.ms_opt_var_name
for (submodule_ext_succ, opt_var_name_in_this_module, ith_output) in submodule_returns:
self.external_successor_local_returns_map[submodule_ext_succ] = (submodule_opt_var_name, ith_output)
# edit external succ 's inputs in parent module
ext_node = self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(submodule_ext_succ)
ext_node_parent = ext_node.parent_module_struct
while ext_node_parent != self.parent_module_struct:
ext_node_parent.inputs_in_parent_module[ext_node.onnx_name] = md_struct.ms_opt_var_name
ext_node_parent = ext_node_parent.parent_module_struct

# need find the prec_name?
for ext_node_prec, opt_var_name in ext_node.inputs_in_parent_module.copy().items():
if isinstance(opt_var_name, str):
if opt_var_name == opt_var_name_in_this_module:
ext_node.inputs_in_parent_module[ext_node_prec] = (self.ms_opt_var_name, ith_output)
if isinstance(opt_var_name, tuple):
if opt_var_name[0] == opt_var_name_in_this_module:
ext_node.inputs_in_parent_module[ext_node_prec] = (self.ms_opt_var_name, ith_output)

+ 423
- 0
mindinsight/mindconverter/graph_based_converter/generator/node_struct.py View File

@@ -0,0 +1,423 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Define the NodeStruct which stores all info. of a node."""
from collections import OrderedDict

from .scope_utils import Scope
from .args_translator import ArgsTranslation
from ..common.code_fragment import CodeFragment
from ..third_party_graph.pytorch_graph_node import PyTorchGraphNode
from ..third_party_graph.onnx_graph_node import OnnxGraphNode
from ..common.global_context import GlobalContext
from ..constant import InputType


class NodeStruct:
"""
Define a node struct which stores all info. to generate statement.

Args:
args (Union[PyTorchGraphNode, OnnxGraphNode, dict]): Node related obj.

Note:
You can pass as many args as possible and the Node Struct will update
by arguments order.
"""
GLOBAL_CONTEXT_MGR = GlobalContext()

def __init__(self, args):
# define attributes here
self._identifier = None
self._fragment = None
self._args_translator = None
self._parent_module_struct = None
self.topo_idx = None
self.node_type = None
self.onnx_name = None
self.onnx_op = None
self.graph_node_ref = None # Our defined GraphNode
self.scope_name = None
self.ms_var_name = None
self.ms_opt_var_name = None # ms_opt_var_name = self.ms_var_name(...)
self.ms_op = None
self.ready_to_generate = False

self.ms_params = dict() # converted params from mapper
self.ms_settings = dict()
self.ms_weights = dict()
self.ms_inputs = OrderedDict()

self.scope = None # Defined Scope class
self.inputs_in_construct_header = OrderedDict() # key is prec_node_name, value is x; For code line use
self.inputs_in_parent_module = OrderedDict() # key is prec_node_name, value is its closet opt_var_name
self.matched_inputs = list() # Matched inputs will can be directly used by code line generation

# initialize funcs.
for arg in args:
self.update(arg)

def __repr__(self):
return str({
"address": hex(id(self)),
"idx": self.topo_idx,
"identifier": self.identifier
})

def ori_topo_idx(self):
"""Get the original topological index in the onnx graph."""
ori_name = self.identifier.replace('$', '').split('/')[-1].replace("::", '/')
self.onnx_name = ori_name
return self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_topo_idx.get(ori_name)

def update_var_name(self, idx=None):
"""
Update the var_name of each node.

Args:
idx (int): The index of the node in this module.
"""
if idx is not None:
self.ms_var_name = self.ms_op.replace('nn.', '').replace('P.', '').lower() + '_' + str(idx)
elif self.topo_idx is not None:
self.ms_var_name = self.ms_op.replace('nn.', '').replace('P.', '').lower() + '_' + str(self.topo_idx)
else:
raise ValueError("Unable to update var name when topo_idx is None.")
self.ms_opt_var_name = self.ms_var_name + '_opt'

def _update_basics_from_gn(self, gn):
"""Update basic info from GraphNode."""
self.graph_node_ref = gn
self.scope_name = gn.scope_name

def _update_from_pytorch_gn(self, gn: PyTorchGraphNode):
"""Update basic info from PyTorchGraphNode."""
self.node_type = "PyTorchGraphNode"
self._update_basics_from_gn(gn)

def _update_from_onnx_gn(self, gn: OnnxGraphNode):
"""Update basic info from OnnxGraphNode."""
self.node_type = "OnnxGraphNode"
self._update_basics_from_gn(gn)

def _update_from_mapper(self, d):
"""Update info from mapper."""
if d.get('op_name'):
self.ms_op = d.get('op_name')
if d.get('params'):
self.ms_params = d.get('params')
if d.get('settings'):
self.ms_settings = d.get('settings')
if d.get('weights'):
self.ms_weights = d.get('weights')

def _update_from_fragment(self, frag: CodeFragment):
"""Update info from CodeFragment."""
self._fragment = frag
if frag.operation:
self.ms_op = frag.operation
idx = self.GLOBAL_CONTEXT_MGR.latest_node_struct_count
self.update_var_name(idx=idx)

def _set_scope_from_identifier(self):
"""Set the Node scope from identifier."""
parsed_scope = Scope.parse_scope_from_node_identifier(self.identifier)
self.scope = Scope(parsed_scope)

def init_args_translator(self, translated_args: list):
"""
Initialize the ArgsTranslator for each Node.

Args:
translated_args (list): The list of args should be translated to formal args.
"""
if not self._fragment:
raise ValueError("Initialize argument translator failed.")
if self._fragment.actual_args and translated_args:
self._args_translator = ArgsTranslation(self._fragment.actual_args, self.ms_var_name, translated_args)

def check_if_generate_ready(self):
"""Check if the NodeStruct is able to generate code."""
# check essential params exists
if all([self.identifier,
self.node_type,
self.scope_name,
self.ms_var_name,
self.ms_opt_var_name,
self.ms_op]):
self.ready_to_generate = True

def update(self, arg, force_ready=False):
"""
Pass Node info. to generator NodeStruct.

Args:
arg (Union[PyTorchGraphNode, OnnxGraphNode, dict]): Node related obj.
force_ready (bool): Force this NodeStruct is ready to generate.
"""
if isinstance(arg, PyTorchGraphNode):
self._update_from_pytorch_gn(arg)
elif isinstance(arg, OnnxGraphNode):
self._update_from_onnx_gn(arg)
elif isinstance(arg, (dict, OrderedDict)):
self._update_from_mapper(arg)
elif isinstance(arg, CodeFragment):
self._update_from_fragment(arg)
else:
raise TypeError("NodeStruct received an unsupported initializing argument.")

if force_ready:
self.ready_to_generate = True
else:
self.check_if_generate_ready()

@property
def identifier(self):
"""Return the identifier of the node."""
return self._identifier

@identifier.setter
def identifier(self, s):
"""
Set the Node identifier, and update the scope.

Args:
s (str): The node identifier string.
"""
self._identifier = s
self._set_scope_from_identifier()
self.topo_idx = self.ori_topo_idx()
self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map[self.onnx_name] = self

@property
def fragment(self):
"""Return the fragment of the node."""
return self._fragment

@fragment.setter
def fragment(self, frag):
"""
Set the Node fragment.

Args:
s (NodeFragment): The node identifier string.
"""
self._fragment = frag

@property
def graph_node(self):
"""Return the GraphNode reference."""
return self.graph_node_ref

@graph_node.setter
def graph_node(self, graphnode):
"""Set the GraphNode reference."""
self.graph_node_ref = graphnode

@property
def onnx_node(self):
"""Return the original onnx node reference."""
return self.GLOBAL_CONTEXT_MGR.onnx_nodes_collection.get(self.onnx_name)

@property
def args_translator(self):
"""Return the args translator of this Node."""
return self._args_translator

@property
def precursor_nodes_names(self) -> list:
"""Return the names of precursor nodes."""
return self.graph_node_ref.precursor_nodes

@property
def precursor_nodes_structs(self) -> list:
"""Return the node struct instances of precursor nodes."""
ret = []
precursor_nodes_names = self.precursor_nodes_names
for pre_node_name in precursor_nodes_names:
nd_struct = self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(pre_node_name)
ret.append(nd_struct)
return ret

@property
def successor_nodes_names(self) -> list:
"""Return the names of successor nodes."""
return self.graph_node_ref.successor_nodes

@property
def successor_nodes_structs(self) -> list:
"""Return the node struct instances of successor nodes."""
ret = []
for pre_node_name in self.successor_nodes_names:
nd_struct = self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(pre_node_name)
ret.append(nd_struct)
return ret

@property
def parent_module_struct(self):
"""Return the parent struct of this node."""
return self._parent_module_struct

@parent_module_struct.setter
def parent_module_struct(self, ref):
self._parent_module_struct = ref

# Code Generation funcs below

def code_line_in_init(self):
"""Initialization line of code in module init block."""
left = "self.{}".format(self.ms_var_name)
args_list = list()
if self._args_translator is not None:
args_list += self._args_translator.actual_args_to_str_list
args_list += self._args_translator.formal_args_to_str_list
else:
actual_args_str = ArgsTranslation.dict_data_to_args_str_list(self._fragment.actual_args)
args_list += actual_args_str
right = f"{self.ms_op}({', '.join(args_list)})"
return left, right

def _get_correct_in_module_returns(self, prec_node, in_module_return):
"""
Find the correct precursor node name in return statement of its parent module.

Args:
prec_node (str): The onnx name of the precursor node given.
in_module_return (list[tuple]): The list of outputs which contains parent module identifier
and module opt_var_name.

Return:
str, correct opt_var_name to be passed in current node.
"""
found_return = False
for ret in in_module_return:
(md_identifier, input_name_to_use) = ret
p_node_struct = self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(prec_node)
# recursive check the p node parent
parent = p_node_struct
while not found_return:
parent = parent.parent_module_struct
if parent is None:
break
if parent.identifier == md_identifier:
return input_name_to_use
return None

def code_line_in_construct(self, inputs=None, in_module_returns=None):
"""Construct line of code in module construct block. """
left = self.ms_opt_var_name
if inputs is None:
inputs = []
for idx, prec_node in enumerate(self.precursor_nodes_names):
if self.inputs_in_construct_header.get(prec_node):
inputs.append(self.inputs_in_construct_header.get(prec_node))
elif self._check_target_node_internal(prec_node):
inputs.append(self.precursor_nodes_structs[idx].ms_opt_var_name)
elif self.inputs_in_parent_module.get(prec_node):
inputs.append(self.inputs_in_parent_module.get(prec_node))
elif in_module_returns and in_module_returns.get(self.onnx_name) \
and (not self._check_target_node_internal(prec_node)):
inputs.append(self._get_correct_in_module_returns(prec_node, in_module_returns.get(self.onnx_name)))
else:
inputs.append("unk_{}_{}".format(idx, prec_node))

if self.matched_inputs:
inputs = self.matched_inputs

# Check original onnx node's input to ensure double inputs are not ignored
original_inputs = self.GLOBAL_CONTEXT_MGR.onnx_node_inputs.get(self.onnx_name)
new_inputs = []
for idx, prec_node in enumerate(self.precursor_nodes_names):
occurence = original_inputs.count(prec_node)
for _ in range(occurence):
new_inputs.append(inputs[idx])
inputs = new_inputs

if isinstance(inputs, str):
inputs = [inputs]

if self._fragment.code_setting and self._fragment.code_setting.op_ipt_type == InputType.LIST.value:
inputs = [str(tuple(inputs)).replace("\'", "")]

if self._fragment.code_setting and self._fragment.code_setting.op_extra_input:
for _, val in self._fragment.code_setting.op_extra_input.items():
inputs.append(str(val))

if self._fragment.code_setting and self._fragment.code_setting.op_extra_tensor:
inputs.append(f"self.{self.ms_var_name}_w")
right = f"self.{self.ms_var_name}({', '.join(inputs)})"
return left, right

def add_extra_tensor(self):
""" Add extra tensor."""
left = "self.{}_w".format(self.ms_var_name)
shape = self._fragment.code_setting.op_extra_tensor.shape
right = f"Tensor(np.random.uniform(0, 1, {shape}), mindspore.float32)"
return left, right

# The following functions are specified for multiple in/out support.
# and should be called only after generator._recursive_form_modules()

def set_inputs_in_construct_header(self, header_x, onnx_precursor_node_name):
"""
Mark the registered external inputs for code generation.

Note:
This function to be called by its parent (ModuleStruct).

Args:
header_x (str): The `x` in module construct header.
onnx_precursor_node_name (str): The original onnx node name.
"""
if self.inputs_in_construct_header.get(onnx_precursor_node_name) is not None:
raise ValueError("The input from {} has already registered. Check this node \
{} has duplicate inputs or not.".format(onnx_precursor_node_name, self.identifier))
self.inputs_in_construct_header[onnx_precursor_node_name] = header_x

def _check_target_node_internal(self, name: str) -> bool:
"""
Check given node under the same scope.

Args:
name (str): Can accept both node identifier or original onnx node name.
"""
target_nd_struct = self.GLOBAL_CONTEXT_MGR.node_struct_collections.get(name) \
or self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(name)
if target_nd_struct is None and self.topo_idx == 0: # First node always has external input
return False

if target_nd_struct is None:
raise ValueError("Unable to find the NodeStruct of given target node {}.".format(name))
return target_nd_struct.scope.path == self.scope.path

@property
def has_successor_node_external(self) -> bool:
"""Check if any successor_node is in external module."""
for name in self.successor_nodes_names:
if not self._check_target_node_internal(name):
return False

return True

@property
def precursor_nodes_names_external(self) -> list:
"""Return a list of external precursor nodes names."""
return [name for name in self.precursor_nodes_names \
if not self._check_target_node_internal(name)]

@property
def successor_nodes_names_external(self) -> list:
"""Return a list of external successor nodes names."""
return [name for name in self.successor_nodes_names \
if not self._check_target_node_internal(name)]

+ 157
- 0
mindinsight/mindconverter/graph_based_converter/generator/scope_utils.py View File

@@ -0,0 +1,157 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Define a scope class processing all operations related to scope and scope name."""
import re


class Scope():
"""Define scope related operations."""

def __init__(self, scope_str):
scopes = scope_str.split('/')
self.module_path = list()
self.scope_list = scopes[:-1]
self.head = self.scope_list[0]
self.tail = self.scope_list[-1]
self.initialization()

def initialization(self):
"""Init scope class."""
self._update_module_path_from_scope_list()

def _update_module_path_from_scope_list(self):
"""Update the module scope path from a list of scope."""
self.module_path = list()
for scope in self.scope_list:
if scope == 'Model':
continue

if 'Module' in scope:
regex = r"Module(?P<num>\d+)_(?P<curr_level_unique_id>\d+)"
match = re.match(regex, scope)
if match:
module_num = match.group('num')
uid = match.group('curr_level_unique_id')
self.module_path.append((int(module_num), int(uid)))

@property
def path(self):
"""Return module scope path."""
return self.module_path

def set_path(self, ind, path_tuple: tuple):
"""
Set the module scope path.

Args:
ind (int): The index of the scope path to be set.
path_tuple ((int, int)): The tuple of the scope path.
"""
self.module_path[ind] = path_tuple

@property
def to_str(self):
"""Return the full module scope as the string format."""
full_str_list = ["Model"]
for (num, uid) in self.module_path:
local = "Module{}_{}".format(num, uid)
full_str_list.append(local)

return "/".join(full_str_list)

@property
def depth(self):
"""Return the depth of the scope path."""
return len(self.path)

@staticmethod
def scope_to_module_name(path):
"""
Helper function to convert any scope path string to the full module scope.

Args:
path (str): path string like "[(5, 0), (3, 0)]"

Returns:
str, the full module scope with format like "Model/Module5_0/Module3_0/"
"""
scope_str_list = ["Model"]
if isinstance(path, str):
path = Scope.path_str_to_list(path)
if isinstance(path, list):
for (num, uid) in path:
local = "Module{}_{}".format(num, uid)
scope_str_list.append(local)

return "/".join(scope_str_list)

@staticmethod
def parse_scope_from_node_identifier(node_identifier: str):
"""
Helper function to parse the scope string from node identifier.

Args:
node_identifier (str): The string of the node identifier.

Returns:
str, parsed scope string from node identifier.
"""
regex = r"(?P<scope>Model/.*)\$\S+\$"
match = re.match(regex, node_identifier)
if not match:
return None
return match.group('scope')

@staticmethod
def path_str_to_list(scope_path_str: str):
"""
Helper function to convert the scope path string back to list.

Args:
scope_path_str (str): The scope path string like "[(5, 0), (3, 0)]".

Returns:
list, a list of the scope path like [(5, 0), (3, 0)].
"""
ret = []
tmp = scope_path_str.strip('[').strip(']')
regex = r"\((?P<num>\d+), (?P<uid>\d+)\)"
s_all = re.findall(regex, tmp)
for (num, uid) in s_all:
ret.append((int(num), int(uid)))

return ret

@staticmethod
def get_parent_module_num_and_uid(path):
"""
Helper function to return its parent's scope tuple.

Args:
path (Union[str, list]): Module scope path string. e.g. "[(5, 0), (3, 0)]"

Returns:
tuple, parent's scope level. e.g. [(5, 0)]
"""
if isinstance(path, str):
path = Scope.path_str_to_list(path)
if isinstance(path, list):
if len(path) == 1: # modules under the main module, (-1, -1) means main module.
return (-1, -1)
if len(path) > 1: # modules under another non-main module. Return parent's scope.
parent = path[-2]
return parent

return None

+ 48
- 0
mindinsight/mindconverter/graph_based_converter/hierarchical_tree/name_mgr.py View File

@@ -106,3 +106,51 @@ class GlobalVarNameMgr:

global_var_namespace.add(new_name)
return new_name


class LocalVarNameMgr:
"""Local variable name mgr."""

def __init__(self):
self.local_op_namespace = dict()
self.local_var_namespace = set()

@staticmethod
def _get_name(name):
"""Deal with op name."""
if "::" in name:
return name.split("::")[1]
return name

def get_name(self, op_type):
"""
Get module/variable name.

If the module already existed, then add a suffix to it.

conv1 onnx::conv

Args:
op_type (str): Operator type in onnx.

Returns:
str, module name.
"""

def _gen(t):
t = t.lower()
if t not in self.local_op_namespace:
self.local_op_namespace[t] = START_IDX
suffix = ""
else:
self.local_op_namespace[t] += 1
suffix = f"{self.local_op_namespace[t] - 1}"

return f"{self._get_name(t)}{suffix}"

new_name = _gen(op_type)
while new_name in self.local_var_namespace:
new_name = _gen(op_type)

self.local_var_namespace.add(new_name)
return new_name

+ 1
- 1
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py View File

@@ -151,7 +151,7 @@ class OnnxGraph(Graph):
input_shape (tuple): Input shape.
"""
input_node = InputNode(input_shape)
input_node_name = "{}InputNode"
input_node_name = self._raw_input_nodes.replace(":0", "")
for node_name, node in self._nodes_collection.items():
if node_name in self._input_nodes:
ipt_nd_name = input_node_name.format(input_node.scope_name)


+ 12
- 2
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py View File

@@ -23,6 +23,7 @@ import numpy as np

from mindinsight.mindconverter.common.log import logger as log
from ..common.utils import fetch_output_from_onnx_model
from ..common.global_context import GlobalContext

from ..constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \
ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL
@@ -110,6 +111,7 @@ class OnnxTensor:
self.to_nodes = []

def to_array(self):
"""Convert the tensor value from binary to np array."""
onnx = import_module("onnx")
# Convert binary data to np.array
if not isinstance(self.raw_tensor, (np.ndarray, list, tuple, int, float)):
@@ -264,7 +266,7 @@ class OnnxDataLoader:
self.output_nodes = output_nodes if isinstance(output_nodes, list) else [output_nodes]
# args for init
self._is_infer_shape = infer_shape
self._global_context = GlobalContext()
# params parsed in init
self.inferred_model = None

@@ -375,12 +377,19 @@ class OnnxDataLoader:

def _parse_nodes(self):
"""Parse each onnx nodes in the model."""
for node in self.nodes:
nodes_topo_idx = []
for idx, node in enumerate(self.nodes):
n = OnnxNode(node)
self._nodes_dict[n.name] = n
nodes_topo_idx.append((idx, n.name))
if len(node.output) > 1:
raise ModelNotSupport(msg=f"{node.name} has multi-outputs which is not supported now.")
self.output_name_to_node_name[node.output[0]] = node.name
self._global_context.onnx_node_name_to_topo_idx[n.name] = idx
node_inputs = [i.replace(":0", "") for i in node.input]
self._global_context.onnx_node_inputs[n.name] = node_inputs
self._global_context.onnx_nodes_collection = self._nodes_dict
self._global_context.onnx_nodes_topo_index = nodes_topo_idx

def _parse_tensors(self):
"""Parse each onnx tensors in the model."""
@@ -388,6 +397,7 @@ class OnnxDataLoader:
for tensor in tensors:
t = OnnxTensor(tensor)
self.tensors_dict[t.name] = t
self._global_context.onnx_tensors_collection = self.tensors_dict

def _parse_node_output_shape(self):
"""


Loading…
Cancel
Save