Browse Source

minor bug fix for args translation

fix problem where incorrect logic for nodestruct's check target nodes internal when target node is graph inputs
update copyright
add a output mgr
adopt the generator with new fragment exchange msg. still need to adopt old ver. code_settings etc.
outputs mgr dev; adapt the new fragment. temp disable extra nodes and weights
adapt the NewFragment; re-imlement the module struct reset method
update ut mapper test for new fragment
tags/v1.2.0-rc1
liangtianshu 4 years ago
parent
commit
02773f4490
11 changed files with 537 additions and 194 deletions
  1. +6
    -1
      mindinsight/mindconverter/graph_based_converter/common/code_fragment.py
  2. +5
    -1
      mindinsight/mindconverter/graph_based_converter/common/global_context.py
  3. +200
    -0
      mindinsight/mindconverter/graph_based_converter/common/outputs.py
  4. +16
    -25
      mindinsight/mindconverter/graph_based_converter/generator/__init__.py
  5. +96
    -0
      mindinsight/mindconverter/graph_based_converter/generator/fragment_utils.py
  6. +84
    -27
      mindinsight/mindconverter/graph_based_converter/generator/generator.py
  7. +32
    -24
      mindinsight/mindconverter/graph_based_converter/generator/module_struct.py
  8. +66
    -104
      mindinsight/mindconverter/graph_based_converter/generator/node_struct.py
  9. +21
    -8
      mindinsight/mindconverter/graph_based_converter/mapper/base.py
  10. +10
    -0
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py
  11. +1
    -4
      tests/ut/mindconverter/graph_based_converter/mapper/test_mapper.py

+ 6
- 1
mindinsight/mindconverter/graph_based_converter/common/code_fragment.py View File

@@ -27,8 +27,9 @@ class Fragment(abc.ABC):
Args: Args:
operation (str): Operation name in MindSpore. operation (str): Operation name in MindSpore.
actual_args (dict): Actual arg values. actual_args (dict): Actual arg values.
input_shape (tuple): The input shape of the node.
output_shape (tuple): The output shape of the node.
settings (namedTuple): Code generation setting. settings (namedTuple): Code generation setting.

""" """


def __init__(self, operation, actual_args, input_shape, output_shape, settings=None): def __init__(self, operation, actual_args, input_shape, output_shape, settings=None):
@@ -46,6 +47,7 @@ class Fragment(abc.ABC):


@property @property
def code_setting(self): def code_setting(self):
"""Code Setting getter."""
return self._code_setting return self._code_setting


@property @property
@@ -152,10 +154,12 @@ class Fragment(abc.ABC):


@property @property
def input_shape(self): def input_shape(self):
"""Return the input shape."""
return self._input_shape return self._input_shape


@property @property
def output_shape(self): def output_shape(self):
"""Return the output shape."""
return self._output_shape return self._output_shape




@@ -196,6 +200,7 @@ class CodeFragment(Fragment):


@property @property
def trainable_params(self): def trainable_params(self):
"""Return the trainable parameters."""
return self._trainable_params return self._trainable_params






+ 5
- 1
mindinsight/mindconverter/graph_based_converter/common/global_context.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -14,6 +14,7 @@
# ============================================================================== # ==============================================================================
"""Define GlobalContext class to save required resources during whole conversion procedure.""" """Define GlobalContext class to save required resources during whole conversion procedure."""
from collections import OrderedDict from collections import OrderedDict
from .outputs import OutputStorage




class Singleton(type): class Singleton(type):
@@ -45,6 +46,7 @@ class GlobalContext(metaclass=Singleton):
self.onnx_node_name_to_topo_idx = dict() self.onnx_node_name_to_topo_idx = dict()
self.onnx_node_inputs = dict() self.onnx_node_inputs = dict()
self._onnx_tensors_collection = dict() self._onnx_tensors_collection = dict()
self.onnx_graph_info = dict()


# Define data stored from generator # Define data stored from generator
# Key as Node Identifier # Key as Node Identifier
@@ -72,6 +74,8 @@ class GlobalContext(metaclass=Singleton):
# key is target node (which use this opt), value is opt_var_name # key is target node (which use this opt), value is opt_var_name
self.extra_input_dict = dict() self.extra_input_dict = dict()


self.outputs_storage = OutputStorage()

def get_onnx_node_from_identifier(self, identifier): def get_onnx_node_from_identifier(self, identifier):
"""Return an OnnxUtils defined node by its identifier.""" """Return an OnnxUtils defined node by its identifier."""
onnx_node_name = self.node_struct_to_onnx_node_map.get(identifier) onnx_node_name = self.node_struct_to_onnx_node_map.get(identifier)


+ 200
- 0
mindinsight/mindconverter/graph_based_converter/common/outputs.py View File

@@ -0,0 +1,200 @@
# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Define basic classes for generator use."""
import abc
import copy
from typing import Union, Iterable

class BaseOutput:
"""
Define the class of output providing a universal nodes' and modules' output data collection.

Args:
output_mapping (tuple[tuple]): The mapping of outputs from onnx to mindspore.
"""
def __init__(self, output_mapping) -> None:
super(BaseOutput).__init__()
self.idx_in_ms_provider = output_mapping[0]
self.idx_in_onnx_provider = output_mapping[1]

# For multi users, key as user and value as index
self.idx_in_ms_user = dict()
self.idx_in_onnx_user = dict()

# The following attributes to be set by who referenced this object.
self.onnx_edge_name = None
self.to_external = False

@property
def ms_user(self):
"""Return the output's user in the MindSpore."""
return self.idx_in_ms_user.keys()

@property
def onnx_user(self):
"""Return the output's user in the ONNX."""
return self.idx_in_onnx_user.keys()

def deepcopy(self):
"""Return a deepcopy of self instance."""
return copy.deepcopy(self)


class BaseOutputManager(abc.ABC):
"""
Base Output Manager class.

Args:
output_mappings (list): A list of output mapping.
"""
def __init__(self, output_mappings):
if isinstance(self.__class__, ModuleOutputManager):
return
self._base_output_list = list()

# init base output obj
for mapping in output_mappings:
obj = BaseOutput(mapping)
self._base_output_list.append(obj)

@property
def outputs(self):
"""Return the list of BaseOutput in this manager."""
return self._base_output_list

@outputs.setter
def outputs(self, val: list):
"""Set the list of BaseOutput in this manager."""
for v in val:
if not isinstance(v, BaseOutput):
raise TypeError(f"{self.__class__} does not accept the type {type(v)} in the list given.")
self._base_output_list = val

