diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py index cd7acaae..f5ecba40 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py @@ -28,7 +28,7 @@ from mindinsight.mindconverter.graph_based_converter.third_party_graph.optimizer from mindinsight.mindconverter.graph_based_converter.constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \ ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL -from mindinsight.mindconverter.common.exceptions import GraphInitError +from mindinsight.mindconverter.common.exceptions import GraphInitError, ModelLoadingError def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=12): @@ -287,6 +287,9 @@ class OnnxDataLoader: self.dynamic_reshape_node = list() self.eliminated_nodes = list() + # Validate init params + self._check_user_provided_info() + self.initialize() @property @@ -510,6 +513,17 @@ class OnnxDataLoader: self._global_context.onnx_graph_info['graph_inputs'] = self.input_nodes.keys() self._global_context.onnx_graph_info['graph_outputs'] = graph_outputs + def _check_user_provided_info(self): + """Validate user input and output node.""" + graph_inputs = [inp.name for inp in self.graph.input] + graph_outputs = [out.name for out in self.graph.output] + for output_node in self.output_nodes: + if output_node not in graph_outputs: + raise ModelLoadingError(f"Unexpected Node {output_node} detected which should not be a graph output.") + for graph_inp in graph_inputs: + if graph_inp not in self.input_nodes.keys(): + raise ModelLoadingError(f"{graph_inp} is one of the graph inputs but user does not provide it.") + def initialize(self): """Initialize the OnnxDataLoader."""