|
|
|
@@ -502,14 +502,7 @@ class OnnxDataLoader: |
|
|
|
|
|
|
|
def _parse_graph(self): |
|
|
|
"""Parse ONNX Graph Info For usage in generator.""" |
|
|
|
graph_inputs = [inp.name for inp in self.graph.input] |
|
|
|
graph_outputs = [out.name for out in self.graph.output] |
|
|
|
for output_node in self.output_nodes: |
|
|
|
if output_node not in graph_outputs: |
|
|
|
raise ValueError(f"Unexpected Node {output_node} detected which should not be a graph output.") |
|
|
|
for graph_inp in graph_inputs: |
|
|
|
if graph_inp not in self.input_nodes.keys(): |
|
|
|
raise ValueError(f"{graph_inp} is one of the graph inputs but user does not provide it.") |
|
|
|
self._global_context.onnx_graph_info['graph_inputs'] = self.input_nodes.keys() |
|
|
|
self._global_context.onnx_graph_info['graph_outputs'] = graph_outputs |
|
|
|
|
|
|
|
@@ -521,7 +514,7 @@ class OnnxDataLoader: |
|
|
|
if output_node not in graph_outputs: |
|
|
|
raise ModelLoadingError(f"Unexpected Node {output_node} detected which should not be a graph output.") |
|
|
|
for graph_inp in graph_inputs: |
|
|
|
if graph_inp not in self.input_nodes.keys(): |
|
|
|
if graph_inp not in self.input_nodes: |
|
|
|
raise ModelLoadingError(f"{graph_inp} is one of the graph inputs but user does not provide it.") |
|
|
|
|
|
|
|
def initialize(self): |
|
|
|
|