Browse Source

bug fix

pull/1282/head
Xuan Yang xuan 4 years ago
parent
commit
5b2bc9c9ec
1 changed files with 15 additions and 1 deletions
  1. +15
    -1
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py

+ 15
- 1
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py View File

@@ -28,7 +28,7 @@ from mindinsight.mindconverter.graph_based_converter.third_party_graph.optimizer

from mindinsight.mindconverter.graph_based_converter.constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \
ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL
from mindinsight.mindconverter.common.exceptions import GraphInitError
from mindinsight.mindconverter.common.exceptions import GraphInitError, ModelLoadingError


def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=12):
@@ -287,6 +287,9 @@ class OnnxDataLoader:
self.dynamic_reshape_node = list()
self.eliminated_nodes = list()

# Validate init params
self._check_user_provided_info()

self.initialize()

@property
@@ -510,6 +513,17 @@ class OnnxDataLoader:
self._global_context.onnx_graph_info['graph_inputs'] = self.input_nodes.keys()
self._global_context.onnx_graph_info['graph_outputs'] = graph_outputs

def _check_user_provided_info(self):
"""Validate user input and output node."""
graph_inputs = [inp.name for inp in self.graph.input]
graph_outputs = [out.name for out in self.graph.output]
for output_node in self.output_nodes:
if output_node not in graph_outputs:
raise ModelLoadingError(f"Unexpected Node {output_node} detected which should not be a graph output.")
for graph_inp in graph_inputs:
if graph_inp not in self.input_nodes.keys():
raise ModelLoadingError(f"{graph_inp} is one of the graph inputs but user does not provide it.")

def initialize(self):
"""Initialize the OnnxDataLoader."""



Loading…
Cancel
Save