|
|
|
@@ -28,7 +28,7 @@ from mindinsight.mindconverter.graph_based_converter.third_party_graph.optimizer |
|
|
|
|
|
|
|
from mindinsight.mindconverter.graph_based_converter.constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \ |
|
|
|
ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL |
|
|
|
from mindinsight.mindconverter.common.exceptions import GraphInitError |
|
|
|
from mindinsight.mindconverter.common.exceptions import GraphInitError, ModelLoadingError |
|
|
|
|
|
|
|
|
|
|
|
def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=12): |
|
|
|
@@ -287,6 +287,9 @@ class OnnxDataLoader: |
|
|
|
self.dynamic_reshape_node = list() |
|
|
|
self.eliminated_nodes = list() |
|
|
|
|
|
|
|
# Validate init params |
|
|
|
self._check_user_provided_info() |
|
|
|
|
|
|
|
self.initialize() |
|
|
|
|
|
|
|
@property |
|
|
|
@@ -510,6 +513,17 @@ class OnnxDataLoader: |
|
|
|
self._global_context.onnx_graph_info['graph_inputs'] = self.input_nodes.keys() |
|
|
|
self._global_context.onnx_graph_info['graph_outputs'] = graph_outputs |
|
|
|
|
|
|
|
def _check_user_provided_info(self): |
|
|
|
"""Validate user input and output node.""" |
|
|
|
graph_inputs = [inp.name for inp in self.graph.input] |
|
|
|
graph_outputs = [out.name for out in self.graph.output] |
|
|
|
for output_node in self.output_nodes: |
|
|
|
if output_node not in graph_outputs: |
|
|
|
raise ModelLoadingError(f"Unexpected Node {output_node} detected which should not be a graph output.") |
|
|
|
for graph_inp in graph_inputs: |
|
|
|
if graph_inp not in self.input_nodes.keys(): |
|
|
|
raise ModelLoadingError(f"{graph_inp} is one of the graph inputs but user does not provide it.") |
|
|
|
|
|
|
|
def initialize(self): |
|
|
|
"""Initialize the OnnxDataLoader.""" |
|
|
|
|
|
|
|
|