@abc.abstractmethod
def deepcopy(self):
"""Return the deepcopy of this instance."""
cls = self.__class__
result = cls.__new__(cls)
result.outputs = list()
for out in self._base_output_list:
result.outputs.append(out.deepcopy())
return result


class NodeOutputManager(BaseOutputManager):
"""
Node Output Manager class.

Args:
identifier (str): The identifier of the node.
output_mappings (list): A list of the output mapping.
"""
def __init__(self, identifier, output_mappings=None) -> None:
super(NodeOutputManager, self).__init__(output_mappings)
self.identifier = identifier

def deepcopy(self):
new_mgr = super().deepcopy()
new_mgr.identifier = self.identifier
return new_mgr


class ModuleOutputManager(BaseOutputManager):
"""
Module Output Manager class.

Args:
identifier (str): The identifier of the module.
output_mappings (list): a list of output mapping
"""
def __init__(self, identifier, base_out: Union[BaseOutput, Iterable[BaseOutput]]) -> None:
super(ModuleOutputManager, self).__init__(None)
self.identifier = identifier
self._return_list_counter = 0
self._base_output_list = list()
if isinstance(base_out, BaseOutput):
self._base_output_list.append(base_out)
else:
self._base_output_list += base_out

@property
def return_num(self):
"""Return the number of outputs to be returned."""
return self._return_list_counter

@return_num.setter
def return_num(self, num: int):
"""Set the number of outputs to be returned."""
self._return_list_counter = num

def deepcopy(self):
"""Return a deepcopy of current instance."""
new_mgr = super().deepcopy()
new_mgr.identifier = self.identifier
new_mgr.return_num = self._return_list_counter
return new_mgr


class OutputStorage:
"""A class saves all outputs."""
def __init__(self):
self._base_output_edge_to_instance = dict()
self._base_output_edge_to_onnx_node_name = dict()
self._base_output_edge_to_ms_identifier = dict()

@property
def outputs_collections(self) -> dict:
"""Return the dict of edge name to output instance."""
return self._base_output_edge_to_instance

def onnx_name(self, output_edge) -> str:
"""Return the dict of edge name to onnx node name."""
return self._base_output_edge_to_onnx_node_name.get(output_edge)

def node_identifier(self, output_edge):
"""Return the dict of edge name to node identifier."""
return self._base_output_edge_to_ms_identifier.get(output_edge)

def add_output(self, out: BaseOutput) -> str:
"""
Add a BaseOutput instance to the storage.

Args:
out (BaseOutput): The BaseOutput instance.
"""
if out.onnx_edge_name:
self._base_output_edge_to_instance[out.onnx_edge_name] = out
else:
raise ValueError("Unable to add a BaseOutput instance with unknown ONNX edge.")

def add_onnx_node_name(self, edge: str, onnx_node_name: str):
"""
Add the onnx node name with the edge name.

Args:
edge (str): The edge name of this output.
onnx_node_name (str): The onnx node which has the edge.
"""
self._base_output_edge_to_onnx_node_name[edge] = onnx_node_name

def add_ms_identifier(self, edge: str, ms_identifier: str):
"""
Add the node identifier with the edge name.

Args:
edge (str): The edge name of this output.
ms_identifier (str): The identifier of the node which has the edge.
"""
self._base_output_edge_to_ms_identifier[edge] = ms_identifier

+ 16
- 25
mindinsight/mindconverter/graph_based_converter/generator/__init__.py View File

@@ -18,9 +18,10 @@ __all__ = ["batch_add_nodes"]
import re import re
import copy import copy


from mindinsight.mindconverter.graph_based_converter.common.code_fragment import CodeFragment
from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords
from .generator import Generator, CodeStruct from .generator import Generator, CodeStruct
from ..common.code_fragment import CodeFragment, NewFragment
from ..common.outputs import NodeOutputManager
from ..constant import ExchangeMessageKeywords




def _tf_model_node_name_reformat(node, node_name): def _tf_model_node_name_reformat(node, node_name):
@@ -56,6 +57,7 @@ def batch_add_nodes(graph_obj, mapper) -> Generator:


""" """
generator_inst = Generator() generator_inst = Generator()
external_inputs = graph_obj.user_provided_input_nodes
for node_name in graph_obj.nodes_in_topological_order: for node_name in graph_obj.nodes_in_topological_order:
node_inst = graph_obj.get_node(node_name) node_inst = graph_obj.get_node(node_name)
node_input = graph_obj.get_input_shape(node_name) node_input = graph_obj.get_input_shape(node_name)
@@ -66,16 +68,11 @@ def batch_add_nodes(graph_obj, mapper) -> Generator:
node_name = node_name_with_scope node_name = node_name_with_scope


node_inst.add_input_and_output_shape(node_input, node_output) node_inst.add_input_and_output_shape(node_input, node_output)
op_name, params, settings, weights = _convert_params(node_inst, mapper)
generator_inst.add_node(
node_name,
node_instance=node_inst,
node_fragment=CodeFragment(op_name, params,
settings,
node_inst.input_shape,
node_inst.output_shape,
weights)
)
code_template, exchange_msg, outputs_lst, outputs_mapping = _convert_params(node_inst, mapper, external_inputs)
outputs_mapping = NodeOutputManager(node_name, output_mappings=outputs_mapping)
fragment = NewFragment(data_entity=exchange_msg, code_template=code_template,
outputs=outputs_lst, outputs_mapping=outputs_mapping)
generator_inst.add_node(node_name, node_instance=node_inst, node_fragment=fragment)
return generator_inst return generator_inst




@@ -105,13 +102,14 @@ def _supply_graph_info(node, external_inputs):
} }




def _convert_params(node, mapper):
def _convert_params(node, mapper, external_inputs):
""" """
Call mapper to convert node's params from ONNX to MindSpore. Call mapper to convert node's params from ONNX to MindSpore.


Args: Args:
node (GraphNode): Our defined GraphNode instance. node (GraphNode): Our defined GraphNode instance.
mapper (Mapper): The mapper instance which indicating conversion method. mapper (Mapper): The mapper instance which indicating conversion method.
external_inputs (list[str]): External inputs provided by users.


Returns: Returns:
tuple[str, dict, dict, dict], op name in MindSpore, MindSpore parameters, tuple[str, dict, dict, dict], op name in MindSpore, MindSpore parameters,
@@ -121,18 +119,11 @@ def _convert_params(node, mapper):
params.update({"input_shape": node.input_shape, params.update({"input_shape": node.input_shape,
"output_shape": node.output_shape}) "output_shape": node.output_shape})


