Browse Source

!1217 [MindConverter] Add the outputs manager and shared weights processing

From: @liangtianshu
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
be07a6834b
14 changed files with 738 additions and 130 deletions
  1. +4
    -1
      mindinsight/mindconverter/graph_based_converter/common/code_fragment.py
  2. +4
    -1
      mindinsight/mindconverter/graph_based_converter/common/global_context.py
  3. +68
    -15
      mindinsight/mindconverter/graph_based_converter/common/outputs.py
  4. +22
    -2
      mindinsight/mindconverter/graph_based_converter/generator/__init__.py
  5. +127
    -21
      mindinsight/mindconverter/graph_based_converter/generator/generator.py
  6. +172
    -0
      mindinsight/mindconverter/graph_based_converter/generator/matcher.py
  7. +97
    -16
      mindinsight/mindconverter/graph_based_converter/generator/module_struct.py
  8. +97
    -41
      mindinsight/mindconverter/graph_based_converter/generator/node_struct.py
  9. +91
    -0
      mindinsight/mindconverter/graph_based_converter/generator/shared_weights.py
  10. +15
    -3
      mindinsight/mindconverter/graph_based_converter/mapper/base.py
  11. +9
    -9
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/mat_mul_mapper.py
  12. +10
    -7
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/add_mapper.py
  13. +10
    -8
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/mul_mapper.py
  14. +12
    -6
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py

+ 4
- 1
mindinsight/mindconverter/graph_based_converter/common/code_fragment.py View File

@@ -310,11 +310,14 @@ class NewFragment:
rewrite_data.update(data[ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value])
if ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value in data:
rewrite_params = {
f"{var}/{slot}": data[ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value][slot]
f"{var}/{slot}": data[ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value].get(slot)
for slot in data[ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value]
}
rewrite_data.update(rewrite_params)
rewrite_data.update(data[ExchangeMessageKeywords.VariableScope.value.ARGS.value])
template = template.format(**{
k: str(rewrite_data[k]) for k in rewrite_data
})
return template.format(**{
k: str(rewrite_data[k]) for k in rewrite_data
})

+ 4
- 1
mindinsight/mindconverter/graph_based_converter/common/global_context.py View File

@@ -83,6 +83,9 @@ class GlobalContext(metaclass=Singleton):

# Record weights name that used many times.
self.repeated_weights = dict()
self.repeated_weights_declaration = dict()
# Define Module Struct Build Status
self.build_struct_finished = False

def get_onnx_node_from_identifier(self, identifier):
"""Return an OnnxUtils defined node by its identifier."""
@@ -144,7 +147,7 @@ class GlobalContext(metaclass=Singleton):
@property
def onnx_tensors_collection(self):
"""Return the onnx tensors collection."""
return self.onnx_tensors_collection
return self._onnx_tensors_collection

@onnx_tensors_collection.setter
def onnx_tensors_collection(self, arg):


+ 68
- 15
mindinsight/mindconverter/graph_based_converter/common/outputs.py View File

@@ -17,6 +17,9 @@ import abc
import copy
from typing import Union, Iterable

from mindinsight.mindconverter.graph_based_converter.common.code_fragment import NewFragment


class BaseOutput:
"""
Define the class of output providing a universal nodes' and modules' output data collection.
@@ -36,6 +39,9 @@ class BaseOutput:
# The following attributes to be set by who referenced this object.
self.onnx_edge_name = None
self.to_external = False
self.opt_var_name = None
# Only for module output edge and its name inside its module
self.inner_ret_name = None

@property
def ms_user(self):
@@ -59,28 +65,41 @@ class BaseOutputManager(abc.ABC):
Args:
output_mappings (list): A list of output mapping.
"""
def __init__(self, output_mappings):
if isinstance(self.__class__, ModuleOutputManager):
def __init__(self, identifier, output_mappings: Iterable):
if isinstance(self, ModuleOutputManager):
return
self._base_output_list = list()
self._base_output_dict = dict()
self.identifier = identifier

# init base output obj
for mapping in output_mappings:
for (onnx_edge_name, mapping) in output_mappings:
obj = BaseOutput(mapping)
self._base_output_list.append(obj)
self._base_output_dict[onnx_edge_name] = obj
obj.onnx_edge_name = onnx_edge_name

@property
def outputs(self):
"""Return the list of BaseOutput in this manager."""
return self._base_output_list
return self._base_output_dict.values()

@property
def outputs_edges(self):
"""Return the list of outputs edge names in this manager."""
return self._base_output_dict.keys()

@outputs.setter
def outputs(self, val: list):
"""Set the list of BaseOutput in this manager."""
tmp = dict()
for v in val:
if not isinstance(v, BaseOutput):
raise TypeError(f"{self.__class__} does not accept the type {type(v)} in the list given.")
self._base_output_list = val
tmp[v.onnx_edge_name] = v
self._base_output_dict = tmp

def get_base_out(self, onnx_edge_name: str) -> BaseOutput:
"""Return the BaseOut by key."""
return self._base_output_dict.get(onnx_edge_name)

@abc.abstractmethod
def deepcopy(self):
@@ -88,7 +107,7 @@ class BaseOutputManager(abc.ABC):
cls = self.__class__
result = cls.__new__(cls)
result.outputs = list()
for out in self._base_output_list:
for out in self._base_output_dict.values():
result.outputs.append(out.deepcopy())
return result

@@ -102,14 +121,19 @@ class NodeOutputManager(BaseOutputManager):
output_mappings (list): A list of the output mapping.
"""
def __init__(self, identifier, output_mappings=None) -> None:
super(NodeOutputManager, self).__init__(output_mappings)
self.identifier = identifier
super(NodeOutputManager, self).__init__(identifier, output_mappings)

def deepcopy(self):
"""Self defined deepcopy method."""
new_mgr = super().deepcopy()
new_mgr.identifier = self.identifier
return new_mgr

def bind_opt_var_names(self, fragment: NewFragment):
"""Get the opt_var_name in return statement."""
for base_out in self._base_output_dict.values():
base_out.opt_var_name = fragment.get_outputs_by_idx(base_out.idx_in_ms_provider)


class ModuleOutputManager(BaseOutputManager):
"""
@@ -120,14 +144,13 @@ class ModuleOutputManager(BaseOutputManager):
output_mappings (list): a list of output mapping
"""
def __init__(self, identifier, base_out: Union[BaseOutput, Iterable[BaseOutput]]) -> None:
super(ModuleOutputManager, self).__init__(None)
self.identifier = identifier
super(ModuleOutputManager, self).__init__(identifier, None)
self._return_list_counter = 0
self._base_output_list = list()
self._base_output_dict = dict()
if isinstance(base_out, BaseOutput):
self._base_output_list.append(base_out)
self.outputs = [base_out]
else:
self._base_output_list += base_out
self.outputs = base_out

@property
def return_num(self):
@@ -139,6 +162,12 @@ class ModuleOutputManager(BaseOutputManager):
"""Set the number of outputs to be returned."""
self._return_list_counter = num

def assign_opt_var_name_to_each_output(self, opt_var_name_base: str):
"""Assign opt_var_name for each output."""
for idx, base_out in enumerate(self._base_output_dict.values()):
postfix = str(idx) if idx > 0 else ""
base_out.opt_var_name = '_'.join([opt_var_name_base, postfix]) if idx > 0 else opt_var_name_base

def deepcopy(self):
"""Return a deepcopy of current instance."""
new_mgr = super().deepcopy()
@@ -146,6 +175,30 @@ class ModuleOutputManager(BaseOutputManager):
new_mgr.return_num = self._return_list_counter
return new_mgr

def bind_module_outputs_internal_name(self, outputs_register: dict):
"""
Get the opt_var_name in return list.

