|
|
|
@@ -23,6 +23,7 @@ from typing import List, Tuple, Mapping |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
from mindinsight.mindconverter.common.log import logger as log |
|
|
|
from mindinsight.mindconverter.common.exceptions import ScriptGenerationError, ReportGenerationError, \ |
|
|
|
CheckPointGenerationError, WeightMapGenerationError, ModelLoadingError, OnnxModelSaveError |
|
|
|
from mindinsight.mindconverter.graph_based_converter.constant import SEPARATOR_IN_ONNX_OP, FrameworkType, \ |
|
|
|
@@ -67,7 +68,8 @@ def check_dependency_integrity(*packages): |
|
|
|
for pkg in packages: |
|
|
|
import_module(pkg) |
|
|
|
return True |
|
|
|
except ImportError: |
|
|
|
except ImportError as e: |
|
|
|
log.exception(e) |
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
@@ -120,7 +122,6 @@ def fetch_output_from_onnx_model(model, model_path: str, feed_dict: dict, output |
|
|
|
os.remove(tmp_file) |
|
|
|
raise OnnxModelSaveError("Onnx model save failed, {}".format(str(error))) |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
sess = ort.InferenceSession(path_or_bytes=tmp_file) |
|
|
|
fetched_res = sess.run(output_names=output_nodes, input_feed=feed_dict) |
|
|
|
|