op_in_ms, ms_params, ms_settings, weights = mapper.convert(op_name=node.op_name,
params=params,
weights=node.weight)
if "input_shape" in ms_params:
ms_params.pop("input_shape")
if "output_shape" in ms_params:
ms_params.pop("output_shape")

if op_in_ms:
return op_in_ms, ms_params, ms_settings, weights

return node.op_name, node.node_params, dict(), dict()
code_template, exchange_msg, outputs_lst, outputs_order_mapping = mapper.convert(op_name=node.op_name,
params=params,
weights=node.weight)
exchange_msg[ExchangeMessageKeywords.METADATA.value] = _supply_graph_info(node, external_inputs)
return code_template, exchange_msg, outputs_lst, outputs_order_mapping




def _combine_external_inputs_with_precursor_nodes(node, external_inputs): def _combine_external_inputs_with_precursor_nodes(node, external_inputs):


+ 96
- 0
mindinsight/mindconverter/graph_based_converter/generator/fragment_utils.py View File

@@ -0,0 +1,96 @@
# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Miscellaneous Fragment related classes and functions. """

from mindinsight.mindconverter.graph_based_converter.common.code_fragment import NewFragment


class FragmentHandler:
"""
Define a handler to process the infomation contained by Fragment.

Args:
fragment (NewFragment): The refactored fragment class.
"""
def __init__(self, fragment: NewFragment):
self._fragment = fragment
# set the var in the fragment to be load and save.
self._target_var = "var_0"

@property
def target_var(self):
"""Return the target var name the handler currently set to be read."""
return self._target_var

@target_var.setter
def target_var(self, target):
"""Set the target var the handler will read."""
if not target in self.exchange_msg.keys():
raise ValueError(f"Unable to set target var {target} where fragment does not have it.")
self._target_var = target

@property
def fragment(self):
"""Return the fragment instance the handler currently processed."""
return self._fragment

@property
def converted(self):
"""Return the status of the op successfully converted."""
return bool(self._fragment.exchange_msg)

# The following section is intended for Fragment exchange message.
@property
def exchange_msg(self):
"""Return the exchange message dictionary the fragment contains."""
return self._fragment.exchange_msg

@property
def var(self):
"""Return the var dictionary the handler currently set to be processed."""
try:
return self.exchange_msg.get(self.target_var)
except AttributeError:
return None

@property
def default_var(self):
"""Return the default var dictionary the handler processed."""
try:
return self.exchange_msg.get("var_0")
except AttributeError:
return None

# For metadata
@property
def metadata(self):
"""Return the metadata of the onnx node info dictionary."""
return self._fragment.exchange_msg.get("metadata")

@property
def input_shape(self):
"""Return the input shape of this node."""
return self.metadata.get('inputs_shape')

@property
def output_shape(self):
"""Return the output shape of this node."""
return self.metadata.get('outputs_shape')

# For outputs
@property
def outputs_manager(self):
"""Return the outputs manager of this node."""
return self._fragment.outputs_mapping

+ 84
- 27
mindinsight/mindconverter/graph_based_converter/generator/generator.py View File

@@ -23,6 +23,7 @@ from .node_struct import NodeStruct
from .module_struct import ModuleStruct from .module_struct import ModuleStruct
from .args_translator import ArgsTranslationHelper from .args_translator import ArgsTranslationHelper
from ..common.global_context import GlobalContext from ..common.global_context import GlobalContext
from ..common.outputs import BaseOutput, ModuleOutputManager
from ...common.exceptions import GeneratorError from ...common.exceptions import GeneratorError
from ..common.name_mgr import GlobalVarNameMgr from ..common.name_mgr import GlobalVarNameMgr
from ..constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT, CodeFormatConfig, get_imported_module from ..constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT, CodeFormatConfig, get_imported_module
@@ -39,21 +40,10 @@ class CodeStruct:


def __init__(self, struct, repeated_submodules=None): def __init__(self, struct, repeated_submodules=None):
"""Initialize the CodeStruct.""" """Initialize the CodeStruct."""
self.output_order = None # output order
self.input = None # opt_var_name for prev. node
self.extra_input = list() # extra_input(s) at construct method args
self.output = None # opt_var_name for next node
self.extra_output = list() # extra_output(s)
self.extra_comment = None # comments for this code line / block.
self.code_line_list = list() # list of code line, a item is a line. self.code_line_list = list() # list of code line, a item is a line.
self._global_var_mgr = GlobalVarNameMgr() # var name procs within same module self._global_var_mgr = GlobalVarNameMgr() # var name procs within same module


self.formal_args_collections = None

if isinstance(struct, NodeStruct):
self.output_order = struct.topo_idx
if isinstance(struct, ModuleStruct): if isinstance(struct, ModuleStruct):
self.output_order = struct.head_nd_struct_index
self._generate_from_module_struct(struct, repeated_submodules) self._generate_from_module_struct(struct, repeated_submodules)


def _add_line(self, s): def _add_line(self, s):
@@ -102,13 +92,15 @@ class CodeStruct:
cons_lines = list() cons_lines = list()
for (_, struct) in md_struct.get_generate_order(): for (_, struct) in md_struct.get_generate_order():
if isinstance(struct, NodeStruct): # Generate code line for Node. if isinstance(struct, NodeStruct): # Generate code line for Node.
code_line_init = struct.code_line_in_init()
code_line_construct = struct.code_line_in_construct()
init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_init)}")
cons_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(code_line_construct)}")
# add extra tensor
if struct.fragment.code_setting and struct.fragment.code_setting.op_extra_tensor:
init_lines.append(f"{SECOND_LEVEL_INDENT}{' = '.join(struct.add_extra_tensor())}")
_ = struct.code_line_in_init()
_ = struct.code_line_in_construct()

init_str, cons_str = struct.fragment.fragment()
init_str = [f"{SECOND_LEVEL_INDENT}{x}" for x in init_str]
cons_str = [f"{SECOND_LEVEL_INDENT}{x}" for x in cons_str]
code_line_construct = cons_str
init_lines += init_str
cons_lines += cons_str


