Browse Source

Fix bugs in conversion of cv model.

tags/v1.2.0-rc1
liuchongming 4 years ago
parent
commit
8531be1e68
10 changed files with 41 additions and 23 deletions
  1. +1
    -1
      mindinsight/mindconverter/graph_based_converter/common/utils.py
  2. +6
    -5
      mindinsight/mindconverter/graph_based_converter/framework.py
  3. +7
    -4
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/common.py
  4. +5
    -2
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py
  5. +3
    -2
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py
  6. +7
    -3
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py
  7. +4
    -2
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py
  8. +2
    -1
      mindinsight/mindconverter/graph_based_converter/third_party_graph/optimizer.py
  9. +4
    -2
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py
  10. +2
    -1
      mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py

+ 1
- 1
mindinsight/mindconverter/graph_based_converter/common/utils.py View File

@@ -81,7 +81,7 @@ def build_feed_dict(onnx_model, input_nodes: dict):
for node in onnx_model.graph.input
}
feed_dict = {
name: np.random.rand(*shape).astype(input_nodes_types[name.split(":")[0]])
name: np.random.rand(*shape).astype(input_nodes_types[name])
for name, shape in input_nodes.items()
}
return feed_dict


+ 6
- 5
mindinsight/mindconverter/graph_based_converter/framework.py View File

@@ -83,7 +83,7 @@ def torch_installation_validation(func):
type, inner function.
"""

def _f(graph_path: str, sample_shape: tuple, input_nodes: str, output_nodes: str,
def _f(graph_path: str, input_nodes: dict, output_nodes: List[str],
output_folder: str, report_folder: str = None):
# Check whether pytorch is installed.
error_info = None
@@ -119,7 +119,7 @@ def torch_installation_validation(func):
_print_error(error)
sys.exit(0)

func(graph_path=graph_path, sample_shape=sample_shape,
func(graph_path=graph_path,
input_nodes=input_nodes, output_nodes=output_nodes,
output_folder=output_folder, report_folder=report_folder)

@@ -265,11 +265,12 @@ def main_graph_base_converter(file_config):
if not file_config.get("shape"):
raise ParamMissingError("Param missing, `--shape` is required when using graph mode.")

if graph_path.endswith("pth") and not file_config['input_nodes'] and \
file_config.get("shape") and len(file_config.get("shape")) == 1:
if graph_path.endswith("pth") and not file_config.get("input_nodes", []) and \
file_config.get("shape") and len(file_config.get("shape", ())) == 1:
file_config['input_nodes'] = ["input.1"]

if len(file_config['shape']) != len(file_config['input_nodes']) != len(set(file_config['input_nodes'])):
if len(file_config['shape']) != len(file_config.get("input_nodes", [])) != len(
set(file_config.get("input_nodes", []))):
raise BadParamError("`--shape` and `--input_nodes` must have the same length, "
"and no redundant node in `--input_nodes`.")



+ 7
- 4
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/common.py View File

@@ -17,12 +17,13 @@
__all__ = ["context",
"gen_hash_key",
"DagGraph",
"MAX_OUT_DEGREE",
"MAX_DEGREE",
"cal_matching_score",
"ACCEPTABLE_RESULT_COUNT",
"MINI_FREQUENCY",
"SATISFIED_SCORE",
"MAX_ITERATION_DEPTH"]
"MAX_ITERATION_DEPTH_OF_MULTI_IPT",
"MAX_ITERATION_DEPTH_OF_SINGLE_IPT"]

import math
import copy
@@ -32,9 +33,10 @@ from typing import List

from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import BaseNode

MAX_OUT_DEGREE = 1
MAX_DEGREE = 1
MINI_FREQUENCY = 0.07
MAX_ITERATION_DEPTH = 16
MAX_ITERATION_DEPTH_OF_MULTI_IPT = 16
MAX_ITERATION_DEPTH_OF_SINGLE_IPT = 8
SATISFIED_SCORE = 0.74
ACCEPTABLE_RESULT_COUNT = 32
PTN_COVERAGE_THRESHOLD = 0.65
@@ -127,6 +129,7 @@ class AlgorithmContext:
precursor_table = {}
successor_table = {}
outputs_table = {}
has_multi_inputs = False

def set_init_node_collection(self, nd_col):
"""Init node_collection."""


+ 5
- 2
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py View File

@@ -21,7 +21,6 @@ class Pattern:

def __init__(self, pattern, pattern_length, in_degree, out_degree, ptn_items: list = None):
self.pattern = pattern
self.count = 0
self.start_index = []
self.end_index = []
self.module_name = None
@@ -37,6 +36,11 @@ class Pattern:
self.additional_score = 0
self.known_module_name = None

@property
def count(self):
"""Count of the pattern."""
return len(self.start_index)

def insert(self, idx, seq_len):
"""
Insert a new position.
@@ -49,7 +53,6 @@ class Pattern:
return
self.start_index.append(idx)
self.end_index.append(idx + seq_len)
self.count += 1

def __str__(self):
"""Override `str()` method."""


+ 3
- 2
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py View File

