Browse Source

minor bug fix for args translation

fix problem where incorrect logic for nodestruct's check target nodes internal when target node is graph inputs
update copyright
add a output mgr
adopt the generator with new fragment exchange msg. still need to adopt old ver. code_settings etc.
outputs mgr dev; adapt the new fragment. temp disable extra nodes and weights
adapt the NewFragment; re-imlement the module struct reset method
update ut mapper test for new fragment
tags/v1.2.0-rc1
liangtianshu 4 years ago
parent
commit
02773f4490
11 changed files with 537 additions and 194 deletions
  1. +6
    -1
      mindinsight/mindconverter/graph_based_converter/common/code_fragment.py
  2. +5
    -1
      mindinsight/mindconverter/graph_based_converter/common/global_context.py
  3. +200
    -0
      mindinsight/mindconverter/graph_based_converter/common/outputs.py
  4. +16
    -25
      mindinsight/mindconverter/graph_based_converter/generator/__init__.py
  5. +96
    -0
      mindinsight/mindconverter/graph_based_converter/generator/fragment_utils.py
  6. +84
    -27
      mindinsight/mindconverter/graph_based_converter/generator/generator.py
  7. +32
    -24
      mindinsight/mindconverter/graph_based_converter/generator/module_struct.py
  8. +66
    -104
      mindinsight/mindconverter/graph_based_converter/generator/node_struct.py
  9. +21
    -8
      mindinsight/mindconverter/graph_based_converter/mapper/base.py
  10. +10
    -0
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py
  11. +1
    -4
      tests/ut/mindconverter/graph_based_converter/mapper/test_mapper.py

+ 6
- 1
mindinsight/mindconverter/graph_based_converter/common/code_fragment.py View File

@@ -27,8 +27,9 @@ class Fragment(abc.ABC):
Args:
operation (str): Operation name in MindSpore.
actual_args (dict): Actual arg values.
input_shape (tuple): The input shape of the node.
output_shape (tuple): The output shape of the node.
settings (namedTuple): Code generation setting.

"""

def __init__(self, operation, actual_args, input_shape, output_shape, settings=None):
@@ -46,6 +47,7 @@ class Fragment(abc.ABC):

@property
def code_setting(self):
"""Code Setting getter."""
return self._code_setting

@property
@@ -152,10 +154,12 @@ class Fragment(abc.ABC):

@property
def input_shape(self):
"""Return the input shape."""
return self._input_shape

@property
def output_shape(self):
"""Return the output shape."""
return self._output_shape


@@ -196,6 +200,7 @@ class CodeFragment(Fragment):

@property
def trainable_params(self):
"""Return the trainable parameters."""
return self._trainable_params




+ 5
- 1
mindinsight/mindconverter/graph_based_converter/common/global_context.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,6 +14,7 @@
# ==============================================================================
"""Define GlobalContext class to save required resources during whole conversion procedure."""
from collections import OrderedDict
from .outputs import OutputStorage


class Singleton(type):
@@ -45,6 +46,7 @@ class GlobalContext(metaclass=Singleton):
self.onnx_node_name_to_topo_idx = dict()
self.onnx_node_inputs = dict()
self._onnx_tensors_collection = dict()
self.onnx_graph_info = dict()

# Define data stored from generator
# Key as Node Identifier
@@ -72,6 +74,8 @@ class GlobalContext(metaclass=Singleton):
# key is target node (which use this opt), value is opt_var_name
self.extra_input_dict = dict()

self.outputs_storage = OutputStorage()

def get_onnx_node_from_identifier(self, identifier):
"""Return an OnnxUtils defined node by its identifier."""
onnx_node_name = self.node_struct_to_onnx_node_map.get(identifier)


+ 200
- 0
mindinsight/mindconverter/graph_based_converter/common/outputs.py View File

@@ -0,0 +1,200 @@
# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Define basic classes for generator use."""
import abc
import copy
from typing import Union, Iterable

class BaseOutput:
"""
Define the class of output providing a universal nodes' and modules' output data collection.

Args:
output_mapping (tuple[tuple]): The mapping of outputs from onnx to mindspore.
"""
def __init__(self, output_mapping) -> None:
super(BaseOutput).__init__()
self.idx_in_ms_provider = output_mapping[0]
self.idx_in_onnx_provider = output_mapping[1]

# For multi users, key as user and value as index
self.idx_in_ms_user = dict()
self.idx_in_onnx_user = dict()

# The following attributes to be set by who referenced this object.
self.onnx_edge_name = None
self.to_external = False

@property
def ms_user(self):
"""Return the output's user in the MindSpore."""
return self.idx_in_ms_user.keys()

@property
def onnx_user(self):
"""Return the output's user in the ONNX."""
return self.idx_in_onnx_user.keys()

def deepcopy(self):
"""Return a deepcopy of self instance."""
return copy.deepcopy(self)


class BaseOutputManager(abc.ABC):
"""
Base Output Manager class.

Args:
output_mappings (list): A list of output mapping.
"""
def __init__(self, output_mappings):
if isinstance(self.__class__, ModuleOutputManager):
return
self._base_output_list = list()

# init base output obj
for mapping in output_mappings:
obj = BaseOutput(mapping)
self._base_output_list.append(obj)

