|
|
|
@@ -32,6 +32,7 @@ from mindinsight.mindconverter.common.exceptions import GraphInitError, SourceFi |
|
|
|
BaseConverterError, UnknownModelError, GeneratorError, TfRuntimeError, RuntimeIntegrityError, ParamMissingError, \ |
|
|
|
BadParamError |
|
|
|
from mindinsight.mindconverter.graph_based_converter.third_party_graph import GraphFactory |
|
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
check_common_dependency_integrity = partial(check_dependency_integrity, |
|
|
|
"onnx", "onnxruntime", "onnxoptimizer") |
|
|
|
@@ -175,11 +176,28 @@ def graph_based_converter_onnx_to_ms(graph_path: str, |
|
|
|
output_folder (str): Output folder. |
|
|
|
report_folder (str): Report output folder path. |
|
|
|
""" |
|
|
|
graph_obj = GraphFactory.init(graph_path, input_nodes=input_nodes, output_nodes=output_nodes) |
|
|
|
generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper) |
|
|
|
model_name = _extract_model_name(graph_path) |
|
|
|
code_fragments = generator_inst.generate() |
|
|
|
save_code_file_and_report(model_name, code_fragments, output_folder, report_folder) |
|
|
|
progress = tqdm(range(3), leave=True, dynamic_ncols=True) |
|
|
|
# The father bootstrap of onnx_to_ms progress begins. |
|
|
|
for i in progress: |
|
|
|
if i == 0: |
|
|
|
progress.set_description('MindConverter Starting, it may take a moment') |
|
|
|
graph_obj = GraphFactory.init(graph_path, input_nodes=input_nodes, output_nodes=output_nodes) |
|
|
|
child_process = tqdm(range(1), desc="{:17}".format("Graph Parsing")) |
|
|
|
# The child bootstrap of graph parsing progress begins. |
|
|
|
for _ in child_process: |
|
|
|
generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper) |
|
|
|
elif i == 1: |
|
|
|
progress.set_description("{:17}".format("Generating codes")) |
|
|
|
child_process = tqdm(range(2), desc="{:17}".format("Generating codes")) |
|
|
|
# The child bootstrap of generating codes progress begins. |
|
|
|
for j in child_process: |
|
|
|
if j == 0: |
|
|
|
model_name = _extract_model_name(graph_path) |
|
|
|
else: |
|
|
|
code_fragments = generator_inst.generate() |
|
|
|
else: |
|
|
|
progress.set_description("{:17}".format("Saving code file")) |
|
|
|
save_code_file_and_report(model_name, code_fragments, output_folder, report_folder) |
|
|
|
# Release global context. |
|
|
|
GlobalContext.release() |
|
|
|
|
|
|
|
@@ -204,12 +222,29 @@ def graph_based_converter_tf_to_ms(graph_path: str, |
|
|
|
""" |
|
|
|
# Close unnecessary log. |
|
|
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' |
|
|
|
|
|
|
|
log_console.info("MindConverter Starting, it may take a moment") |
|
|
|
graph_obj = GraphFactory.init(graph_path, input_nodes=input_nodes, output_nodes=output_nodes) |
|
|
|
generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper) |
|
|
|
model_name = _extract_model_name(graph_path) |
|
|
|
code_fragments = generator_inst.generate() |
|
|
|
save_code_file_and_report(model_name, code_fragments, output_folder, report_folder) |
|
|
|
progress = tqdm(range(3), leave=True, dynamic_ncols=True) |
|
|
|
# The father bootstrap of tf_to_ms progress begins. |
|
|
|
for i in progress: |
|
|
|
if i == 0: |
|
|
|
progress.set_description("{:17}".format("Graph Parsing")) |
|
|
|
child_process = tqdm(range(1), desc="{:17}".format("Graph Parsing")) |
|
|
|
# The child bootstrap of graph parsing progress begins. |
|
|
|
for _ in child_process: |
|
|
|
generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper) |
|
|
|
elif i == 1: |
|
|
|
progress.set_description("{:17}".format("Generating codes")) |
|
|
|
child_process = tqdm(range(2), desc="{:17}".format("Generating codes")) |
|
|
|
# The child bootstrap of generating codes progress begins. |
|
|
|
for j in child_process: |
|
|
|
if j == 0: |
|
|
|
model_name = _extract_model_name(graph_path) |
|
|
|
else: |
|
|
|
code_fragments = generator_inst.generate() |
|
|
|
else: |
|
|
|
progress.set_description("{:17}".format("Saving code file")) |
|
|
|
save_code_file_and_report(model_name, code_fragments, output_folder, report_folder) |
|
|
|
# Release global context. |
|
|
|
GlobalContext.release() |
|
|
|
|
|
|
|
|