@@ -21,7 +21,7 @@ from typing import Dict, List, Callable, Union
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.built_in_pattern import BUILT_IN_PATTERN, \
is_built_in_pattern
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import context, gen_hash_key, DagGraph, \
MAX_OUT_DEGREE, cal_matching_score
MAX_DEGREE, cal_matching_score
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.known_module_name import BUILT_IN_MODULE_NAME
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern, scope_name_mapping
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern_fuzzy_matching import \
@@ -390,7 +390,7 @@ def generate_pattern(topo_order: List[BaseNode], dag: DagGraph,
dag=dag)

in_degree, out_degree, _, _ = _get_pattern_degree(found_sequence, dag)
if out_degree > MAX_OUT_DEGREE:
if out_degree > MAX_DEGREE or (not context.has_multi_inputs and in_degree > MAX_DEGREE):
cur_idx += 1
continue

@@ -419,6 +419,7 @@ def _post_process_overlap(patterns) -> Dict:
patterns[name].start_index.pop(idx)
patterns[name].end_index.pop(idx)
continue
prev_end = patterns[name].end_index[idx]
idx += 1
return patterns



+ 7
- 3
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py View File

@@ -17,9 +17,9 @@ from queue import PriorityQueue
from typing import Dict, List

from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import context, DagGraph, gen_hash_key, \
ACCEPTABLE_RESULT_COUNT
ACCEPTABLE_RESULT_COUNT, MAX_ITERATION_DEPTH_OF_SINGLE_IPT
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import MINI_FREQUENCY, \
MAX_ITERATION_DEPTH, SATISFIED_SCORE
MAX_ITERATION_DEPTH_OF_MULTI_IPT, SATISFIED_SCORE
from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext
from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import BaseNode
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.search_path import SearchPath, Pattern, \
@@ -37,7 +37,9 @@ def _is_satisfied(path):
Returns:
bool, True or False.
"""
if len(path.recursion_path) > MAX_ITERATION_DEPTH:
recursion_depth = MAX_ITERATION_DEPTH_OF_MULTI_IPT if context.has_multi_inputs \
else MAX_ITERATION_DEPTH_OF_SINGLE_IPT
if len(path.recursion_path) > recursion_depth:
return True
candidate_eval = any([is_pattern_satisfied(p, path) for p in path.new_pattern.values()])
if not path.new_pattern or not candidate_eval:
@@ -262,6 +264,8 @@ def _build_connection(loader):
context.successor_table[node_name] = list(node.get_successor_dict().keys())
context.outputs_table[node_name] = node.output_name_list

# Record the model inputs count, use it to control the search algorithm.
context.has_multi_inputs = len(loader.input_nodes) > 1
dag = DagGraph(nodes=context.node_collection.copy(),
precursor=context.precursor_table.copy(),
successor=context.successor_table.copy())


+ 4
- 2
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py View File

@@ -202,6 +202,8 @@ class OnnxGraph(Graph):
else:
onnx_model = PyTorchGraphParser.parse(graph_path, **kwargs)
onnx_inputs = [onnx_input.name for onnx_input in onnx_model.graph.input]
if input_nodes not in onnx_inputs:
raise ModelNotSupportError(f"input nodes({input_nodes}) is not in model inputs ({onnx_inputs}).")
for ipt in input_nodes:
if ipt not in onnx_inputs:
raise ModelNotSupportError(f"input nodes({input_nodes}) is not "
f"in model inputs ({onnx_inputs}).")
return onnx_model

+ 2
- 1
mindinsight/mindconverter/graph_based_converter/third_party_graph/optimizer.py View File

@@ -93,7 +93,8 @@ class OnnxSimplify:
self._constant_nodes = copy.deepcopy(const_nodes)

@ModelNotSupportError.check_except(
"Error occurs in loading model, please check your model or runtime environment integrity."
"Error occurs when loading model with given params, please check `--shape`, "
"`--input_nodes`, `--output_nodes`, `--model_file` or runtime environment integrity."
)
def _onnx_infer(self, infer_inputs_shape):
"""


+ 4
- 2
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py View File

@@ -27,7 +27,8 @@ class PyTorchGraphParser(GraphParser):

@classmethod
@ModelNotSupportError.check_except(
"Error occurs in loading model, please check your model or runtime environment integrity."
"Error occurs when loading model with given params, please check `--shape`, "
"`--input_nodes`, `--output_nodes`, `--model_file` or runtime environment integrity."
)
def parse(cls, model_path: str, **kwargs):
"""
@@ -47,8 +48,9 @@ class PyTorchGraphParser(GraphParser):
raise error

try:
sample_shape = list(kwargs.get("input_nodes").values())[0]
onnx_model_sim = cls._convert_pytorch_graph_to_onnx(
model_path, kwargs['sample_shape'], opset_version=11)
model_path, sample_shape, opset_version=11)
return onnx_model_sim

except ModuleNotFoundError:


+ 2
- 1
mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py View File

@@ -27,7 +27,8 @@ class TFGraphParser(GraphParser):

@classmethod
@ModelNotSupportError.check_except(
"Error occurs in loading model, please check your model or runtime environment integrity."
"Error occurs when loading model with given params, please check `--shape`, "
"`--input_nodes`, `--output_nodes`, `--model_file` or runtime environment integrity."
)
def parse(cls, model_path: str, **kwargs):
"""


Loading…
Cancel
Save