Args:
opt_var_name_list (list): List from module outputs register, registered by submodule and nodes.
"""
for base_out in self._base_output_dict.values():
# bind the edge name inside module
base_out.inner_ret_name = outputs_register.get(base_out.onnx_edge_name)

def bind_opt_var_name(self, opt_var_names: list):
"""
Assign the opt_var_name for outputs of this module.

Args:
opt_var_names (list): A list of opt_var_name of this module, generated by module itself.
"""
if len(opt_var_names) != len(self._base_output_dict.values()):
raise ValueError(f"Unable to bind the opt_var_name of the Module {self.identifier}" \
f" has inconsistent outputs number.")
for idx, base_out in enumerate(self._base_output_dict.values()):
base_out.opt_var_name = opt_var_names[idx]


class OutputStorage:
"""A class saves all outputs."""


+ 22
- 2
mindinsight/mindconverter/graph_based_converter/generator/__init__.py View File

@@ -18,10 +18,10 @@ __all__ = ["batch_add_nodes"]
import re
import copy

from mindinsight.mindconverter.graph_based_converter.generator.generator import Generator, CodeStruct
from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords
from mindinsight.mindconverter.graph_based_converter.common.code_fragment import NewFragment
from mindinsight.mindconverter.graph_based_converter.common.outputs import NodeOutputManager
from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords
from mindinsight.mindconverter.graph_based_converter.generator.generator import Generator


def _tf_model_node_name_reformat(node, node_name):
@@ -123,6 +123,7 @@ def _convert_params(node, mapper, external_inputs):
params=params,
weights=node.weight)
exchange_msg[ExchangeMessageKeywords.METADATA.value] = _supply_graph_info(node, external_inputs)
outputs_order_mapping = _bind_outputs_edges(exchange_msg=exchange_msg, outputs_order_mapping=outputs_order_mapping)
return code_template, exchange_msg, outputs_lst, outputs_order_mapping


@@ -145,3 +146,22 @@ def _combine_external_inputs_with_precursor_nodes(node, external_inputs):
node_idx = node.ir_node_inputs.index(item)
precursor.insert(node_idx, item)
return precursor

def _bind_outputs_edges(exchange_msg, outputs_order_mapping):
"""
Bind the outputs edges names with the outputs order mapping.

Args:
exchange_msg (dict): The dict of exchange messages of this node.
outputs_order_mapping (tuple): The outputs mapping of this node.

returns,
zip, the zip object of both edges and mapping
"""
outputs_edges = exchange_msg.get('metadata').get('outputs')
if not outputs_edges:
raise ValueError(f"ONNX Node {exchange_msg.get('metadata').get('source')} has no outputs info.")
if len(outputs_edges) != len(outputs_order_mapping):
raise ValueError(f"ONNX Node {exchange_msg.get('metadata').get('source')} has inconsistent " \
f"outputs edge number and mapping number")
return zip(outputs_edges, outputs_order_mapping)

+ 127
- 21
mindinsight/mindconverter/graph_based_converter/generator/generator.py View File

@@ -17,6 +17,7 @@ import copy
from collections import OrderedDict
from importlib import import_module

import numpy as np
from yapf.yapflib.yapf_api import FormatCode

from mindinsight.mindconverter.common.exceptions import GeneratorError
@@ -32,6 +33,8 @@ from mindinsight.mindconverter.graph_based_converter.constant import NEW_LINE, S
FIRST_LEVEL_INDENT, get_imported_module, SEPARATOR_BTW_NAME_AND_ID, WeightType, LINK_IN_WEIGHT_NAME
from mindinsight.mindconverter.graph_based_converter.report_generator import ReportGenerator
from mindinsight.mindconverter.graph_based_converter.common.utils import replace_string_in_list
from mindinsight.mindconverter.graph_based_converter.generator.matcher import MatcherLauncher
from mindinsight.mindconverter.graph_based_converter.generator.shared_weights import SharedWeightHelper


class CodeStruct:
@@ -90,6 +93,10 @@ class CodeStruct:
for formal in md_struct.args_translator.formal_args.keys():
module_def_args.append(formal)

# set passthrough weights for shared weights, no need for main model
if md_struct.identifier != []:
module_def_args = SharedWeightHelper.add_shared_weights_in_init_statement(md_struct, module_def_args)

# For code line in init & construct blocks
init_lines = list()
cons_lines = list()
@@ -105,7 +112,7 @@ class CodeStruct:
init_lines += init_str
cons_lines += cons_str

else: # is ModuleStruct
else: # is ModuleStruct
# check if this instance generated CodeStruct
if GlobalContext().code_structs.get(struct.pattern_id) is None:
CodeStruct(struct, repeated_submodules)
@@ -118,6 +125,13 @@ class CodeStruct:
# define header of init block
self.new_line = f"{FIRST_LEVEL_INDENT}def __init__({', '.join(module_def_args)}):"
self.new_line = f"{SECOND_LEVEL_INDENT}super({class_name}, self).__init__()"

#add shared weights declaration in init code part
if md_struct.identifier == []:
passthrough_w_declaration = SharedWeightHelper.public_module_shared_weight_statement_generation(md_struct)
for s in passthrough_w_declaration:
self.new_line = f"{SECOND_LEVEL_INDENT}{s}"

# add init code lines to code line list.
self.code_line_list += init_lines
self.new_line = f"{NEW_LINE * 2}"
@@ -129,16 +143,14 @@ class CodeStruct:
self.code_line_list += cons_lines
# define returns
returns = []
if md_struct.external_successor_local_returns_map:
for r in list(md_struct.external_successor_local_returns_map.values()):
if isinstance(r, tuple): # results return with index nth output
returns.append(r[0])
else:
returns.append(r)
returns = list(set(returns))
else:
returns = [code_line_construct[0]] if isinstance(code_line_construct, tuple) \
else [code_line_construct[-1].replace(' ', '').split('=')[0]]

# take opt_var_name to return_list
for output_edge in md_struct.outputs_register.keys():
opt_var_name = md_struct.internal_outputs_collection.get(output_edge)
if opt_var_name is None:
raise ValueError(f"Module {md_struct.identifier} has an output {output_edge} has unknown opt_var_name.")
returns.append(opt_var_name)

self.new_line = f"{SECOND_LEVEL_INDENT}return {', '.join(returns)}"
self.new_line = f"{NEW_LINE * 2}"
GlobalContext().code_structs[md_struct.pattern_id] = self
@@ -244,6 +256,83 @@ class Generator:

return formal_args

@staticmethod
def _set_translated_args_for_unshared_weights(t_param_postfix, nd_struct_list):
"""Set the weight with given param postfix to args translation."""
for _, nd_struct in nd_struct_list:
nparr = nd_struct.fragment.default_var["trainable_params"].get(t_param_postfix).get('data')
nd_struct.fragment.default_var["args"][f"{t_param_postfix}_shape"] = nparr.shape
nd_struct.fragment.default_var["args"][f"{t_param_postfix}_dtype"] = nparr.dtype
init_tensor_template = f"Parameter(Tensor(np.random.uniform(0, 1, "\
f"{{{t_param_postfix}_shape}}).astype(np.{{{t_param_postfix}_dtype}})), "\
f"name=None)"
nd_struct.fragment.default_var["parameters"][t_param_postfix] = init_tensor_template

def _get_same_trainable_params_onnx_name_from_repeated_nodes(self,
t_param_postfix,
t_param_data_dict,
nd_struct_list: list):
"""Return all onnx names from the same weights in repeated nodes."""
(_, base_nd_struct) = nd_struct_list[0]
t_base_name = t_param_data_dict.get('onnx_name')
t_onnx_names = [t_base_name]
for (_, nd_struct) in nd_struct_list[1:]:
compared_t_param_data_dict = nd_struct.fragment.default_var["trainable_params"].get(t_param_postfix)
if not compared_t_param_data_dict:
raise ValueError(f"Inconsistent trainable params detected for node "\
f"{nd_struct.topo_idx} with base node {base_nd_struct.topo_idx}")
compared_t_name = compared_t_param_data_dict.get('onnx_name')
t_onnx_names.append(compared_t_name)
return t_onnx_names

def _partial_shared_weights_in_repeated_submodule_procs(self, nd_struct_list):
"""
Check each node in repeated submodule to ensure the node has a fully / partial shared weight.

