| @@ -502,14 +502,7 @@ class OnnxDataLoader: | |||||
| def _parse_graph(self): | def _parse_graph(self): | ||||
| """Parse ONNX Graph Info For usage in generator.""" | """Parse ONNX Graph Info For usage in generator.""" | ||||
| graph_inputs = [inp.name for inp in self.graph.input] | |||||
| graph_outputs = [out.name for out in self.graph.output] | graph_outputs = [out.name for out in self.graph.output] | ||||
| for output_node in self.output_nodes: | |||||
| if output_node not in graph_outputs: | |||||
| raise ValueError(f"Unexpected Node {output_node} detected which should not be a graph output.") | |||||
| for graph_inp in graph_inputs: | |||||
| if graph_inp not in self.input_nodes.keys(): | |||||
| raise ValueError(f"{graph_inp} is one of the graph inputs but user does not provide it.") | |||||
| self._global_context.onnx_graph_info['graph_inputs'] = self.input_nodes.keys() | self._global_context.onnx_graph_info['graph_inputs'] = self.input_nodes.keys() | ||||
| self._global_context.onnx_graph_info['graph_outputs'] = graph_outputs | self._global_context.onnx_graph_info['graph_outputs'] = graph_outputs | ||||
| @@ -521,7 +514,7 @@ class OnnxDataLoader: | |||||
| if output_node not in graph_outputs: | if output_node not in graph_outputs: | ||||
| raise ModelLoadingError(f"Unexpected Node {output_node} detected which should not be a graph output.") | raise ModelLoadingError(f"Unexpected Node {output_node} detected which should not be a graph output.") | ||||
| for graph_inp in graph_inputs: | for graph_inp in graph_inputs: | ||||
| if graph_inp not in self.input_nodes.keys(): | |||||
| if graph_inp not in self.input_nodes: | |||||
| raise ModelLoadingError(f"{graph_inp} is one of the graph inputs but user does not provide it.") | raise ModelLoadingError(f"{graph_inp} is one of the graph inputs but user does not provide it.") | ||||
| def initialize(self): | def initialize(self): | ||||