@property
def outputs(self):
"""Return the list of BaseOutput in this manager."""
return self._base_output_list

@outputs.setter
def outputs(self, val: list):
"""Set the list of BaseOutput in this manager."""
for v in val:
if not isinstance(v, BaseOutput):
raise TypeError(f"{self.__class__} does not accept the type {type(v)} in the list given.")
self._base_output_list = val

@abc.abstractmethod
def deepcopy(self):
"""Return the deepcopy of this instance."""
cls = self.__class__
result = cls.__new__(cls)
result.outputs = list()
for out in self._base_output_list:
result.outputs.append(out.deepcopy())
return result


class NodeOutputManager(BaseOutputManager):
"""
Node Output Manager class.

Args:
identifier (str): The identifier of the node.
output_mappings (list): A list of the output mapping.
"""
def __init__(self, identifier, output_mappings=None) -> None:
super(NodeOutputManager, self).__init__(output_mappings)
self.identifier = identifier

def deepcopy(self):
new_mgr = super().deepcopy()
new_mgr.identifier = self.identifier
return new_mgr


class ModuleOutputManager(BaseOutputManager):
"""
Module Output Manager class.

Args:
identifier (str): The identifier of the module.
output_mappings (list): a list of output mapping
"""
def __init__(self, identifier, base_out: Union[BaseOutput, Iterable[BaseOutput]]) -> None:
super(ModuleOutputManager, self).__init__(None)
self.identifier = identifier
self._return_list_counter = 0
self._base_output_list = list()
if isinstance(base_out, BaseOutput):
self._base_output_list.append(base_out)
else:
self._base_output_list += base_out

@property
def return_num(self):
"""Return the number of outputs to be returned."""
return self._return_list_counter

@return_num.setter
def return_num(self, num: int):
"""Set the number of outputs to be returned."""
self._return_list_counter = num

def deepcopy(self):
"""Return a deepcopy of current instance."""
new_mgr = super().deepcopy()
new_mgr.identifier = self.identifier
new_mgr.return_num = self._return_list_counter
return new_mgr


class OutputStorage:
"""A class saves all outputs."""
def __init__(self):
self._base_output_edge_to_instance = dict()
self._base_output_edge_to_onnx_node_name = dict()
self._base_output_edge_to_ms_identifier = dict()

@property
def outputs_collections(self) -> dict:
"""Return the dict of edge name to output instance."""
return self._base_output_edge_to_instance

def onnx_name(self, output_edge) -> str:
"""Return the dict of edge name to onnx node name."""
return self._base_output_edge_to_onnx_node_name.get(output_edge)

def node_identifier(self, output_edge):
"""Return the dict of edge name to node identifier."""
return self._base_output_edge_to_ms_identifier.get(output_edge)

def add_output(self, out: BaseOutput) -> str:
"""
Add a BaseOutput instance to the storage.

Args:
out (BaseOutput): The BaseOutput instance.
"""
if out.onnx_edge_name:
self._base_output_edge_to_instance[out.onnx_edge_name] = out
else:
raise ValueError("Unable to add a BaseOutput instance with unknown ONNX edge.")

def add_onnx_node_name(self, edge: str, onnx_node_name: str):
"""
Add the onnx node name with the edge name.

Args:
edge (str): The edge name of this output.
onnx_node_name (str): The onnx node which has the edge.
"""
self._base_output_edge_to_onnx_node_name[edge] = onnx_node_name

def add_ms_identifier(self, edge: str, ms_identifier: str):
"""
Add the node identifier with the edge name.

Args:
edge (str): The edge name of this output.
ms_identifier (str): The identifier of the node which has the edge.
"""
self._base_output_edge_to_ms_identifier[edge] = ms_identifier

+ 16
- 25
mindinsight/mindconverter/graph_based_converter/generator/__init__.py View File

@@ -18,9 +18,10 @@ __all__ = ["batch_add_nodes"]
import re
import copy

from mindinsight.mindconverter.graph_based_converter.common.code_fragment import CodeFragment
from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords
from .generator import Generator, CodeStruct
from ..common.code_fragment import CodeFragment, NewFragment
from ..common.outputs import NodeOutputManager
from ..constant import ExchangeMessageKeywords


def _tf_model_node_name_reformat(node, node_name):
@@ -56,6 +57,7 @@ def batch_add_nodes(graph_obj, mapper) -> Generator:

"""
generator_inst = Generator()
external_inputs = graph_obj.user_provided_input_nodes
for node_name in graph_obj.nodes_in_topological_order:
node_inst = graph_obj.get_node(node_name)
node_input = graph_obj.get_input_shape(node_name)
@@ -66,16 +68,11 @@ def batch_add_nodes(graph_obj, mapper) -> Generator:
node_name = node_name_with_scope

node_inst.add_input_and_output_shape(node_input, node_output)
op_name, params, settings, weights = _convert_params(node_inst, mapper)
generator_inst.add_node(
node_name,
node_instance=node_inst,
node_fragment=CodeFragment(op_name, params,
settings,
node_inst.input_shape,
node_inst.output_shape,
weights)
)
code_template, exchange_msg, outputs_lst, outputs_mapping = _convert_params(node_inst, mapper, external_inputs)
outputs_mapping = NodeOutputManager(node_name, output_mappings=outputs_mapping)
fragment = NewFragment(data_entity=exchange_msg, code_template=code_template,
outputs=outputs_lst, outputs_mapping=outputs_mapping)
generator_inst.add_node(node_name, node_instance=node_inst, node_fragment=fragment)
return generator_inst


@@ -105,13 +102,14 @@ def _supply_graph_info(node, external_inputs):
}


def _convert_params(node, mapper):
def _convert_params(node, mapper, external_inputs):
"""
Call mapper to convert node's params from ONNX to MindSpore.