elif isinstance(struct, ModuleStruct): elif isinstance(struct, ModuleStruct):
# check if this instance generated CodeStruct # check if this instance generated CodeStruct
@@ -145,7 +137,8 @@ class CodeStruct:
returns.append(r) returns.append(r)
returns = list(set(returns)) returns = list(set(returns))
else: else:
returns = [code_line_construct[0]]
returns = [code_line_construct[0]] if isinstance(code_line_construct, tuple) \
else [code_line_construct[-1].replace(' ', '').split('=')[0]]
self.new_line = f"{SECOND_LEVEL_INDENT}return {', '.join(returns)}" self.new_line = f"{SECOND_LEVEL_INDENT}return {', '.join(returns)}"
self.new_line = f"{NEW_LINE * 2}" self.new_line = f"{NEW_LINE * 2}"
self.GLOBAL_CONTEXT.code_structs[md_struct.pattern_id] = self self.GLOBAL_CONTEXT.code_structs[md_struct.pattern_id] = self
@@ -156,9 +149,6 @@ class Generator:


def __init__(self): def __init__(self):
"""Init the generator.""" """Init the generator."""
# define basic attributes
self.framework = None

# define MUST have params # define MUST have params
self._node_struct_collections = OrderedDict() self._node_struct_collections = OrderedDict()
self._module_struct_collections = OrderedDict() self._module_struct_collections = OrderedDict()
@@ -244,9 +234,9 @@ class Generator:
if len(nd_struct_list) < 2: if len(nd_struct_list) < 2:
return formal_args return formal_args
(_, base_nd_struct) = nd_struct_list[0] (_, base_nd_struct) = nd_struct_list[0]
for (base_parameter, base_value) in base_nd_struct.fragment.actual_args.items(): # for each param
for (base_parameter, base_value) in base_nd_struct.fragment.default_var["args"].items(): # for each param
for (_, nd_struct) in nd_struct_list[1:]: for (_, nd_struct) in nd_struct_list[1:]:
compared_value = nd_struct.fragment.actual_args.get(base_parameter)
compared_value = nd_struct.fragment.default_var["args"].get(base_parameter)
if compared_value == base_value: if compared_value == base_value:
continue continue
formal_args.add(base_parameter) formal_args.add(base_parameter)
@@ -340,9 +330,9 @@ class Generator:
self.module_structs[str(parent_scope)] = parent_md_struct self.module_structs[str(parent_scope)] = parent_md_struct
else: else:
# 1B. not has parent, generate a new ModuleStruct # 1B. not has parent, generate a new ModuleStruct
parent_md_struct = copy.deepcopy(md_struct) # use this submodule to create a parent module
# use this submodule to create a parent module
parent_md_struct = ModuleStruct(None, init_as_parent=True, parent_base=md_struct)
# rewrite parent md struct # rewrite parent md struct
parent_md_struct.reset_as_parent()
parent_md_struct.add_submodule(md_struct) parent_md_struct.add_submodule(md_struct)
self.module_structs[str(parent_scope)] = parent_md_struct self.module_structs[str(parent_scope)] = parent_md_struct
sub = self.module_structs.pop(scope_path_str) # remove this submodule from collections sub = self.module_structs.pop(scope_path_str) # remove this submodule from collections
@@ -378,6 +368,7 @@ class Generator:
self._update_all_modules_args_translator() self._update_all_modules_args_translator()


# 6. Update all nodes and moudles input/output # 6. Update all nodes and moudles input/output
# Enable build_output_connections later.
self.module_structs.get('[]').allocate_construct_header_x() self.module_structs.get('[]').allocate_construct_header_x()
self.module_structs.get('[]').collect_returns() self.module_structs.get('[]').collect_returns()


@@ -488,7 +479,7 @@ class Generator:


for code_struct in self._global_context.code_structs.values(): for code_struct in self._global_context.code_structs.values():
for line in code_struct.code_line_list: for line in code_struct.code_line_list:
outputs.append(line.replace("onnx::", ""))
outputs.append(line)


formatted_code, _ = FormatCode("\n".join(outputs), formatted_code, _ = FormatCode("\n".join(outputs),
style_config=CodeFormatConfig.PEP8.value) style_config=CodeFormatConfig.PEP8.value)
@@ -575,3 +566,69 @@ class Generator:
if m_num == module_num: if m_num == module_num:
ret.append(nd_struct_list) ret.append(nd_struct_list)
return ret return ret

def build_outputs_connection(self):
"""Build all nodes and modules outputs connections."""
for nd_struct in self.node_structs.values():
# for each output in curr node output manager
for out in nd_struct.outputs_manager.outputs:
# Set the onnx output edge name to this output
out.onnx_edge_name = nd_struct.fragment.metadata.get('outputs')[out.idx_in_onnx_provider]
self._global_context.outputs_storage.add_output(out)
self._global_context.outputs_storage.add_onnx_node_name(out.onnx_edge_name,
nd_struct.fragment.metadata.get('source'))
self._global_context.outputs_storage.add_ms_identifier(out.onnx_edge_name, nd_struct.identifier)

# Set input with existing output mapping
for idx, inp in enumerate(nd_struct.fragment.metadata.get('inputs')):
if inp in self._global_context.outputs_storage.outputs_collections:
output_obj = self._global_context.outputs_storage.outputs_collections[inp]
output_obj.idx_in_onnx_user[nd_struct.onnx_name] = idx

# set ms_user idx, need to modify if not follow onnx order
output_obj.idx_in_ms_user[nd_struct.identifier] = idx

# set this output to be returned to external
output_obj.to_external = not(nd_struct.check_target_node_internal(
self._global_context.outputs_storage.onnx_name(inp)
))

# collect submodule's and nodes' outputs mgr
self._collect_output_mgr()

def _collect_output_mgr(self, module=None):
"""
Collect the outputs manager from nodes and submodules the current module has.

Args:
module (ModuleStruct): The module struct collecting its nodes and submodules.
"""
root_module = module or self.get_module_struct('[]')
output_mgr_list = list()
for struct in root_module.get_generate_order():
if isinstance(struct, tuple):
# index 1 is the NodeStruct while 0 is topological index.
struct = struct[1]
if isinstance(struct, ModuleStruct) and struct.outputs_manager is None:
self._collect_output_mgr(module=struct)
for out in struct.outputs_manager.outputs:
if Generator.check_output_need_to_external(root_module, out):
output_mgr_list.append(out)
root_module.outputs_manager = ModuleOutputManager(root_module.identifier, base_out=output_mgr_list)

@staticmethod
def check_output_need_to_external(root_module: ModuleStruct, checked_output: BaseOutput):
"""
Check the output still need to be returned to module external.

Args:
root_module (ModuleStruct): The Module that the output to be determined.
checked_output (BaseOutput): The output to be checked whether returned by the Module.

Returns:
bool, True if the output need to be returned to the module external.
"""
for user in checked_output.onnx_user:
if user in root_module.external_successor_nodes_names:
return True
return False