Args:
nd_struct_list (list): A list of node structs which are same node in repeated modules.
"""
# Not repeated will skip this function
if len(nd_struct_list) < 2:
return
(_, base_nd_struct) = nd_struct_list[0]
shared_w_list = self._global_context.repeated_weights.keys()
if not shared_w_list:
if base_nd_struct.fragment.default_var.get("parameters"):
# set only if has parameters as it requires rewritten.
for (t_param_postfix, t_param_data_dict) in \
base_nd_struct.fragment.default_var["trainable_params"].items():
if not isinstance(t_param_data_dict.get('data'), np.ndarray):
continue
Generator._set_translated_args_for_unshared_weights(t_param_postfix, nd_struct_list)
return

for (t_param_postfix, t_param_data_dict) in base_nd_struct.fragment.default_var["trainable_params"].items():
# check each weight if partial shared or fully shared weight
if not t_param_data_dict:
continue
t_onnx_names = self._get_same_trainable_params_onnx_name_from_repeated_nodes(t_param_postfix,
t_param_data_dict,
nd_struct_list)
t_shared_status = [name in shared_w_list for name in t_onnx_names]
if True in t_shared_status and False in t_shared_status:
# is partial shared, set unshared to fake shared in GlobalContext
for idx, (name, status) in enumerate(zip(t_onnx_names, t_shared_status)):
if status:
# actual shared, do nothing, skip
continue
node_onnx_name = nd_struct_list[idx][1].onnx_name
if not self._global_context.repeated_weights.get(name):
self._global_context.repeated_weights[name] = [node_onnx_name]
else:
self._global_context.repeated_weights[name] += [node_onnx_name]
if True not in t_shared_status and base_nd_struct.fragment.default_var.get("parameters"):
# if the repeated node is not shared weight and the mapper accept parameters rewritten.
if not isinstance(t_param_data_dict.get('data'), np.ndarray):
continue
Generator._set_translated_args_for_unshared_weights(t_param_postfix, nd_struct_list)


def _list_formal_parameters_in_a_module(self, module_filter_return):
"""
Find all formal args / params from nodes in a module.
@@ -256,6 +345,9 @@ class Generator:
"""
formal_params_list = list()
transposed = [list(e) for e in zip(*module_filter_return)]
for operation in transposed:
# use the map filtered result for partial shared weights procs
self._partial_shared_weights_in_repeated_submodule_procs(operation)
for operation in transposed:
formal_parameters = self._compare_with_base_parameters(operation)
if formal_parameters:
@@ -363,19 +455,34 @@ class Generator:
md_collection_len = new_len
else:
len_changes = False
GlobalContext().build_struct_finished = True
# 5. Update all translated args from module map
self._update_all_modules_args_translator()

# 6. Update all nodes and moudles input/output
# Enable build_output_connections later.
self.build_outputs_connection()
self.module_structs.get('[]').allocate_construct_header_x()
self.module_structs.get('[]').collect_returns()

matcher = MatcherLauncher(self.module_structs.get('[]'))
matcher.matching_process()

for nd_struct in self.node_structs.values():
if nd_struct.fragment.metadata.get("operation") == "Split":
self._split_op_procs(nd_struct)

def _shared_weights_processing(self):
"""Process shared weights."""
# check each node has shared weight
for nd_struct in self.node_structs.values():
shared_weights = SharedWeightHelper.check_node_has_shared_weight(nd_struct)
if shared_weights:
# register each shared weight to public module
for shared_w in shared_weights:
SharedWeightHelper.register_shared_weight_to_public_parent(nd_struct,
shared_w,
pub_module_identifier=[])

def _update_all_modules_args_translator(self):
"""Update all modules' args translators."""
done_submodule = set()
@@ -426,7 +533,7 @@ class Generator:
else:
raise TypeError("Unable to update global depth due to TypeError in NodeStruct.scope.depth")

def add_node(self, node_identifier, node_instance=None, node_fragment=None, mapper_dict=None):
def add_node(self, node_identifier, node_instance=None, node_fragment=None):
"""
Add Node information to the generator.

@@ -434,7 +541,6 @@ class Generator:
node_identifier (str): The unique identifier for the node passed in.
node_instance (GraphNode): The GraphNode instance of each node.
node_fragment (NodeFragment): The NodeFragment instance of this node passed in.
mapper_dict (dict): The dict contains converted attributes from mapper.
"""

if node_identifier is None:
@@ -443,8 +549,6 @@ class Generator:
args = []
if node_instance is not None:
args.append(node_instance)
if mapper_dict is not None:
args.append(mapper_dict)
if node_fragment is not None:
args.append(node_fragment)

@@ -551,6 +655,7 @@ class Generator:
"""
self._form_bottom_submodule()
self._recursive_form_module()
self._shared_weights_processing()

ckpt_data_list, weight_map = self.generate_checkpoint()

@@ -654,14 +759,13 @@ class Generator:
# for each output in curr node output manager
for out in nd_struct.outputs_manager.outputs:
# Set the onnx output edge name to this output
out.onnx_edge_name = nd_struct.fragment.metadata.get('outputs')[out.idx_in_onnx_provider]
self._global_context.outputs_storage.add_output(out)
self._global_context.outputs_storage.add_onnx_node_name(out.onnx_edge_name,
nd_struct.fragment.metadata.get('source'))
self._global_context.outputs_storage.add_ms_identifier(out.onnx_edge_name, nd_struct.identifier)

# Set input with existing output mapping
for idx, inp in enumerate(nd_struct.fragment.metadata.get('inputs')):
for idx, inp in enumerate(nd_struct.inputs_edges_names):
if inp in self._global_context.outputs_storage.outputs_collections:
output_obj = self._global_context.outputs_storage.outputs_collections[inp]
output_obj.idx_in_onnx_user[nd_struct.onnx_name] = idx
@@ -694,8 +798,10 @@ class Generator:
self._collect_output_mgr(module=struct)
for out in struct.outputs_manager.outputs:
if Generator.check_output_need_to_external(root_module, out):
output_mgr_list.append(out)
root_module.outputs_manager = ModuleOutputManager(root_module.identifier, base_out=output_mgr_list)
output_mgr_list.append(out.deepcopy())
root_module.outputs_manager = ModuleOutputManager(root_module.identifier,
base_out=output_mgr_list)
root_module.outputs_manager.assign_opt_var_name_to_each_output(root_module.ms_opt_var_name)

@staticmethod
def check_output_need_to_external(root_module: ModuleStruct, checked_output: BaseOutput):


+ 172
- 0
mindinsight/mindconverter/graph_based_converter/generator/matcher.py View File

@@ -0,0 +1,172 @@
# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Processing the node's and modules' inputs and outputs matching."""
from mindinsight.mindconverter.graph_based_converter.generator.node_struct import NodeStruct
from mindinsight.mindconverter.graph_based_converter.generator.module_struct import ModuleStruct
from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext

class MatcherHelper:
"""
Helper function for matching processing.
"""
@staticmethod
def main_model_special_process_inputs(main_model: ModuleStruct):
"""Call in preprocess"""
# allocate main model construct x
prec_edges = main_model.external_precursor_nodes_names
default_x_str = "x"
inputs = dict()
for idx, edge in enumerate(prec_edges):
if not edge in inputs:
# idx-1 here as we have a x without index and another x0 in module inputs
# so the idx 0 position is the second input, not the first x.
inputs[edge] = "".join([default_x_str, str(idx-1)]) if idx > 0 else default_x_str
main_model.inputs_register = inputs

@staticmethod
def get_public_parent_module(node_a: NodeStruct, node_b: NodeStruct):
"""Return the public parent module of both Node A and Node B."""
find = False
b_onnx_name = node_b.onnx_name
tmp = node_a
while not find:
parent_struct = tmp.parent_module_struct
if b_onnx_name in parent_struct.onnx_names:
find = True
tmp = parent_struct
return tmp

@staticmethod
def get_submodule_has_out_user_under_public_parent(public_module: ModuleStruct, node_out_user: NodeStruct):
"""Return the ModuleStruct which under the public module and contains the NodeStruct which provided."""
for module_struct in public_module.module_structs:
if node_out_user.onnx_name in module_struct.onnx_names:
return module_struct
return None

@staticmethod
def register_outputs_to_main_model(output_edge_name: str, output_edge_provider: NodeStruct):
"""
Register the output edge to the main model and through all modules.

Args:
output_edge_name (str): The name of this output edge.
output_edge_provider (NodeStruct): The node which produces this output.
"""
base_out = output_edge_provider.outputs_manager.get_base_out(output_edge_name)
nd_parent = output_edge_provider.parent_module_struct
while nd_parent:
nd_parent.add_outputs_edge(base_out.onnx_edge_name)
nd_parent = nd_parent.parent_module_struct

@staticmethod
def register_inputs_to_main_model(input_edge_name: str, input_edge_user: NodeStruct):
"""
Register the input edge to the main model and through all modules.

Args:
input_edge_name (str): The name of this input edge.
input_edge_user (NodeStruct): The node uses this input.
"""
nd_parent = input_edge_user.parent_module_struct
while nd_parent:
nd_parent.add_inputs_edge(input_edge_name)
nd_parent = nd_parent.parent_module_struct


class MatcherLauncher:
"""Process Node-to-Node inputs outputs matching."""
def __init__(self, main_model: ModuleStruct):
super(MatcherLauncher).__init__()
self.main_model = main_model
self._global_context = GlobalContext()
self._graph_inputs = self._global_context.onnx_graph_info.get("graph_inputs")
self._graph_outputs = self._global_context.onnx_graph_info.get("graph_outputs")

def matching_process(self):
"""The matching process."""
# 0. Pre-process
MatcherHelper.main_model_special_process_inputs(self.main_model)

# 1. Set all module's return dict
self._register_module_inputs_x_header()

# 2. Set module returns
self._register_module_returns()


def _register_module_inputs_x_header(self):
"""Recursively register the inputs to module init header."""
# Use nearest parent module algorithm
for nd_struct in self._global_context.node_struct_collections.values():
if not nd_struct.precursor_nodes_names_external:
# has no precursor nodes but need check if inputs are graph level inputs
has_graph_input = False
for edge in nd_struct.inputs_edges_names:
if edge in self._global_context.onnx_graph_info.get('graph_inputs'):
has_graph_input = True
break
if not has_graph_input:
continue # avoid unnecessary checking

for inp in nd_struct.inputs_edges_names:
if inp in self._global_context.onnx_graph_info.get('graph_inputs'):
# when the input edge is from graph level.
MatcherHelper.register_inputs_to_main_model(inp, nd_struct)
continue
out_provider_onnx_name = self._global_context.outputs_storage.onnx_name(inp)
out_provider_struct = \
self._global_context.onnx_node_name_to_node_struct_map.get(out_provider_onnx_name)
if out_provider_struct is None:
raise ValueError(f"The Matcher detected an output has unknown provider for the edge {inp}")
public_parent = MatcherHelper.get_public_parent_module(nd_struct, out_provider_struct)
nd_parent = nd_struct.parent_module_struct
# Recursively register x in all parents until the public module
while public_parent.identifier != nd_parent.identifier:
nd_parent.add_inputs_edge(inp)
nd_parent = nd_parent.parent_module_struct


def _register_module_returns(self):
"""Recursively register the node outputs to parent modules."""
# Use nearest parent module algorithm
for nd_struct in self._global_context.node_struct_collections.values():
if not nd_struct.successor_nodes_names_external:
# check if any edge to graph output
has_graph_output = False
for edge in nd_struct.fragment.metadata.get('outputs'):
if edge in self._global_context.onnx_graph_info.get('graph_outputs'):
has_graph_output = True
break
if not has_graph_output:
continue # avoid unnecessary checking
for base_out in nd_struct.outputs_manager.outputs:
if base_out.onnx_edge_name in self._global_context.onnx_graph_info.get('graph_outputs'):
MatcherHelper.register_outputs_to_main_model(base_out.onnx_edge_name, nd_struct)
continue
out_user_onnx_names = base_out.onnx_user
for out_user_onnx_name in out_user_onnx_names:
out_user_struct = \
self._global_context.onnx_node_name_to_node_struct_map.get(out_user_onnx_name)
if out_user_struct is None:
raise ValueError(f"The Matcher detected an output has unknown provider for the edge "\
f"{base_out.onnx_edge_name}")
public_parent = MatcherHelper.get_public_parent_module(nd_struct, out_user_struct)
nd_parent = nd_struct.parent_module_struct
# Recursively register outputs to parents until the public module
while public_parent.identifier != nd_parent.identifier:
nd_parent.add_outputs_edge(base_out.onnx_edge_name)
nd_parent = nd_parent.parent_module_struct

+ 97
- 16
mindinsight/mindconverter/graph_based_converter/generator/module_struct.py View File