Args:
node (GraphNode): Our defined GraphNode instance.
mapper (Mapper): The mapper instance which indicating conversion method.
external_inputs (list[str]): External inputs provided by users.

Returns:
tuple[str, dict, dict, dict], op name in MindSpore, MindSpore parameters,
@@ -121,18 +119,11 @@ def _convert_params(node, mapper):
params.update({"input_shape": node.input_shape,
"output_shape": node.output_shape})

op_in_ms, ms_params, ms_settings, weights = mapper.convert(op_name=node.op_name,
params=params,
weights=node.weight)
if "input_shape" in ms_params:
ms_params.pop("input_shape")
if "output_shape" in ms_params:
ms_params.pop("output_shape")

if op_in_ms:
return op_in_ms, ms_params, ms_settings, weights

return node.op_name, node.node_params, dict(), dict()
code_template, exchange_msg, outputs_lst, outputs_order_mapping = mapper.convert(op_name=node.op_name,
params=params,
weights=node.weight)
exchange_msg[ExchangeMessageKeywords.METADATA.value] = _supply_graph_info(node, external_inputs)
return code_template, exchange_msg, outputs_lst, outputs_order_mapping


def _combine_external_inputs_with_precursor_nodes(node, external_inputs):


+ 96
- 0
mindinsight/mindconverter/graph_based_converter/generator/fragment_utils.py View File

@@ -0,0 +1,96 @@
# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Miscellaneous Fragment related classes and functions. """

from mindinsight.mindconverter.graph_based_converter.common.code_fragment import NewFragment


class FragmentHandler:
"""
Define a handler to process the infomation contained by Fragment.

Args:
fragment (NewFragment): The refactored fragment class.
"""
def __init__(self, fragment: NewFragment):
self._fragment = fragment
# set the var in the fragment to be load and save.
self._target_var = "var_0"

@property
def target_var(self):
"""Return the target var name the handler currently set to be read."""
return self._target_var

@target_var.setter
def target_var(self, target):
"""Set the target var the handler will read."""
if not target in self.exchange_msg.keys():
raise ValueError(f"Unable to set target var {target} where fragment does not have it.")
self._target_var = target

@property
def fragment(self):
"""Return the fragment instance the handler currently processed."""
return self._fragment

@property
def converted(self):
"""Return the status of the op successfully converted."""
return bool(self._fragment.exchange_msg)

# The following section is intended for Fragment exchange message.
@property
def exchange_msg(self):
"""Return the exchange message dictionary the fragment contains."""
return self._fragment.exchange_msg

@property
def var(self):
"""Return the var dictionary the handler currently set to be processed."""
try:
return self.exchange_msg.get(self.target_var)
except AttributeError:
return None

@property
def default_var(self):
"""Return the default var dictionary the handler processed."""
try:
return self.exchange_msg.get("var_0")
except AttributeError:
return None

# For metadata
@property
def metadata(self):
"""Return the metadata of the onnx node info dictionary."""
return self._fragment.exchange_msg.get("metadata")

@property
def input_shape(self):
"""Return the input shape of this node."""
return self.metadata.get('inputs_shape')

@property
def output_shape(self):
"""Return the output shape of this node."""
return self.metadata.get('outputs_shape')

# For outputs
@property
def outputs_manager(self):
"""Return the outputs manager of this node."""
return self._fragment.outputs_mapping

+ 84
- 27
mindinsight/mindconverter/graph_based_converter/generator/generator.py View File

@@ -23,6 +23,7 @@ from .node_struct import NodeStruct
from .module_struct import ModuleStruct
from .args_translator import ArgsTranslationHelper
from ..common.global_context import GlobalContext
from ..common.outputs import BaseOutput, ModuleOutputManager
from ...common.exceptions import GeneratorError
from ..common.name_mgr import GlobalVarNameMgr
from ..constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT, CodeFormatConfig, get_imported_module
@@ -39,21 +40,10 @@ class CodeStruct:

def __init__(self, struct, repeated_submodules=None):
"""Initialize the CodeStruct."""
self.output_order = None # output order
self.input = None # opt_var_name for prev. node
self.extra_input = list() # extra_input(s) at construct method args
self.output = None # opt_var_name for next node
self.extra_output = list() # extra_output(s)
self.extra_comment = None # comments for this code line / block.
self.code_line_list = list() # list of code line, a item is a line.
self._global_var_mgr = GlobalVarNameMgr() # var name procs within same module

self.formal_args_collections = None

if isinstance(struct, NodeStruct):
self.output_order = struct.topo_idx
if isinstance(struct, ModuleStruct):
self.output_order = struct.head_nd_struct_index
self._generate_from_module_struct(struct, repeated_submodules)