+ 32
- 24
mindinsight/mindconverter/graph_based_converter/generator/module_struct.py View File

@@ -14,6 +14,7 @@
# ============================================================================== # ==============================================================================
"""Define a struct for module converted and save all required information here.""" """Define a struct for module converted and save all required information here."""


import copy
from collections import OrderedDict from collections import OrderedDict


from .node_struct import NodeStruct from .node_struct import NodeStruct
@@ -31,10 +32,12 @@ class ModuleStruct:


Args: Args:
args (list): A list of node structs. args (list): A list of node structs.
init_as_parent (bool): Control init method if the ModuleStruct be init as a parent module struct.
parent_base (ModuleStruct): The base ModuleStruct the current ModuleStruct to be init as.
""" """
GLOBAL_CONTEXT_MGR = GlobalContext() GLOBAL_CONTEXT_MGR = GlobalContext()


def __init__(self, nd_struct_list):
def __init__(self, nd_struct_list, init_as_parent=False, parent_base=None):
"""Init. a module by NodeStructs.""" """Init. a module by NodeStructs."""
self.pattern_id = -1 # pattern num, -1 as Main module self.pattern_id = -1 # pattern num, -1 as Main module
self.pattern_uid = -1 # unique module id for this pattern self.pattern_uid = -1 # unique module id for this pattern
@@ -53,19 +56,15 @@ class ModuleStruct:


self._fragment = None self._fragment = None
self._args_translator = None self._args_translator = None
self._setting = None
self._parent_module_struct = None self._parent_module_struct = None
# only store original formal args name, not global # only store original formal args name, not global
self._nodes_structs_formal_args_list = list() self._nodes_structs_formal_args_list = list()
# only store translated (globalized) formal args
self._nodes_structs_formal_args_translated_list = list()


# define other settings here # define other settings here
self._node_args_translation_list = list() self._node_args_translation_list = list()
self._var_name_mgr = LocalVarNameMgr() self._var_name_mgr = LocalVarNameMgr()
self.construct_header_x = OrderedDict() # key is header x, value is precursors onnx name self.construct_header_x = OrderedDict() # key is header x, value is precursors onnx name
self.inputs_in_construct_header = OrderedDict() # key is precursors onnx name, value is x in parent construct self.inputs_in_construct_header = OrderedDict() # key is precursors onnx name, value is x in parent construct
self.inputs_in_parent_module = OrderedDict() # key is prec_node_name, value is its closet opt_var_name


# key is node's onnx name(output provider), value is (provider_succ_name, opt_var_name) # key is node's onnx name(output provider), value is (provider_succ_name, opt_var_name)
self.outputs_collection = dict() self.outputs_collection = dict()
@@ -74,40 +73,49 @@ class ModuleStruct:
# key is ext. succ node onnx name, value is local opt_var # key is ext. succ node onnx name, value is local opt_var
self.external_successor_local_returns_map = OrderedDict() self.external_successor_local_returns_map = OrderedDict()


# key is node's onnx_name, value is (successor_name, opt_var_name) <- node's level
self.outputs_collection = dict()
# Define outputs manager, note this will be assigned later by Generator.
self.outputs_manager = None


# start initialization
if not self.initialized:
self._init_module(nd_struct_list)
if init_as_parent and (parent_base is not None):
self.reset_as_parent_passed_in(parent_base)
else: else:
self._update_module(nd_struct_list)
# start initialization
if not self.initialized:
self._init_module(nd_struct_list)
else:
self._update_module(nd_struct_list)


# assign this module reference to node
for (_, nd_struct) in nd_struct_list:
nd_struct.parent_module_struct = self
# assign this module reference to node
for (_, nd_struct) in nd_struct_list:
nd_struct.parent_module_struct = self


def reset_as_parent(self):
def reset_as_parent_passed_in(self, parent_base):
""" """
Reset all attributes and filled as a parent module of this module.
Reset all attributes and filled as a parent module of the module passed in.

Args:
parent_base(ModuleStruct): The base ModuleStruct to be passed in for ModuleStruct init.


Note: Note:
This function must be called only after a deepcopy of this instance!
This function must be called only if the new ModuleStruct is a parent of parent_base.
""" """
self.identifier.pop()
self.scope_depth = self.scope_depth - 1
self._set_pattern_id()
self._find_parent_module()
self.identifier = copy.deepcopy(parent_base.identifier)[:-1]
self.scope_depth = copy.deepcopy(parent_base.scope_depth) - 1
self.module_name = Scope.scope_to_module_name(self.identifier) self.module_name = Scope.scope_to_module_name(self.identifier)
self.head_nd_struct = parent_base.head_nd_struct
self.head_nd_struct_index = parent_base.head_nd_struct_index
self.tail_nd_struct = parent_base.tail_nd_struct
self.tail_nd_struct_index = parent_base.tail_nd_struct_index
self._node_structs = list() self._node_structs = list()
self._module_structs = list() self._module_structs = list()
self._fragment = None self._fragment = None
self._args_translator = None self._args_translator = None
self.initialized = True
self._set_pattern_id()
self._find_parent_module()
self.init_args_translator() self.init_args_translator()
self._setting = None
self._parent_module_struct = None self._parent_module_struct = None
self._nodes_structs_formal_args_list = list() self._nodes_structs_formal_args_list = list()

self._node_args_translation_list = list() self._node_args_translation_list = list()


def _set_pattern_id(self): def _set_pattern_id(self):
@@ -435,7 +443,7 @@ class ModuleStruct:


@property @property
def external_successor_nodes_names(self) -> list: def external_successor_nodes_names(self) -> list:
"""Return all precursors nodes names not in this module."""
"""Return all successor nodes names not in this module."""
ret = [] ret = []
for _, struct in self.get_generate_order(): for _, struct in self.get_generate_order():
if isinstance(struct, NodeStruct): if isinstance(struct, NodeStruct):


+ 66
- 104
mindinsight/mindconverter/graph_based_converter/generator/node_struct.py View File

@@ -15,15 +15,14 @@
"""Define the NodeStruct which stores all info. of a node.""" """Define the NodeStruct which stores all info. of a node."""
from collections import OrderedDict from collections import OrderedDict


from mindinsight.mindconverter.graph_based_converter.common.code_fragment import NewFragment
from mindinsight.mindconverter.graph_based_converter.generator.fragment_utils import FragmentHandler
from .scope_utils import Scope from .scope_utils import Scope
from .args_translator import ArgsTranslation from .args_translator import ArgsTranslation
from ..common.code_fragment import CodeFragment
from ..third_party_graph.onnx_graph_node import OnnxGraphNode from ..third_party_graph.onnx_graph_node import OnnxGraphNode
from ..common.global_context import GlobalContext from ..common.global_context import GlobalContext
from ..constant import InputType
from ...common.exceptions import GeneratorError from ...common.exceptions import GeneratorError