@@ -34,7 +34,6 @@ class ModuleStruct:
init_as_parent (bool): Control init method if the ModuleStruct be init as a parent module struct.
parent_base (ModuleStruct): The base ModuleStruct the current ModuleStruct to be init as.
"""

def __init__(self, nd_struct_list, init_as_parent=False, parent_base=None):
"""Init. a module by NodeStructs."""
self.pattern_id = -1 # pattern num, -1 as Main module
@@ -74,6 +73,20 @@ class ModuleStruct:
# Define outputs manager, note this will be assigned later by Generator.
self.outputs_manager = None

self._global_context = GlobalContext()

# Define a dict to store the reference for quick searching
self.rapid_reference = dict()

# new vars for matcher
self.inputs_register = dict() # reg by sub
self.outputs_register = OrderedDict() # reg by sub
self.internal_outputs_collection = dict() # reg by sub

# new vars for shared weights
self.shared_weights_collection = dict() # reg by sub
self.shared_weights_counter = 0 # updated by sub

if init_as_parent and (parent_base is not None):
self.reset_as_parent_passed_in(parent_base)
else:
@@ -293,8 +306,26 @@ class ModuleStruct:
ret.sort(key=lambda x: x[0])
return ret

def _code_line_init_statement_shared_weights_args(self):
"""Generate the args for shared weights where calling this module."""
args_list = list()
for passthrough_w_onnx_name, passthrough_w_var_name in self.shared_weights_collection.items():
passthrough_w_var_name_in_parent = \
self.parent_module_struct.shared_weights_collection.get(passthrough_w_onnx_name)
if self.parent_module_struct.identifier == []: # now only consider declaration in main model
args_list.append(f"{passthrough_w_var_name}=self.{passthrough_w_var_name_in_parent}")
else:
args_list.append(f"{passthrough_w_var_name}={passthrough_w_var_name_in_parent}")
return args_list

def _code_line_init_generate_shared_w_declaration_for_repeated(self):
"""Force to repeat sub nodes init code line for fulfillment of shared weight declaration in main model."""
for _, nd_struct in self._node_structs:
nd_struct.code_line_in_init()

def code_line_in_init(self):
"""Initialization line of code in module init block."""
self._code_line_init_generate_shared_w_declaration_for_repeated()
left = "self.{}".format(self.ms_var_name)
args_list = list()
# Load args in init statement.
@@ -308,26 +339,40 @@ class ModuleStruct:
args_list += self._args_translator.formal_args_to_str_list # load from formal args
else:
args_list += self._fragment.actual_args
args_list += self._code_line_init_statement_shared_weights_args()
right = f"{self.class_name}({', '.join(args_list)})"
return left, right

def code_line_in_construct(self, inputs=None):
"""Construct line of code in module construct block."""
# check number of outputs this module has
opt_var_name_in_module = list(self.external_successor_local_returns_map.values())
num_output = len(set(opt_var_name_in_module))
outputs_edges = list(self.outputs_register.keys())
num_output = len(outputs_edges)

# Allocate opt_var_name
if num_output == 1: # single output
left = f"{self.ms_opt_var_name}"
left = [f"{self.ms_opt_var_name}"]
else:
left = [f"{self.ms_opt_var_name}_{num}" for num in range(num_output)]

if inputs is None and self.matched_inputs:
inputs = self.matched_inputs
inputs = []
# Update self's outputs mgr
for idx, edge in enumerate(outputs_edges):
base_out = self.outputs_manager.get_base_out(edge)
if base_out.opt_var_name is None:
print(f"ModuleStruct {self.identifier} has an output {base_out.onnx_edge_name} not has opt_var_name")
base_out.opt_var_name = left[idx]
self.parent_module_struct.internal_outputs_collection[base_out.onnx_edge_name] = base_out.opt_var_name

# Take inputs from parent & previous
for input_edge in self.inputs_register:
if input_edge in self.parent_module_struct.inputs_register:
inputs.append(self.parent_module_struct.inputs_register.get(input_edge))
elif input_edge in self.parent_module_struct.internal_outputs_collection:
inputs.append(self.parent_module_struct.internal_outputs_collection.get(input_edge))

if isinstance(inputs, str):
inputs = [inputs]
right = f"self.{self.ms_var_name}({', '.join(inputs)})"
return left, right
left = ", ".join(left)
return (left, right)

@property
def node_structs(self):
@@ -377,23 +422,36 @@ class ModuleStruct:
@property
def onnx_names_from_nodes(self) -> list:
"""Return all nodes onnx names in this module."""
ret = []
for (_, node) in self.node_structs:
ret.append(node.onnx_name)
if self._global_context.build_struct_finished and "_onnx_names_from_nodes" in self.rapid_reference:
return self.rapid_reference["_onnx_names_from_nodes"]
ret = [node.onnx_name for (_, node) in self.node_structs]
if self._global_context.build_struct_finished:
self.rapid_reference["_onnx_names_from_nodes"] = ret
return ret

@property
def onnx_names_from_submodules(self) -> list:
"""Return all nodes onnx names in submodules of this module."""
if self._global_context.build_struct_finished and "_onnx_names_from_submodules" in self.rapid_reference:
return self.rapid_reference["_onnx_names_from_submodules"]

ret = []
for md_struct in self.module_structs:
ret += md_struct.onnx_names
if self._global_context.build_struct_finished:
self.rapid_reference["_onnx_names_from_submodules"] = ret

return ret

@property
def onnx_names(self) -> list:
"""Return all nodes' onnx names which contained by this module."""
return self.onnx_names_from_nodes + self.onnx_names_from_submodules
if self._global_context.build_struct_finished and "_onnx_names" in self.rapid_reference:
return self.rapid_reference["_onnx_names"]
ret = self.onnx_names_from_nodes + self.onnx_names_from_submodules
if self._global_context.build_struct_finished:
self.rapid_reference["_onnx_names"] = ret
return ret

@property
def external_precursor_nodes_names(self) -> list:
@@ -434,8 +492,8 @@ class ModuleStruct:
"""Return the class name for generating code of this module."""
if self.pattern_id == -1:
return "Model"
if GlobalContext().known_module_name.get("Module{}".format(self.pattern_id)) is not None:
class_name = GlobalContext().known_module_name.get("Module{}".format(self.pattern_id))
if self._global_context.known_module_name.get("Module{}".format(self.pattern_id)) is not None:
class_name = self._global_context.known_module_name.get("Module{}".format(self.pattern_id))
else:
class_name = "Module{}".format(self.pattern_id)
return class_name
@@ -676,3 +734,26 @@ class ModuleStruct:
submodule_opt_var_name = md_struct.ms_opt_var_name
for (submodule_ext_succ, _, ith_output) in submodule_returns:
self.external_successor_local_returns_map[submodule_ext_succ] = (submodule_opt_var_name, ith_output)

# The following funcs are designated to be invoked by matcher.
def add_inputs_edge(self, edge_name: str):
construct_header_length = len(self.inputs_register.values())
default_x_str = "x"
if not edge_name in self.inputs_register:
self.inputs_register[edge_name] = "".join([default_x_str, str(construct_header_length-1)]) \
if construct_header_length > 0 else default_x_str

def add_outputs_edge(self, edge_name: str):
if edge_name in self.outputs_register:
return # to be filled during code generation, should from sub's opt_var_name
self.outputs_register[edge_name] = "<placeholder>"