def _add_line(self, s):
@@ -102,13 +92,15 @@ class CodeStruct:
cons_lines = list()
for (_, struct) in md_struct.get_generate_order():
if isinstance(struct, NodeStruct): # Generate code line for Node.
code_line_init = struct.code_line_in_init()
code_line_construct = struct.code_line_in_construct()
init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_init)}")
cons_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_construct)}")
# add extra tensor
if struct.fragment.code_setting and struct.fragment.code_setting.op_extra_tensor:
init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(struct.add_extra_tensor())}")
_ = struct.code_line_in_init()
_ = struct.code_line_in_construct()

init_str, cons_str = struct.fragment.fragment()
init_str = [f"{SECOND_LEVEL_INDENT}{x}" for x in init_str]
cons_str = [f"{SECOND_LEVEL_INDENT}{x}" for x in cons_str]
code_line_construct = cons_str
init_lines += init_str
cons_lines += cons_str

elif isinstance(struct, ModuleStruct):
# check if this instance generated CodeStruct
@@ -145,7 +137,8 @@ class CodeStruct:
returns.append(r)
returns = list(set(returns))
else:
returns = [code_line_construct[0]]
returns = [code_line_construct[0]] if isinstance(code_line_construct, tuple) \
else [code_line_construct[-1].replace(' ', '').split('=')[0]]
self.new_line = f"{SECOND_LEVEL_INDENT}return {', '.join(returns)}"
self.new_line = f"{NEW_LINE * 2}"
self.GLOBAL_CONTEXT.code_structs[md_struct.pattern_id] = self
@@ -156,9 +149,6 @@ class Generator:

def __init__(self):
"""Init the generator."""
# define basic attributes
self.framework = None

# define MUST have params
self._node_struct_collections = OrderedDict()
self._module_struct_collections = OrderedDict()
@@ -244,9 +234,9 @@ class Generator:
if len(nd_struct_list) < 2:
return formal_args
(_, base_nd_struct) = nd_struct_list[0]
for (base_parameter, base_value) in base_nd_struct.fragment.actual_args.items(): # for each param
for (base_parameter, base_value) in base_nd_struct.fragment.default_var["args"].items(): # for each param
for (_, nd_struct) in nd_struct_list[1:]:
compared_value = nd_struct.fragment.actual_args.get(base_parameter)
compared_value = nd_struct.fragment.default_var["args"].get(base_parameter)
if compared_value == base_value:
continue
formal_args.add(base_parameter)
@@ -340,9 +330,9 @@ class Generator:
self.module_structs[str(parent_scope)] = parent_md_struct
else:
# 1B. not has parent, generate a new ModuleStruct
parent_md_struct = copy.deepcopy(md_struct) # use this submodule to create a parent module
# use this submodule to create a parent module
parent_md_struct = ModuleStruct(None, init_as_parent=True, parent_base=md_struct)
# rewrite parent md struct
parent_md_struct.reset_as_parent()
parent_md_struct.add_submodule(md_struct)
self.module_structs[str(parent_scope)] = parent_md_struct
sub = self.module_structs.pop(scope_path_str) # remove this submodule from collections
@@ -378,6 +368,7 @@ class Generator:
self._update_all_modules_args_translator()

# 6. Update all nodes and moudles input/output
# Enable build_output_connections later.
self.module_structs.get('[]').allocate_construct_header_x()
self.module_structs.get('[]').collect_returns()

@@ -488,7 +479,7 @@ class Generator:

for code_struct in self._global_context.code_structs.values():
for line in code_struct.code_line_list:
outputs.append(line.replace("onnx::", ""))
outputs.append(line)

formatted_code, _ = FormatCode("\n".join(outputs),
style_config=CodeFormatConfig.PEP8.value)
@@ -575,3 +566,69 @@ class Generator:
if m_num == module_num:
ret.append(nd_struct_list)
return ret

def build_outputs_connection(self):
"""Build all nodes and modules outputs connections."""
for nd_struct in self.node_structs.values():
# for each output in curr node output manager
for out in nd_struct.outputs_manager.outputs:
# Set the onnx output edge name to this output
out.onnx_edge_name = nd_struct.fragment.metadata.get('outputs')[out.idx_in_onnx_provider]
self._global_context.outputs_storage.add_output(out)
self._global_context.outputs_storage.add_onnx_node_name(out.onnx_edge_name,
nd_struct.fragment.metadata.get('source'))
self._global_context.outputs_storage.add_ms_identifier(out.onnx_edge_name, nd_struct.identifier)

# Set input with existing output mapping
for idx, inp in enumerate(nd_struct.fragment.metadata.get('inputs')):
if inp in self._global_context.outputs_storage.outputs_collections:
output_obj = self._global_context.outputs_storage.outputs_collections[inp]
output_obj.idx_in_onnx_user[nd_struct.onnx_name] = idx

# set ms_user idx, need to modify if not follow onnx order
output_obj.idx_in_ms_user[nd_struct.identifier] = idx

# set this output to be returned to external
output_obj.to_external = not(nd_struct.check_target_node_internal(
self._global_context.outputs_storage.onnx_name(inp)
))

# collect submodule's and nodes' outputs mgr
self._collect_output_mgr()