class NodeStruct: class NodeStruct:
""" """
Define a node struct which stores all info. to generate statement. Define a node struct which stores all info. to generate statement.
@@ -44,21 +43,11 @@ class NodeStruct:
self._args_translator = None self._args_translator = None
self._parent_module_struct = None self._parent_module_struct = None
self.topo_idx = None self.topo_idx = None
self.node_type = None
self.onnx_name = None self.onnx_name = None
self.onnx_op = None
self.graph_node_ref = None self.graph_node_ref = None
self.scope_name = None self.scope_name = None
self.ms_var_name = None
self.ms_op = None
self.ready_to_generate = False self.ready_to_generate = False


# Define attributes converted from mapper
self.ms_params = dict()
self.ms_settings = dict()
self.ms_weights = dict()
self.ms_inputs = OrderedDict()

# Defined Scope class # Defined Scope class
self.scope = None self.scope = None


@@ -67,9 +56,6 @@ class NodeStruct:
# key is prec_node_name, value is x; For code line use # key is prec_node_name, value is x; For code line use
self.inputs_in_construct_header = OrderedDict() self.inputs_in_construct_header = OrderedDict()


# key is prec_node_name, value is its closet opt_var_name
self.inputs_in_parent_module = OrderedDict()

# Matched inputs will can be directly used by code line generation # Matched inputs will can be directly used by code line generation
self.matched_inputs = list() self.matched_inputs = list()


@@ -86,7 +72,7 @@ class NodeStruct:


def ori_topo_idx(self): def ori_topo_idx(self):
"""Get the original topological index in the onnx graph.""" """Get the original topological index in the onnx graph."""
ori_name = self.identifier.replace('$', '').split('/')[-1].replace("::", '/')
ori_name = self._fragment.metadata.get('source')
self.onnx_name = ori_name self.onnx_name = ori_name
return self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_topo_idx.get(ori_name) return self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_topo_idx.get(ori_name)


@@ -97,12 +83,20 @@ class NodeStruct:
Args: Args:
idx (int): The index of the node in this module. idx (int): The index of the node in this module.
""" """
def _remove_op_header(op_name):
"""Remove op header which indicating their sources of op set."""
op_name = op_name.replace('nn.', '')
op_name = op_name.replace('P.', '')
op_name = op_name.replace('onnx.', '')
return op_name

if idx is not None: if idx is not None:
self.ms_var_name = self.ms_op.replace('nn.', '').replace('P.', '').lower() + '_' + str(idx)
self.ms_var_name = "{}_{}".format(_remove_op_header(self.ms_op), str(idx)).lower()
elif self.topo_idx is not None: elif self.topo_idx is not None:
self.ms_var_name = self.ms_op.replace('nn.', '').replace('P.', '').lower() + '_' + str(self.topo_idx)
self.ms_var_name = "{}_{}".format(_remove_op_header(self.ms_op), str(self.topo_idx)).lower()
else: else:
raise ValueError("Unable to update var name when topo_idx is None.") raise ValueError("Unable to update var name when topo_idx is None.")
self.fragment.default_var['variable_name'] = self.ms_var_name


def _update_basics_from_gn(self, gn): def _update_basics_from_gn(self, gn):
"""Update basic info from GraphNode.""" """Update basic info from GraphNode."""
@@ -111,25 +105,13 @@ class NodeStruct:


def _update_from_onnx_gn(self, gn: OnnxGraphNode): def _update_from_onnx_gn(self, gn: OnnxGraphNode):
"""Update basic info from OnnxGraphNode.""" """Update basic info from OnnxGraphNode."""
self.node_type = "OnnxGraphNode"
self._update_basics_from_gn(gn) self._update_basics_from_gn(gn)


def _update_from_mapper(self, d):
"""Update info from mapper."""
if d.get('op_name'):
self.ms_op = d.get('op_name')
if d.get('params'):
self.ms_params = d.get('params')
if d.get('settings'):
self.ms_settings = d.get('settings')
if d.get('weights'):
self.ms_weights = d.get('weights')

def _update_from_fragment(self, frag: CodeFragment):
def _update_from_fragment(self, frag: NewFragment):
"""Update info from CodeFragment.""" """Update info from CodeFragment."""
self._fragment = frag
if frag.operation:
self.ms_op = frag.operation
self._fragment = FragmentHandler(frag)

if self.ms_op:
idx = self.GLOBAL_CONTEXT_MGR.latest_node_struct_count idx = self.GLOBAL_CONTEXT_MGR.latest_node_struct_count
self.update_var_name(idx=idx) self.update_var_name(idx=idx)


@@ -148,22 +130,13 @@ class NodeStruct:
""" """
if not self._fragment: if not self._fragment:
raise ValueError("Initialize argument translator failed.") raise ValueError("Initialize argument translator failed.")
if self._fragment.actual_args and translated_args:
self._args_translator = ArgsTranslation(self._fragment.actual_args, self.ms_var_name, translated_args)

def check_if_generate_ready(self):
"""Check if the NodeStruct is able to generate code."""
# check essential params exists
if all([self.identifier,
self.node_type,
self.scope_name,
self.ms_var_name,
self.ms_opt_var_name,
self.ms_op]):
self.ready_to_generate = True
if self._fragment.converted and self._fragment.default_var["args"] and translated_args:
self._args_translator = ArgsTranslation(self._fragment.default_var["args"],
self.ms_var_name,
translated_args)


@GeneratorError.check_except("Generator occurs an error when creating node struct.") @GeneratorError.check_except("Generator occurs an error when creating node struct.")
def update(self, arg, force_ready=False):
def update(self, arg):
""" """
Pass Node info. to generator NodeStruct. Pass Node info. to generator NodeStruct.


@@ -174,18 +147,11 @@ class NodeStruct:


if isinstance(arg, OnnxGraphNode): if isinstance(arg, OnnxGraphNode):
self._update_from_onnx_gn(arg) self._update_from_onnx_gn(arg)
elif isinstance(arg, (dict, OrderedDict)):
self._update_from_mapper(arg)
elif isinstance(arg, CodeFragment):
elif isinstance(arg, NewFragment):
self._update_from_fragment(arg) self._update_from_fragment(arg)
else: else:
raise TypeError("NodeStruct received an unsupported initializing argument.") raise TypeError("NodeStruct received an unsupported initializing argument.")


