diff --git a/mindinsight/mindconverter/graph_based_converter/framework.py b/mindinsight/mindconverter/graph_based_converter/framework.py index e1d822f8..9364cd14 100644 --- a/mindinsight/mindconverter/graph_based_converter/framework.py +++ b/mindinsight/mindconverter/graph_based_converter/framework.py @@ -147,8 +147,8 @@ def tf_installation_validation(func): type, inner function. """ - def _f(graph_path: str, sample_shape: tuple, output_folder: str, report_folder: str = None, - input_nodes: str = None, output_nodes: str = None): + def _f(graph_path: str, input_nodes: dict, output_nodes: List[str], + output_folder: str, report_folder: str): not_integral_error = RuntimeIntegrityError( f"TensorFlow, " f"{get_third_part_lib_validation_error_info(['tf2onnx', 'onnx', 'onnxruntime', 'onnxoptimizer'])} " @@ -171,9 +171,9 @@ def tf_installation_validation(func): _print_error(not_integral_error) sys.exit(0) - func(graph_path=graph_path, sample_shape=sample_shape, - output_folder=output_folder, report_folder=report_folder, - input_nodes=input_nodes, output_nodes=output_nodes) + func(graph_path=graph_path, + input_nodes=input_nodes, output_nodes=output_nodes, + output_folder=output_folder, report_folder=report_folder) return _f