def _collect_output_mgr(self, module=None):
"""
Collect the outputs manager from nodes and submodules the current module has.

Args:
module (ModuleStruct): The module struct collecting its nodes and submodules.
"""
root_module = module or self.get_module_struct('[]')
output_mgr_list = list()
for struct in root_module.get_generate_order():
if isinstance(struct, tuple):
# index 1 is the NodeStruct while 0 is topological index.
struct = struct[1]
if isinstance(struct, ModuleStruct) and struct.outputs_manager is None:
self._collect_output_mgr(module=struct)
for out in struct.outputs_manager.outputs:
if Generator.check_output_need_to_external(root_module, out):
output_mgr_list.append(out)
root_module.outputs_manager = ModuleOutputManager(root_module.identifier, base_out=output_mgr_list)

@staticmethod
def check_output_need_to_external(root_module: ModuleStruct, checked_output: BaseOutput):
"""
Check the output still need to be returned to module external.

Args:
root_module (ModuleStruct): The Module that the output to be determined.
checked_output (BaseOutput): The output to be checked whether returned by the Module.

Returns:
bool, True if the output need to be returned to the module external.
"""
for user in checked_output.onnx_user:
if user in root_module.external_successor_nodes_names:
return True
return False

+ 32
- 24
mindinsight/mindconverter/graph_based_converter/generator/module_struct.py View File

@@ -14,6 +14,7 @@
# ==============================================================================
"""Define a struct for module converted and save all required information here."""

import copy
from collections import OrderedDict

from .node_struct import NodeStruct
@@ -31,10 +32,12 @@ class ModuleStruct:

Args:
args (list): A list of node structs.
init_as_parent (bool): Control init method if the ModuleStruct be init as a parent module struct.
parent_base (ModuleStruct): The base ModuleStruct the current ModuleStruct to be init as.
"""
GLOBAL_CONTEXT_MGR = GlobalContext()

def __init__(self, nd_struct_list):
def __init__(self, nd_struct_list, init_as_parent=False, parent_base=None):
"""Init. a module by NodeStructs."""
self.pattern_id = -1 # pattern num, -1 as Main module
self.pattern_uid = -1 # unique module id for this pattern
@@ -53,19 +56,15 @@ class ModuleStruct:

self._fragment = None
self._args_translator = None
self._setting = None
self._parent_module_struct = None
# only store original formal args name, not global
self._nodes_structs_formal_args_list = list()
# only store translated (globalized) formal args
self._nodes_structs_formal_args_translated_list = list()

# define other settings here
self._node_args_translation_list = list()
self._var_name_mgr = LocalVarNameMgr()
self.construct_header_x = OrderedDict() # key is header x, value is precursors onnx name
self.inputs_in_construct_header = OrderedDict() # key is precursors onnx name, value is x in parent construct
self.inputs_in_parent_module = OrderedDict() # key is prec_node_name, value is its closet opt_var_name

# key is node's onnx name(output provider), value is (provider_succ_name, opt_var_name)
self.outputs_collection = dict()
@@ -74,40 +73,49 @@ class ModuleStruct:
# key is ext. succ node onnx name, value is local opt_var
self.external_successor_local_returns_map = OrderedDict()

# key is node's onnx_name, value is (successor_name, opt_var_name) <- node's level
self.outputs_collection = dict()
# Define outputs manager, note this will be assigned later by Generator.
self.outputs_manager = None

# start initialization
if not self.initialized:
self._init_module(nd_struct_list)
if init_as_parent and (parent_base is not None):
self.reset_as_parent_passed_in(parent_base)
else:
self._update_module(nd_struct_list)
# start initialization
if not self.initialized:
self._init_module(nd_struct_list)
else:
self._update_module(nd_struct_list)

# assign this module reference to node
for (_, nd_struct) in nd_struct_list:
nd_struct.parent_module_struct = self
# assign this module reference to node
for (_, nd_struct) in nd_struct_list:
nd_struct.parent_module_struct = self

def reset_as_parent(self):
def reset_as_parent_passed_in(self, parent_base):
"""
Reset all attributes and filled as a parent module of this module.
Reset all attributes and filled as a parent module of the module passed in.

Args:
parent_base(ModuleStruct): The base ModuleStruct to be passed in for ModuleStruct init.