def fill_outputs_edge(self, edge_name: str, opt_var_name: str):
# FILL the outputs edge once you got a opt_var_name of corresponding node!!!
if not edge_name in self.outputs_register:
raise ValueError(f"ModuleStruct {self.identifier} does not have edge "\
f"{edge_name} and unable to fill its output var name.")
if self.outputs_register[edge_name] != "<placeholder>":
raise ValueError(f"The edge has been already filled as {self.outputs_register[edge_name]}" \
f" instead of your {opt_var_name}")
self.outputs_register[edge_name] = opt_var_name

+ 97
- 41
mindinsight/mindconverter/graph_based_converter/generator/node_struct.py View File

@@ -35,7 +35,6 @@ class NodeStruct:
You can pass as many args as possible and the Node Struct will update
by arguments order.
"""

def __init__(self, args):
# define attributes here
self.global_context_mgr = GlobalContext()
@@ -43,6 +42,7 @@ class NodeStruct:
self._fragment = None
self._args_translator = None
self._parent_module_struct = None
self._global_context = GlobalContext()
self.topo_idx = None
self.onnx_name = None
self.graph_node_ref = None
@@ -75,7 +75,7 @@ class NodeStruct:
"""Get the original topological index in the onnx graph."""
ori_name = self._fragment.metadata.get('source')
self.onnx_name = ori_name
return GlobalContext().onnx_node_name_to_topo_idx.get(ori_name)
return self._global_context.onnx_node_name_to_topo_idx.get(ori_name)

def update_var_name(self, idx=None):
"""
@@ -114,7 +114,7 @@ class NodeStruct:
self._fragment = FragmentHandler(frag)

if self.ms_op:
idx = GlobalContext().latest_node_struct_count
idx = self._global_context.latest_node_struct_count
self.update_var_name(idx=idx)

def _set_scope_from_identifier(self):
@@ -168,7 +168,7 @@ class NodeStruct:
self._identifier = s
self._set_scope_from_identifier()
self.topo_idx = self.ori_topo_idx()
GlobalContext().onnx_node_name_to_node_struct_map[self.onnx_name] = self
self._global_context.onnx_node_name_to_node_struct_map[self.onnx_name] = self

@property
def fragment(self):
@@ -198,7 +198,7 @@ class NodeStruct:
@property
def onnx_node(self):
"""Return the original onnx node reference."""
return GlobalContext().onnx_nodes_collection.get(self.onnx_name)
return self._global_context.onnx_nodes_collection.get(self.onnx_name)

@property
def ms_op(self):
@@ -241,7 +241,7 @@ class NodeStruct:
ret = []
precursor_nodes_names = self.precursor_nodes_names
for pre_node_name in precursor_nodes_names:
nd_struct = GlobalContext().onnx_node_name_to_node_struct_map.get(pre_node_name)
nd_struct = self._global_context.onnx_node_name_to_node_struct_map.get(pre_node_name)
ret.append(nd_struct)
return ret

@@ -255,7 +255,7 @@ class NodeStruct:
"""Return the node struct instances of successor nodes."""
ret = []
for pre_node_name in self.successor_nodes_names:
nd_struct = GlobalContext().onnx_node_name_to_node_struct_map.get(pre_node_name)
nd_struct = self._global_context.onnx_node_name_to_node_struct_map.get(pre_node_name)
ret.append(nd_struct)
return ret

@@ -278,54 +278,110 @@ class NodeStruct:
"""Return the outputs var(s) in construct statement."""
return self.fragment.fragment.outputs()

@property
def inputs_edges_names(self):
"""Return the inputs edges of this node."""
# Consider moving this process to metadata.
ret = []
for edge in self.fragment.metadata.get('inputs'):
if not self._global_context.get_onnx_tensor(edge):
ret.append(edge)
return ret

@property
def shared_weights(self):
"""Return the shared weights in this node."""
shared_weight_names = []
for shared_weight_name, repeated_node_list in self._global_context.repeated_weights.items():
if self.onnx_name in repeated_node_list:
shared_weight_names.append(shared_weight_name)
return shared_weight_names

# Code Generation funcs below

def _get_shared_weight_var_names_from_parent(self, onnx_name=None):
"""
Get shared weight var name in the parent module.

Args:
onnx_name (str): The onnx name of this weight. Default None.

Returns:
[List, str], a list of all shared weights the node has or the specific name provided.
"""
if onnx_name is None:
shared_weights_var_name_in_module = []
for shared_w in self.shared_weights:
for passthrough_w, passthrough_w_var_name in \
self._parent_module_struct.shared_weights_collection.items():
if shared_w == passthrough_w:
shared_weights_var_name_in_module.append(passthrough_w_var_name)
return shared_weights_var_name_in_module
if isinstance(onnx_name, str):
return self._parent_module_struct.shared_weights_collection.get(onnx_name)

return []


def code_line_in_init(self):
"""Initialization line of code in module init block."""
left = "self.{}".format(self.ms_var_name)
args_list = list()
if self._args_translator is not None:
self.fragment.default_var['args'] = {**self._args_translator.actual_args,
**self._args_translator.formal_args}
args_list += self._args_translator.actual_args_to_str_list
args_list += self._args_translator.formal_args_to_str_list
else:
actual_args_str = ArgsTranslation.dict_data_to_args_str_list(self._fragment.default_var['args'])
args_list += actual_args_str

if not self._fragment.converted:
args_list.append('='.join(["input_shape", str(self._fragment.input_shape)]))
args_list.append('='.join(["output_shape", str(self._fragment.output_shape)]))
right = f"{self.ms_op.replace('::', '.')}({', '.join(args_list)})"
else:
right = f"{self.ms_op}({', '.join(args_list)})"
return left, right
# create a parameter for shared weight scenario
trainable_params = self.fragment.default_var.get("trainable_params")
if trainable_params and self.fragment.default_var.get("parameters"):
# if trainable params and the mappers accept the param declaration rewritten.
for trainable_param_postfix, data_dict in trainable_params.items():
onnx_name = data_dict.get('onnx_name')
nparray = data_dict.get('data')
try:
shape = nparray.shape
dtype = nparray.dtype
except Exception:
raise ValueError("Parameters has inconsistent data type.")
# set declare statement
declare_statement = self.fragment.fragment.create_parameter(shape, dtype)
if onnx_name not in self._global_context.repeated_weights.keys():
# if the weight is not a shared weight, set to actual declaration.
if not self.fragment.default_var["parameters"].get(trainable_param_postfix):
self.fragment.default_var["parameters"][trainable_param_postfix] = declare_statement
continue # not a shared weight, skip the rest

if onnx_name in self._global_context.repeated_weights_declaration.keys():
continue # already declared, skip
self._global_context.repeated_weights_declaration[onnx_name] = declare_statement

# set template to mapper parameter rewritten.
shared_w_var_in_parent = self._get_shared_weight_var_names_from_parent(onnx_name=onnx_name)
# add self for node node under public parent module
if self.parent_module_struct.identifier == []:
#now only consider declaration in the main model
shared_w_var_in_parent = f"self.{shared_w_var_in_parent}"
self.fragment.default_var["parameters"][trainable_param_postfix] = shared_w_var_in_parent

def code_line_in_construct(self, inputs=None):
"""Construct line of code in module construct block. """
left = self.ms_opt_var_name