if force_ready:
self.ready_to_generate = True
else:
self.check_if_generate_ready()

@property @property
def identifier(self): def identifier(self):
"""Return the identifier of the node.""" """Return the identifier of the node."""
@@ -234,10 +200,30 @@ class NodeStruct:
"""Return the original onnx node reference.""" """Return the original onnx node reference."""
return self.GLOBAL_CONTEXT_MGR.onnx_nodes_collection.get(self.onnx_name) return self.GLOBAL_CONTEXT_MGR.onnx_nodes_collection.get(self.onnx_name)


@property
def ms_op(self):
"""Return the operation name in MindSpore."""
return self._fragment.default_var.get('operation')

@ms_op.setter
def ms_op(self, ms_op_name: str):
"""Set the operation name in MindSpore."""
self._fragment.default_var['operation'] = ms_op_name

@property
def ms_var_name(self):
"""Return the variable name of this Node in the MindSpore script."""
return self._fragment.default_var.get('variable_name')

@ms_var_name.setter
def ms_var_name(self, ms_var_name: str):
"""Set the variable name of this Node in the MindSpore script."""
self._fragment.default_var['variable_name'] = ms_var_name

@property @property
def ms_opt_var_name(self): def ms_opt_var_name(self):
"""Return the output variable name of current node.""" """Return the output variable name of current node."""
return "{}_opt".format(self.ms_var_name).lower()
return self.fragment.fragment.get_outputs_by_idx(0)


@property @property
def args_translator(self): def args_translator(self):
@@ -282,25 +268,32 @@ class NodeStruct:
def parent_module_struct(self, ref): def parent_module_struct(self, ref):
self._parent_module_struct = ref self._parent_module_struct = ref


@property
def outputs_manager(self):
"""Return the outputs manager instance."""
return self.fragment.outputs_manager

@property
def outputs_in_construct(self):
"""Return the outputs var(s) in construct statement."""
return self.fragment.fragment.outputs()

# Code Generation funcs below # Code Generation funcs below


def code_line_in_init(self): def code_line_in_init(self):
"""Initialization line of code in module init block.""" """Initialization line of code in module init block."""
unconverted = False
if "onnx::" in self.ms_var_name:
unconverted = True
self.ms_var_name = self.ms_var_name.replace("onnx::", "")
left = "self.{}".format(self.ms_var_name) left = "self.{}".format(self.ms_var_name)

args_list = list() args_list = list()
if self._args_translator is not None: if self._args_translator is not None:
self.fragment.default_var['args'] = {**self._args_translator.actual_args,
**self._args_translator.formal_args}
args_list += self._args_translator.actual_args_to_str_list args_list += self._args_translator.actual_args_to_str_list
args_list += self._args_translator.formal_args_to_str_list args_list += self._args_translator.formal_args_to_str_list
else: else:
actual_args_str = ArgsTranslation.dict_data_to_args_str_list(self._fragment.actual_args)
actual_args_str = ArgsTranslation.dict_data_to_args_str_list(self._fragment.default_var['args'])
args_list += actual_args_str args_list += actual_args_str


if unconverted:
if not self._fragment.converted:
args_list.append('='.join(["input_shape", str(self._fragment.input_shape)])) args_list.append('='.join(["input_shape", str(self._fragment.input_shape)]))
args_list.append('='.join(["output_shape", str(self._fragment.output_shape)])) args_list.append('='.join(["output_shape", str(self._fragment.output_shape)]))
right = f"{self.ms_op.replace('::', '.')}({', '.join(args_list)})" right = f"{self.ms_op.replace('::', '.')}({', '.join(args_list)})"
@@ -308,32 +301,6 @@ class NodeStruct:
right = f"{self.ms_op}({', '.join(args_list)})" right = f"{self.ms_op}({', '.join(args_list)})"
return left, right return left, right


def _get_correct_in_module_returns(self, prec_node, in_module_return):
"""
Find the correct precursor node name in return statement of its parent module.

Args:
prec_node (str): The onnx name of the precursor node given.
in_module_return (list[tuple]): The list of outputs which contains parent module identifier
and module opt_var_name.

Return:
str, correct opt_var_name to be passed in current node.
"""
found_return = False
for ret in in_module_return:
(md_identifier, input_name_to_use) = ret
p_node_struct = self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(prec_node)
# recursive check the p node parent
parent = p_node_struct
while not found_return:
parent = parent.parent_module_struct
if parent is None:
break
if parent.identifier == md_identifier:
return input_name_to_use
return None

def code_line_in_construct(self, inputs=None): def code_line_in_construct(self, inputs=None):
"""Construct line of code in module construct block. """ """Construct line of code in module construct block. """
left = self.ms_opt_var_name left = self.ms_opt_var_name
@@ -356,15 +323,7 @@ class NodeStruct:
if isinstance(inputs, str): if isinstance(inputs, str):
inputs = [inputs] inputs = [inputs]


if self._fragment.code_setting and self._fragment.code_setting.op_ipt_type == InputType.LIST.value:
inputs = [str(tuple(inputs)).replace("\'", "")]

if self._fragment.code_setting and self._fragment.code_setting.op_extra_input:
for _, val in self._fragment.code_setting.op_extra_input.items():
inputs.append(str(val))

if self._fragment.code_setting and self._fragment.code_setting.op_extra_tensor:
inputs.append(f"self.{self.ms_var_name}_w")
self.fragment.default_var['inputs'] = inputs
right = f"self.{self.ms_var_name}({', '.join(inputs)})" right = f"self.{self.ms_var_name}({', '.join(inputs)})"
return left, right return left, right


@@ -394,7 +353,7 @@ class NodeStruct:
{} has duplicate inputs or not.".format(onnx_precursor_node_name, self.identifier)) {} has duplicate inputs or not.".format(onnx_precursor_node_name, self.identifier))
self.inputs_in_construct_header[onnx_precursor_node_name] = header_x self.inputs_in_construct_header[onnx_precursor_node_name] = header_x


def _check_target_node_internal(self, name: str) -> bool:
def check_target_node_internal(self, name: str) -> bool:
""" """
Check given node under the same scope. Check given node under the same scope.


@@ -406,6 +365,9 @@ class NodeStruct:
if target_nd_struct is None and self.topo_idx == 0: # First node always has external input if target_nd_struct is None and self.topo_idx == 0: # First node always has external input
return False return False


if target_nd_struct is None and (name in self.GLOBAL_CONTEXT_MGR.onnx_graph_info.get('graph_inputs')):
return False