Note:
This function must be called only after a deepcopy of this instance!
This function must be called only if the new ModuleStruct is a parent of parent_base.
"""
self.identifier.pop()
self.scope_depth = self.scope_depth - 1
self._set_pattern_id()
self._find_parent_module()
self.identifier = copy.deepcopy(parent_base.identifier)[:-1]
self.scope_depth = copy.deepcopy(parent_base.scope_depth) - 1
self.module_name = Scope.scope_to_module_name(self.identifier)
self.head_nd_struct = parent_base.head_nd_struct
self.head_nd_struct_index = parent_base.head_nd_struct_index
self.tail_nd_struct = parent_base.tail_nd_struct
self.tail_nd_struct_index = parent_base.tail_nd_struct_index
self._node_structs = list()
self._module_structs = list()
self._fragment = None
self._args_translator = None
self.initialized = True
self._set_pattern_id()
self._find_parent_module()
self.init_args_translator()
self._setting = None
self._parent_module_struct = None
self._nodes_structs_formal_args_list = list()

self._node_args_translation_list = list()

def _set_pattern_id(self):
@@ -435,7 +443,7 @@ class ModuleStruct:

@property
def external_successor_nodes_names(self) -> list:
"""Return all precursors nodes names not in this module."""
"""Return all successor nodes names not in this module."""
ret = []
for _, struct in self.get_generate_order():
if isinstance(struct, NodeStruct):


+ 66
- 104
mindinsight/mindconverter/graph_based_converter/generator/node_struct.py View File

@@ -15,15 +15,14 @@
"""Define the NodeStruct which stores all info. of a node."""
from collections import OrderedDict

from mindinsight.mindconverter.graph_based_converter.common.code_fragment import NewFragment
from mindinsight.mindconverter.graph_based_converter.generator.fragment_utils import FragmentHandler
from .scope_utils import Scope
from .args_translator import ArgsTranslation
from ..common.code_fragment import CodeFragment
from ..third_party_graph.onnx_graph_node import OnnxGraphNode
from ..common.global_context import GlobalContext
from ..constant import InputType
from ...common.exceptions import GeneratorError


class NodeStruct:
"""
Define a node struct which stores all info. to generate statement.
@@ -44,21 +43,11 @@ class NodeStruct:
self._args_translator = None
self._parent_module_struct = None
self.topo_idx = None
self.node_type = None
self.onnx_name = None
self.onnx_op = None
self.graph_node_ref = None
self.scope_name = None
self.ms_var_name = None
self.ms_op = None
self.ready_to_generate = False

# Define attributes converted from mapper
self.ms_params = dict()
self.ms_settings = dict()
self.ms_weights = dict()
self.ms_inputs = OrderedDict()

# Defined Scope class
self.scope = None

@@ -67,9 +56,6 @@ class NodeStruct:
# key is prec_node_name, value is x; For code line use
self.inputs_in_construct_header = OrderedDict()

# key is prec_node_name, value is its closet opt_var_name
self.inputs_in_parent_module = OrderedDict()

# Matched inputs will can be directly used by code line generation
self.matched_inputs = list()

@@ -86,7 +72,7 @@ class NodeStruct:

def ori_topo_idx(self):
"""Get the original topological index in the onnx graph."""
ori_name = self.identifier.replace('$', '').split('/')[-1].replace("::", '/')
ori_name = self._fragment.metadata.get('source')
self.onnx_name = ori_name
return self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_topo_idx.get(ori_name)

@@ -97,12 +83,20 @@ class NodeStruct:
Args:
idx (int): The index of the node in this module.
"""
def _remove_op_header(op_name):
"""Remove op header which indicating their sources of op set."""
op_name = op_name.replace('nn.', '')
op_name = op_name.replace('P.', '')
op_name = op_name.replace('onnx.', '')
return op_name

if idx is not None:
self.ms_var_name = self.ms_op.replace('nn.', '').replace('P.', '').lower() + '_' + str(idx)
self.ms_var_name = "{}_{}".format(_remove_op_header(self.ms_op), str(idx)).lower()
elif self.topo_idx is not None:
self.ms_var_name = self.ms_op.replace('nn.', '').replace('P.', '').lower() + '_' + str(self.topo_idx)
self.ms_var_name = "{}_{}".format(_remove_op_header(self.ms_op), str(self.topo_idx)).lower()
else:
raise ValueError("Unable to update var name when topo_idx is None.")
self.fragment.default_var['variable_name'] = self.ms_var_name

def _update_basics_from_gn(self, gn):
"""Update basic info from GraphNode."""
@@ -111,25 +105,13 @@ class NodeStruct:

def _update_from_onnx_gn(self, gn: OnnxGraphNode):
"""Update basic info from OnnxGraphNode."""
self.node_type = "OnnxGraphNode"
self._update_basics_from_gn(gn)

def _update_from_mapper(self, d):
"""Update info from mapper."""
if d.get('op_name'):
self.ms_op = d.get('op_name')
if d.get('params'):
self.ms_params = d.get('params')
if d.get('settings'):
self.ms_settings = d.get('settings')
if d.get('weights'):
self.ms_weights = d.get('weights')

def _update_from_fragment(self, frag: CodeFragment):
def _update_from_fragment(self, frag: NewFragment):
"""Update info from CodeFragment."""
self._fragment = frag
if frag.operation:
self.ms_op = frag.operation
self._fragment = FragmentHandler(frag)

if self.ms_op:
idx = self.GLOBAL_CONTEXT_MGR.latest_node_struct_count
self.update_var_name(idx=idx)

@@ -148,22 +130,13 @@ class NodeStruct:
"""
if not self._fragment:
raise ValueError("Initialize argument translator failed.")
if self._fragment.actual_args and translated_args:
self._args_translator = ArgsTranslation(self._fragment.actual_args, self.ms_var_name, translated_args)

def check_if_generate_ready(self):
"""Check if the NodeStruct is able to generate code."""
# check essential params exists
if all([self.identifier,
self.node_type,
self.scope_name,
self.ms_var_name,
self.ms_opt_var_name,
self.ms_op]):
self.ready_to_generate = True
if self._fragment.converted and self._fragment.default_var["args"] and translated_args:
self._args_translator = ArgsTranslation(self._fragment.default_var["args"],
self.ms_var_name,
translated_args)