if not self.matched_inputs and inputs is None:
raise ValueError("Unable to generate the code construct statement due to empty inputs.")
inputs = []

if self.matched_inputs:
inputs = self.matched_inputs
# Bind current node opt_var_name & register to parent
self.outputs_manager.bind_opt_var_names(self.fragment.fragment)
for base_out in self.outputs_manager.outputs:
opt_var = base_out.opt_var_name
self.parent_module_struct.internal_outputs_collection[base_out.onnx_edge_name] = opt_var

# Check original onnx node's input to ensure double inputs are not ignored
original_inputs = GlobalContext().onnx_node_inputs.get(self.onnx_name)
new_inputs = []
for idx, prec_node in enumerate(self.precursor_nodes_names):
occurrence = original_inputs.count(prec_node)
for _ in range(occurrence):
new_inputs.append(inputs[idx])
inputs = new_inputs

if isinstance(inputs, str):
inputs = [inputs]
# Take inputs from parents module
for input_edge in self.inputs_edges_names:
if input_edge in self.parent_module_struct.inputs_register:
inputs.append(self.parent_module_struct.inputs_register.get(input_edge))
elif input_edge in self.parent_module_struct.internal_outputs_collection:
inputs.append(self.parent_module_struct.internal_outputs_collection.get(input_edge))

self.fragment.default_var['inputs'] = inputs
right = f"self.{self.ms_var_name}({', '.join(inputs)})"
return left, right
return left

def add_extra_tensor(self):
""" Add extra tensor."""
@@ -360,12 +416,12 @@ class NodeStruct:
Args:
name (str): Can accept both node identifier or original onnx node name.
"""
target_nd_struct = GlobalContext().node_struct_collections.get(name) \
or GlobalContext().onnx_node_name_to_node_struct_map.get(name)
target_nd_struct = self._global_context.node_struct_collections.get(name) \
or self._global_context.onnx_node_name_to_node_struct_map.get(name)
if target_nd_struct is None and self.topo_idx == 0: # First node always has external input
return False

if target_nd_struct is None and (name in GlobalContext().onnx_graph_info.get('graph_inputs')):
if target_nd_struct is None and (name in self._global_context.onnx_graph_info.get('graph_inputs')):
return False

if target_nd_struct is None:


+ 91
- 0
mindinsight/mindconverter/graph_based_converter/generator/shared_weights.py View File

@@ -0,0 +1,91 @@
# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Module rocessing for shared weights."""
from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext
from mindinsight.mindconverter.graph_based_converter.generator.node_struct import NodeStruct
from mindinsight.mindconverter.graph_based_converter.generator.module_struct import ModuleStruct


class SharedWeightHelper:
"""Helper function to process shared weights."""

@staticmethod
def check_node_has_shared_weight(node: NodeStruct):
"""
Check the node has shared weight and return all of them.

Args:
node (NodeStruct): NodeStruct instance.

Returns:
list, a list of shared weight onnx names
"""
shared_weight_names = []
for shared_weight_name, repeated_node_list in GlobalContext().repeated_weights.items():
if node.onnx_name in repeated_node_list:
shared_weight_names.append(shared_weight_name)

return shared_weight_names

@staticmethod
def add_shared_weight_to_parent_module(shared_weight_name: str, module_to_be_registered: ModuleStruct):
"""Register the shared weight name to module and assign a local var name for it."""
default_weight_name = f"passthrough_w_{module_to_be_registered.shared_weights_counter}"
if shared_weight_name not in module_to_be_registered.shared_weights_collection:
module_to_be_registered.shared_weights_collection[shared_weight_name] = default_weight_name
module_to_be_registered.shared_weights_counter += 1

@staticmethod
def register_shared_weight_to_public_parent(node: NodeStruct, shared_weight_name: str, pub_module_identifier):
"""
Register shared weight from bottom to top until its public module.

Note:
Now we always consider the public module is main model, since looking for public module among multiple
nodes consume long time.

Args:where the shared weight to be used.
node (NodeStruct): The NodeStruct instance which has the shared weight.
share_weight_name (str): The onnx name of the shared weights.
pub_module_identifier (list): The identifier of the public module the shared weight in.
"""
parent_module = node.parent_module_struct
exit_flag = False
while True:
if parent_module.identifier == pub_module_identifier:
exit_flag = True
SharedWeightHelper.add_shared_weight_to_parent_module(shared_weight_name, parent_module)
parent_module = parent_module.parent_module_struct
if exit_flag:
break
if parent_module is None:
break

@staticmethod
def add_shared_weights_in_init_statement(md_struct: ModuleStruct, module_def_args: list):
"""add shared weights to module init statement."""
if md_struct.shared_weights_collection:
return module_def_args + list(md_struct.shared_weights_collection.values())
return module_def_args

@staticmethod
def public_module_shared_weight_statement_generation(public_module: ModuleStruct):
"""Return the statement of declaration of shared weights in its public module."""
statements = []
for passthrough_w_onnx_name, passthrough_w_var_name in public_module.shared_weights_collection.items():
parameter_statement = GlobalContext().repeated_weights_declaration.get(passthrough_w_onnx_name)
declare_statement = f"self.{passthrough_w_var_name} = {parameter_statement}"
statements.append(declare_statement)
return statements

+ 15
- 3
mindinsight/mindconverter/graph_based_converter/mapper/base.py View File

@@ -181,9 +181,9 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC):
return template, exchange_msg, outputs_list, outputs_mapping

@staticmethod
def _find_val_by_index(loc_index, weights_list, default_value=None):
def _find_val_by_index(loc_index, weights_list, default_val=None):
"""Find value by location index of weights_list."""
result = default_value
result = default_val
if loc_index < 0:
return weights_list[loc_index].value

@@ -196,7 +196,6 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC):
@staticmethod
def _find_location_by_index(loc_index, weights_list):
"""Find weight location in inputs of Node."""

result = -1
if loc_index < 0:
return weights_list[loc_index].location
@@ -206,3 +205,16 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC):
result = weight.location
break
return result

@staticmethod
def _find_onnx_name_by_index(loc_index, weights_list):
"""Find weight onnx name in inputs of Node."""
result = -1
if loc_index < 0:
return weights_list[loc_index].name

for idx, weight in enumerate(weights_list):
if idx == loc_index:
result = weight.name
break
return result

+ 9
- 9
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/mat_mul_mapper.py View File

@@ -33,7 +33,8 @@ class MatMulMapper(ONNXToMindSporeMapper):
def _convert_trained_weights(**kwargs):
weights = kwargs['weights']
weight = MatMulMapper._find_val_by_index(0, weights)
return {'w': {'data': weight, 'type': WeightType.PARAMETER.value}}
onnx_name = MatMulMapper._find_onnx_name_by_index(0, weights)
return {'w': {'data': weight, 'type': WeightType.PARAMETER.value, 'onnx_name': onnx_name}}

@staticmethod
def _generate_snippet_template(**kwargs):
@@ -48,15 +49,11 @@ class MatMulMapper(ONNXToMindSporeMapper):
if not weights:
return template, exchange_msg, outputs_list, outputs_mapping

tensor = MatMulMapper._find_val_by_index(0, weights)