if target_nd_struct is None: if target_nd_struct is None:
raise ValueError("Unable to find the NodeStruct of given target node {}.".format(name)) raise ValueError("Unable to find the NodeStruct of given target node {}.".format(name))
return target_nd_struct.scope.path == self.scope.path return target_nd_struct.scope.path == self.scope.path
@@ -414,7 +376,7 @@ class NodeStruct:
def has_successor_node_external(self) -> bool: def has_successor_node_external(self) -> bool:
"""Check if any successor_node is in external module.""" """Check if any successor_node is in external module."""
for name in self.successor_nodes_names: for name in self.successor_nodes_names:
if not self._check_target_node_internal(name):
if not self.check_target_node_internal(name):
return False return False


return True return True
@@ -423,10 +385,10 @@ class NodeStruct:
def precursor_nodes_names_external(self) -> list: def precursor_nodes_names_external(self) -> list:
"""Return a list of external precursor nodes names.""" """Return a list of external precursor nodes names."""
return [name for name in self.precursor_nodes_names return [name for name in self.precursor_nodes_names
if not self._check_target_node_internal(name)]
if not self.check_target_node_internal(name)]


@property @property
def successor_nodes_names_external(self) -> list: def successor_nodes_names_external(self) -> list:
"""Return a list of external successor nodes names.""" """Return a list of external successor nodes names."""
return [name for name in self.successor_nodes_names return [name for name in self.successor_nodes_names
if not self._check_target_node_internal(name)]
if not self.check_target_node_internal(name)]

+ 21
- 8
mindinsight/mindconverter/graph_based_converter/mapper/base.py View File

@@ -103,25 +103,38 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC):
op_name_converter = getattr(converter, GET_OP_NAME) op_name_converter = getattr(converter, GET_OP_NAME)
params_converter = getattr(converter, GET_OP_PARAMS) params_converter = getattr(converter, GET_OP_PARAMS)
weights_converter = getattr(converter, GET_OP_WEIGHTS) weights_converter = getattr(converter, GET_OP_WEIGHTS)
settings_converter = getattr(converter, GET_OP_SETTINGS)
template_generator = getattr(converter, GET_OP_TEMPLATE)
except (ModuleNotFoundError,) as e: except (ModuleNotFoundError,) as e:
# If mapper can not be found, then skip it. # If mapper can not be found, then skip it.
err_msg = f"Converting {op_name} failed, see {str(e)}" err_msg = f"Converting {op_name} failed, see {str(e)}"
log.error(err_msg) log.error(err_msg)
return None, dict(), None, dict()
return None, None, None, None


try: try:
converter_name = op_name_converter(params=params, weights=weights, op_name=op_name) converter_name = op_name_converter(params=params, weights=weights, op_name=op_name)
converted_params = params_converter(params=params, weights=weights) converted_params = params_converter(params=params, weights=weights)
converted_weights = weights_converter(weights=weights) if weights else dict()
converted_params.update(converted_weights)
converted_settings = settings_converter(params=params, weights=weights)
if "input_shape" in converted_params:
converted_params.pop("input_shape")
if "output_shape" in converted_params:
converted_params.pop("output_shape")
# set to converted_weights to enable weight migration
_ = weights_converter(weights=weights) if weights else dict()
code_template, exchange_msg, outputs_list, outputs_mapping = template_generator(
operation=converter_name,
converted_params=converted_params,
raw_params=params,
weights=weights
)
except (AttributeError, KeyError, ValueError, TypeError, IndexError) as e: except (AttributeError, KeyError, ValueError, TypeError, IndexError) as e:
err_msg = f"Converting {op_name} failed, see {str(e)}" err_msg = f"Converting {op_name} failed, see {str(e)}"
log.error(err_msg) log.error(err_msg)
return None, dict(), None, dict()
code_template, exchange_msg, outputs_list, outputs_mapping = template_generator(
operation=op_name,
params=params,
weights=weights
)


return converter_name, converted_params, converted_settings, converted_weights
return code_template, exchange_msg, outputs_list, outputs_mapping


@staticmethod @staticmethod
def _operation_name_in_ms(*args, **kwargs): def _operation_name_in_ms(*args, **kwargs):
@@ -142,7 +155,7 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC):
@staticmethod @staticmethod
def _generate_snippet_template(**kwargs): def _generate_snippet_template(**kwargs):
op = kwargs.get("operation") op = kwargs.get("operation")
args = kwargs.get("converted_params")
args = kwargs.get("converted_params", dict())
weights = kwargs.get("weights") weights = kwargs.get("weights")
if not op: if not op:
raise ValueError("Can not get MindSpore operation name.") raise ValueError("Can not get MindSpore operation name.")


+ 10
- 0
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py View File

@@ -464,6 +464,13 @@ class OnnxDataLoader:
else: else:
self._global_context.onnx_node_inputs[node.name] = [input_node_name] self._global_context.onnx_node_inputs[node.name] = [input_node_name]


def _parse_graph(self):
"""Parse ONNX Graph Info For usage in generator."""
graph_inputs = [inp.name for inp in self.graph.input]
graph_outputs = [out.name for out in self.graph.output]
self._global_context.onnx_graph_info['graph_inputs'] = graph_inputs
self._global_context.onnx_graph_info['graph_outputs'] = graph_outputs

def initialize(self): def initialize(self):
"""Initialize the OnnxDataLoader.""" """Initialize the OnnxDataLoader."""


@@ -473,6 +480,9 @@ class OnnxDataLoader:
log.error(str(err)) log.error(str(err))
log.exception(err) log.exception(err)


# Parse ONNX Graph level info
self._parse_graph()

# 1. parse all nodes # 1. parse all nodes
self._parse_nodes() self._parse_nodes()




+ 1
- 4
tests/ut/mindconverter/graph_based_converter/mapper/test_mapper.py View File

@@ -217,8 +217,5 @@ class TestMappers:
def test_mapper(self, params): def test_mapper(self, params):
"""Test mapper function.""" """Test mapper function."""
mapper = ONNXToMindSporeMapper() mapper = ONNXToMindSporeMapper()
converter_name, converted_params, converted_settings, _ = \
_, _, _, _ = \
mapper.convert(params['input']['op_name'], params['input']['params'], params['input']['weights']) mapper.convert(params['input']['op_name'], params['input']['params'], params['input']['weights'])
assert params['expected_output']['converter_name'] == converter_name
assert params['expected_output']['converted_params'] == converted_params
assert isinstance(converted_settings, Setting)

Loading…
Cancel
Save