@GeneratorError.check_except("Generator occurs an error when creating node struct.")
def update(self, arg, force_ready=False):
def update(self, arg):
"""
Pass Node info. to generator NodeStruct.

@@ -174,18 +147,11 @@ class NodeStruct:

if isinstance(arg, OnnxGraphNode):
self._update_from_onnx_gn(arg)
elif isinstance(arg, (dict, OrderedDict)):
self._update_from_mapper(arg)
elif isinstance(arg, CodeFragment):
elif isinstance(arg, NewFragment):
self._update_from_fragment(arg)
else:
raise TypeError("NodeStruct received an unsupported initializing argument.")

if force_ready:
self.ready_to_generate = True
else:
self.check_if_generate_ready()

@property
def identifier(self):
"""Return the identifier of the node."""
@@ -234,10 +200,30 @@ class NodeStruct:
"""Return the original onnx node reference."""
return self.GLOBAL_CONTEXT_MGR.onnx_nodes_collection.get(self.onnx_name)

@property
def ms_op(self):
"""Return the operation name in MindSpore."""
return self._fragment.default_var.get('operation')

@ms_op.setter
def ms_op(self, ms_op_name: str):
"""Set the operation name in MindSpore."""
self._fragment.default_var['operation'] = ms_op_name

@property
def ms_var_name(self):
"""Return the variable name of this Node in the MindSpore script."""
return self._fragment.default_var.get('variable_name')

@ms_var_name.setter
def ms_var_name(self, ms_var_name: str):
"""Set the variable name of this Node in the MindSpore script."""
self._fragment.default_var['variable_name'] = ms_var_name

@property
def ms_opt_var_name(self):
"""Return the output variable name of current node."""
return "{}_opt".format(self.ms_var_name).lower()
return self.fragment.fragment.get_outputs_by_idx(0)

@property
def args_translator(self):
@@ -282,25 +268,32 @@ class NodeStruct:
def parent_module_struct(self, ref):
self._parent_module_struct = ref

@property
def outputs_manager(self):
"""Return the outputs manager instance."""
return self.fragment.outputs_manager

@property
def outputs_in_construct(self):
"""Return the outputs var(s) in construct statement."""
return self.fragment.fragment.outputs()

# Code Generation funcs below

def code_line_in_init(self):
"""Initialization line of code in module init block."""
unconverted = False
if "onnx::" in self.ms_var_name:
unconverted = True
self.ms_var_name = self.ms_var_name.replace("onnx::", "")
left = "self.{}".format(self.ms_var_name)

args_list = list()
if self._args_translator is not None:
self.fragment.default_var['args'] = {**self._args_translator.actual_args,
**self._args_translator.formal_args}
args_list += self._args_translator.actual_args_to_str_list
args_list += self._args_translator.formal_args_to_str_list
else:
actual_args_str = ArgsTranslation.dict_data_to_args_str_list(self._fragment.actual_args)
actual_args_str = ArgsTranslation.dict_data_to_args_str_list(self._fragment.default_var['args'])
args_list += actual_args_str

if unconverted:
if not self._fragment.converted:
args_list.append('='.join(["input_shape", str(self._fragment.input_shape)]))
args_list.append('='.join(["output_shape", str(self._fragment.output_shape)]))
right = f"{self.ms_op.replace('::', '.')}({', '.join(args_list)})"
@@ -308,32 +301,6 @@ class NodeStruct:
right = f"{self.ms_op}({', '.join(args_list)})"
return left, right

def _get_correct_in_module_returns(self, prec_node, in_module_return):
"""
Find the correct precursor node name in return statement of its parent module.

Args:
prec_node (str): The onnx name of the precursor node given.
in_module_return (list[tuple]): The list of outputs which contains parent module identifier
and module opt_var_name.

Return:
str, correct opt_var_name to be passed in current node.
"""
found_return = False
for ret in in_module_return:
(md_identifier, input_name_to_use) = ret
p_node_struct = self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(prec_node)
# recursive check the p node parent
parent = p_node_struct
while not found_return:
parent = parent.parent_module_struct
if parent is None:
break
if parent.identifier == md_identifier:
return input_name_to_use
return None

def code_line_in_construct(self, inputs=None):
"""Construct line of code in module construct block. """
left = self.ms_opt_var_name
@@ -356,15 +323,7 @@ class NodeStruct:
if isinstance(inputs, str):
inputs = [inputs]

if self._fragment.code_setting and self._fragment.code_setting.op_ipt_type == InputType.LIST.value:
inputs = [str(tuple(inputs)).replace("\'", "")]

if self._fragment.code_setting and self._fragment.code_setting.op_extra_input:
for _, val in self._fragment.code_setting.op_extra_input.items():
inputs.append(str(val))

if self._fragment.code_setting and self._fragment.code_setting.op_extra_tensor:
inputs.append(f"self.{self.ms_var_name}_w")
self.fragment.default_var['inputs'] = inputs
right = f"self.{self.ms_var_name}({', '.join(inputs)})"
return left, right

@@ -394,7 +353,7 @@ class NodeStruct:
{} has duplicate inputs or not.".format(onnx_precursor_node_name, self.identifier))
self.inputs_in_construct_header[onnx_precursor_node_name] = header_x

def _check_target_node_internal(self, name: str) -> bool:
def check_target_node_internal(self, name: str) -> bool:
"""
Check given node under the same scope.