variable_slot = "var_0"
init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})"
args["weight_shape"] = tensor.shape
args["weight_dtype"] = tensor.dtype
init_tensor = f"self.{{{variable_slot}}}_w = " \
f"Parameter(Tensor(np.random.uniform(0, 1, {{weight_shape}}).astype(np.{{weight_dtype}})), " \
f"name=None)"
# Note: adding weight shape to args is now deprecated due to conflict of partial weights share processing.
variable_slot_param_name = f"{variable_slot}/w"
init_tensor = f"self.{{{variable_slot}}}_w = {{{variable_slot_param_name}}}"
construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \
f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}," \
f"self.{{{variable_slot}}}_w)"
@@ -75,7 +72,10 @@ class MatMulMapper(ONNXToMindSporeMapper):
ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [],
ExchangeMessageKeywords.VariableScope.value.ARGS.value: args,
ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights,
ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: trainable_params
ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: trainable_params,
ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value: {
"w": ""
}
}
}
outputs_list = [f"opt_{{{variable_slot}}}"]


+ 10
- 7
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/add_mapper.py View File

@@ -35,8 +35,9 @@ class AddMapper(ONNXToMindSporeMapper):
def _convert_trained_weights(**kwargs):
weights = kwargs.get('weights', list())
tensor = AddMapper._find_val_by_index(0, weights)
onnx_name = AddMapper._find_onnx_name_by_index(0, weights)
if isinstance(tensor, np.ndarray) and tensor.shape:
return {'bias': {'data': tensor, 'type': WeightType.PARAMETER.value}}
return {'bias': {'data': tensor, 'type': WeightType.PARAMETER.value, 'onnx_name': onnx_name}}
return dict()

@staticmethod
@@ -54,7 +55,6 @@ class AddMapper(ONNXToMindSporeMapper):

tensor = AddMapper._find_val_by_index(0, weights)
bias_shape = tensor.shape
bias_dtype = tensor.dtype
bias_location = AddMapper._find_location_by_index(0, weights)

variable_slot = "var_0"
@@ -64,11 +64,10 @@ class AddMapper(ONNXToMindSporeMapper):
inputs_in_construct.insert(bias_location, f"self.{{{variable_slot}}}_bias")

if bias_shape:
args["bias_shape"] = bias_shape
args["bias_dtype"] = bias_dtype
init_tensor = f"self.{{{variable_slot}}}_bias = " \
f"Parameter(Tensor(np.random.uniform(0, 1, {{bias_shape}}).astype(np.{{bias_dtype}})), " \
f"name=None)"
# Note: adding weight shape to args is now deprecated due to conflict of partial weights share processing.
variable_slot_param_name = f"{variable_slot}/bias"
init_tensor = f"self.{{{variable_slot}}}_bias = {{{variable_slot_param_name}}}"

else:
args["bias_value"] = tensor.tolist()
init_tensor = f"self.{{{variable_slot}}}_bias = {{bias_value}}"
@@ -93,6 +92,10 @@ class AddMapper(ONNXToMindSporeMapper):
ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: trainable_params
}
}
if bias_shape:
exchange_msg[variable_slot][ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value] = {
"bias": ""
}
outputs_list = [f"opt_{{{variable_slot}}}"]
outputs_mapping = ((0, 0),)
return template, exchange_msg, outputs_list, outputs_mapping

+ 10
- 8
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/mul_mapper.py View File

@@ -36,8 +36,9 @@ class MulMapper(ONNXToMindSporeMapper):
def _convert_trained_weights(**kwargs):
weights = kwargs.get('weights', list())
tensor = MulMapper._find_val_by_index(0, weights)
onnx_name = MulMapper._find_onnx_name_by_index(0, weights)
if isinstance(tensor, np.ndarray) and tensor.shape:
return {'w': {'data': tensor, 'type': WeightType.PARAMETER.value}}
return {'w': {'data': tensor, 'type': WeightType.PARAMETER.value, 'onnx_name': onnx_name}}
return dict()

@staticmethod
@@ -47,13 +48,12 @@ class MulMapper(ONNXToMindSporeMapper):
op = kwargs.get("operation")
args = kwargs.get("converted_params")
weights = kwargs.get("weights")
trainable_params = kwargs.get("trainable_params", dict())
trainable_params = kwargs.get('trainable_params', dict())
if not weights:
return template, exchange_msg, outputs_list, outputs_mapping

tensor = MulMapper._find_val_by_index(0, weights)
w_shape = tensor.shape
w_dtype = tensor.dtype
w_location = MulMapper._find_location_by_index(0, weights)

variable_slot = "var_0"
@@ -63,11 +63,9 @@ class MulMapper(ONNXToMindSporeMapper):
inputs_in_construct.insert(w_location, f"self.{{{variable_slot}}}_w")

if w_shape:
args["w_shape"] = w_shape
args["w_dtype"] = w_dtype
init_tensor = f"self.{{{variable_slot}}}_w = " \
f"Parameter(Tensor(np.random.uniform(0, 1, {{w_shape}}).astype(np.{{w_dtype}})), " \
f"name=None)"
# Note: adding weight shape to args is now deprecated due to conflict of partial weights share processing.
variable_slot_param_name = f"{variable_slot}/w"
init_tensor = f"self.{{{variable_slot}}}_w = {{{variable_slot_param_name}}}"
else:
args["w_value"] = tensor.tolist()
init_tensor = f"self.{{{variable_slot}}}_w = {{w_value}}"
@@ -90,6 +88,10 @@ class MulMapper(ONNXToMindSporeMapper):
ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: trainable_params
}
}
if w_shape:
exchange_msg[variable_slot][ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value] = {
"w": ""
}
outputs_list = [f"opt_{{{variable_slot}}}"]
outputs_mapping = ((0, 0),)
return template, exchange_msg, outputs_list, outputs_mapping

+ 12
- 6
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py View File

@@ -421,11 +421,16 @@ class OnnxDataLoader:
self._global_context.onnx_node_name_to_topo_idx[n.name] = idx

for k in self.repeated_weight:
if not self.tensors_dict.get(k).to_array().shape:
# scalar does not have shape info
continue
self.repeated_weight[k] = record_tensors[k][:]

self._global_context.onnx_nodes_collection = self._nodes_dict
self._global_context.onnx_nodes_topo_index = nodes_topo_idx
self._global_context.repeated_weights = self.repeated_weight
# now only process shared weights for multi-inputs models
if len(self.input_nodes) > 1:
self._global_context.repeated_weights = self.repeated_weight

def _parse_tensors(self):
"""Parse each onnx tensors in the model."""
@@ -500,10 +505,14 @@ class OnnxDataLoader:
# Parse ONNX Graph level info
self._parse_graph()

# 1. parse all nodes
# 1. parse all tensors
self._parse_tensors()

# 2. parse all nodes, note that parse tensors must be done as nodes require tensor info
# to process the node weight sharing.
self._parse_nodes()

# 2. parse value info (incl. node output shape)
# 3. parse value info (incl. node output shape)
if self._is_infer_shape:
try:
self._infer_model()
@@ -514,9 +523,6 @@ class OnnxDataLoader:
log.exception(e)
raise e

# 3. parse all tensors
self._parse_tensors()

# 4. Optimize graph to eliminate some nodes.
self._find_nodes_to_be_eliminated()



Loading…
Cancel
Save