| @@ -28,7 +28,7 @@ from mindinsight.mindconverter.graph_based_converter.third_party_graph.optimizer | |||||
| from mindinsight.mindconverter.graph_based_converter.constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \ | from mindinsight.mindconverter.graph_based_converter.constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \ | ||||
| ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL | ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL | ||||
| from mindinsight.mindconverter.common.exceptions import GraphInitError | |||||
| from mindinsight.mindconverter.common.exceptions import GraphInitError, ModelLoadingError | |||||
| def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=12): | def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=12): | ||||
| @@ -287,6 +287,9 @@ class OnnxDataLoader: | |||||
| self.dynamic_reshape_node = list() | self.dynamic_reshape_node = list() | ||||
| self.eliminated_nodes = list() | self.eliminated_nodes = list() | ||||
| # Validate init params | |||||
| self._check_user_provided_info() | |||||
| self.initialize() | self.initialize() | ||||
| @property | @property | ||||
| @@ -510,6 +513,17 @@ class OnnxDataLoader: | |||||
| self._global_context.onnx_graph_info['graph_inputs'] = self.input_nodes.keys() | self._global_context.onnx_graph_info['graph_inputs'] = self.input_nodes.keys() | ||||
| self._global_context.onnx_graph_info['graph_outputs'] = graph_outputs | self._global_context.onnx_graph_info['graph_outputs'] = graph_outputs | ||||
| def _check_user_provided_info(self): | |||||
| """Validate user input and output node.""" | |||||
| graph_inputs = [inp.name for inp in self.graph.input] | |||||
| graph_outputs = [out.name for out in self.graph.output] | |||||
| for output_node in self.output_nodes: | |||||
| if output_node not in graph_outputs: | |||||
| raise ModelLoadingError(f"Unexpected Node {output_node} detected which should not be a graph output.") | |||||
| for graph_inp in graph_inputs: | |||||
| if graph_inp not in self.input_nodes.keys(): | |||||
| raise ModelLoadingError(f"{graph_inp} is one of the graph inputs but user does not provide it.") | |||||
| def initialize(self): | def initialize(self): | ||||
| """Initialize the OnnxDataLoader.""" | """Initialize the OnnxDataLoader.""" | ||||