@@ -406,6 +365,9 @@ class NodeStruct:
if target_nd_struct is None and self.topo_idx == 0: # First node always has external input
return False

if target_nd_struct is None and (name in self.GLOBAL_CONTEXT_MGR.onnx_graph_info.get('graph_inputs')):
return False

if target_nd_struct is None:
raise ValueError("Unable to find the NodeStruct of given target node {}.".format(name))
return target_nd_struct.scope.path == self.scope.path
@@ -414,7 +376,7 @@ class NodeStruct:
def has_successor_node_external(self) -> bool:
"""Check if any successor_node is in external module."""
for name in self.successor_nodes_names:
if not self._check_target_node_internal(name):
if not self.check_target_node_internal(name):
return False

return True
@@ -423,10 +385,10 @@ class NodeStruct:
def precursor_nodes_names_external(self) -> list:
"""Return a list of external precursor nodes names."""
return [name for name in self.precursor_nodes_names
if not self._check_target_node_internal(name)]
if not self.check_target_node_internal(name)]

@property
def successor_nodes_names_external(self) -> list:
"""Return a list of external successor nodes names."""
return [name for name in self.successor_nodes_names
if not self._check_target_node_internal(name)]
if not self.check_target_node_internal(name)]

+ 21
- 8
mindinsight/mindconverter/graph_based_converter/mapper/base.py View File

@@ -103,25 +103,38 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC):
op_name_converter = getattr(converter, GET_OP_NAME)
params_converter = getattr(converter, GET_OP_PARAMS)
weights_converter = getattr(converter, GET_OP_WEIGHTS)
settings_converter = getattr(converter, GET_OP_SETTINGS)
template_generator = getattr(converter, GET_OP_TEMPLATE)
except (ModuleNotFoundError,) as e:
# If mapper can not be found, then skip it.
err_msg = f"Converting {op_name} failed, see {str(e)}"
log.error(err_msg)
return None, dict(), None, dict()
return None, None, None, None

try:
converter_name = op_name_converter(params=params, weights=weights, op_name=op_name)
converted_params = params_converter(params=params, weights=weights)
converted_weights = weights_converter(weights=weights) if weights else dict()
converted_params.update(converted_weights)
converted_settings = settings_converter(params=params, weights=weights)
if "input_shape" in converted_params:
converted_params.pop("input_shape")
if "output_shape" in converted_params:
converted_params.pop("output_shape")
# set to converted_weights to enable weight migration
_ = weights_converter(weights=weights) if weights else dict()
code_template, exchange_msg, outputs_list, outputs_mapping = template_generator(
operation=converter_name,
converted_params=converted_params,
raw_params=params,
weights=weights
)
except (AttributeError, KeyError, ValueError, TypeError, IndexError) as e:
err_msg = f"Converting {op_name} failed, see {str(e)}"
log.error(err_msg)
return None, dict(), None, dict()
code_template, exchange_msg, outputs_list, outputs_mapping = template_generator(
operation=op_name,
params=params,
weights=weights
)

return converter_name, converted_params, converted_settings, converted_weights
return code_template, exchange_msg, outputs_list, outputs_mapping

@staticmethod
def _operation_name_in_ms(*args, **kwargs):
@@ -142,7 +155,7 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC):
@staticmethod
def _generate_snippet_template(**kwargs):
op = kwargs.get("operation")
args = kwargs.get("converted_params")
args = kwargs.get("converted_params", dict())
weights = kwargs.get("weights")
if not op:
raise ValueError("Can not get MindSpore operation name.")


+ 10
- 0
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py View File

@@ -464,6 +464,13 @@ class OnnxDataLoader:
else:
self._global_context.onnx_node_inputs[node.name] = [input_node_name]

def _parse_graph(self):
"""Parse ONNX Graph Info For usage in generator."""
graph_inputs = [inp.name for inp in self.graph.input]
graph_outputs = [out.name for out in self.graph.output]
self._global_context.onnx_graph_info['graph_inputs'] = graph_inputs
self._global_context.onnx_graph_info['graph_outputs'] = graph_outputs

def initialize(self):
"""Initialize the OnnxDataLoader."""

@@ -473,6 +480,9 @@ class OnnxDataLoader:
log.error(str(err))
log.exception(err)

# Parse ONNX Graph level info
self._parse_graph()

# 1. parse all nodes
self._parse_nodes()



+ 1
- 4
tests/ut/mindconverter/graph_based_converter/mapper/test_mapper.py View File

@@ -217,8 +217,5 @@ class TestMappers:
def test_mapper(self, params):
"""Test mapper function."""
mapper = ONNXToMindSporeMapper()
converter_name, converted_params, converted_settings, _ = \
_, _, _, _ = \
mapper.convert(params['input']['op_name'], params['input']['params'], params['input']['weights'])
assert params['expected_output']['converter_name'] == converter_name
assert params['expected_output']['converted_params'] == converted_params
assert isinstance(converted_settings, Setting)

Loading…
Cancel
Save