# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Define common utils.""" import json import os import stat from importlib import import_module from importlib.util import find_spec from typing import List, Tuple, Mapping import numpy as np from mindspore.train.serialization import save_checkpoint from mindinsight.mindconverter.common.exceptions import ScriptGenerationError, ReportGenerationError, \ UnknownModelError, CheckPointGenerationError, WeightMapGenerationError from mindinsight.mindconverter.common.log import logger as log from mindinsight.mindconverter.graph_based_converter.constant import SEPARATOR_IN_ONNX_OP, BINARY_HEADER_PYTORCH_BITS, \ FrameworkType, BINARY_HEADER_PYTORCH_FILE, TENSORFLOW_MODEL_SUFFIX, THIRD_PART_VERSION def is_converted(operation: str): """ Whether convert successful. Args: operation (str): Operation name. Returns: bool, true or false. """ return operation and SEPARATOR_IN_ONNX_OP not in operation def _add_outputs_of_onnx_model(model, output_nodes: List[str]): """ Add output nodes of onnx model. Args: model (ModelProto): ONNX model. output_nodes (list[str]): Output nodes list. Returns: ModelProto, edited ONNX model. """ onnx = import_module("onnx") for opt_name in output_nodes: intermediate_layer_value_info = onnx.helper.ValueInfoProto() intermediate_layer_value_info.name = opt_name model.graph.output.append(intermediate_layer_value_info) return model def check_dependency_integrity(*packages): """Check dependency package integrity.""" try: for pkg in packages: import_module(pkg) return True except ImportError: return False def build_feed_dict(onnx_model, input_nodes: dict): """Build feed dict for onnxruntime.""" dtype_mapping = getattr(import_module("tf2onnx.utils"), "ONNX_TO_NUMPY_DTYPE") input_nodes_types = { node.name: dtype_mapping[node.type.tensor_type.elem_type] for node in onnx_model.graph.input } feed_dict = { name: np.random.rand(*shape).astype(input_nodes_types[name]) for name, shape in input_nodes.items() } return feed_dict def fetch_output_from_onnx_model(model, feed_dict: dict, output_nodes: List[str]): """ Fetch specific nodes output from onnx model. Notes: Only support to get output without batch dimension. Args: model (ModelProto): ONNX model. feed_dict (dict): Feed forward inputs. output_nodes (list[str]): Output nodes list. Returns: dict, nodes' output value. """ if not isinstance(feed_dict, dict) or not isinstance(output_nodes, list): raise TypeError("`feed_dict` should be type of dict, and `output_nodes` " "should be type of List[str].") edit_model = _add_outputs_of_onnx_model(model, output_nodes) ort = import_module("onnxruntime") sess = ort.InferenceSession(path_or_bytes=bytes(edit_model.SerializeToString())) fetched_res = sess.run(output_names=output_nodes, input_feed=feed_dict) run_result = dict() for idx, opt in enumerate(output_nodes): run_result[opt] = fetched_res[idx] return run_result def save_code_file_and_report(model_name: str, code_lines: Mapping[str, Tuple], out_folder: str, report_folder: str): """ Save code file and report. Args: model_name (str): Model name. code_lines (dict): Code lines. out_folder (str): Output folder. report_folder (str): Report output folder. """ flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL modes = stat.S_IRUSR | stat.S_IWUSR modes_usr = stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR out_folder = os.path.realpath(out_folder) if not report_folder: report_folder = out_folder else: report_folder = os.path.realpath(report_folder) if not os.path.exists(out_folder): os.makedirs(out_folder, modes_usr) if not os.path.exists(report_folder): os.makedirs(report_folder, modes_usr) for file_name in code_lines: code, report, trainable_weights, weight_map = code_lines[file_name] code_file_path = os.path.realpath(os.path.join(out_folder, f"{model_name}.py")) report_file_path = os.path.realpath(os.path.join(report_folder, f"report_of_{model_name}.txt")) try: if os.path.exists(code_file_path): raise ScriptGenerationError("Code file with the same name already exists.") with os.fdopen(os.open(code_file_path, flags, modes), 'w') as file: file.write(code) except (IOError, FileExistsError) as error: raise ScriptGenerationError(str(error)) try: if os.path.exists(report_file_path): raise ReportGenerationError("Report file with the same name already exists.") with os.fdopen(os.open(report_file_path, flags, stat.S_IRUSR), "w") as rpt_f: rpt_f.write(report) except (IOError, FileExistsError) as error: raise ReportGenerationError(str(error)) ckpt_file_path = os.path.realpath(os.path.join(out_folder, f"{model_name}.ckpt")) try: if os.path.exists(ckpt_file_path): raise CheckPointGenerationError("Checkpoint file with the same name already exists.") save_checkpoint(trainable_weights, ckpt_file_path) except TypeError as error: raise CheckPointGenerationError(str(error)) weight_map_path = os.path.realpath(os.path.join(out_folder, f"weight_map_of_{model_name}.json")) try: if os.path.exists(weight_map_path): raise WeightMapGenerationError("Weight map file with the same name already exists.") with os.fdopen(os.open(weight_map_path, flags, stat.S_IRUSR), 'w') as map_f: weight_map_json = {f"{model_name}": weight_map} json.dump(weight_map_json, map_f) except (IOError, FileExistsError) as error: raise WeightMapGenerationError(str(error)) def onnx_satisfied(): """Validate ONNX , ONNXRUNTIME, ONNXOPTIMIZER installation.""" if not find_spec("onnx") or not find_spec("onnxruntime") or not find_spec("onnxoptimizer"): return False return True def lib_version_satisfied(current_ver: str, mini_ver_limited: str, newest_ver_limited: str = ""): """ Check python lib version whether is satisfied. Notes: Version number must be format of x.x.x, e.g. 1.1.0. Args: current_ver (str): Current lib version. mini_ver_limited (str): Mini lib version. newest_ver_limited (str): Newest lib version. Returns: bool, true or false. """ required_version_number_len = 3 if len(list(current_ver.split("."))) != required_version_number_len or \ len(list(mini_ver_limited.split("."))) != required_version_number_len or \ (newest_ver_limited and len(newest_ver_limited.split(".")) != required_version_number_len): raise ValueError("Version number must be format of x.x.x.") if current_ver < mini_ver_limited or (newest_ver_limited and current_ver > newest_ver_limited): return False return True def get_dict_key_by_value(val, dic): """ Return the first appeared key of a dictionary by given value. Args: val (Any): Value of the key. dic (dict): Dictionary to be checked. Returns: Any, key of the given value. """ for d_key, d_val in dic.items(): if d_val == val: return d_key return None def convert_bytes_string_to_string(bytes_str): """ Convert a byte string to string by utf-8. Args: bytes_str (bytes): A bytes string. Returns: str, a str with utf-8 encoding. """ if isinstance(bytes_str, bytes): return bytes_str.decode('utf-8') return bytes_str def get_framework_type(model_path): """Get framework type.""" if model_path.endswith('.onnx'): return FrameworkType.PYTORCH.value try: with open(model_path, 'rb') as f: if f.read(BINARY_HEADER_PYTORCH_BITS) == BINARY_HEADER_PYTORCH_FILE: framework_type = FrameworkType.PYTORCH.value elif os.path.basename(model_path).split(".")[-1].lower() == TENSORFLOW_MODEL_SUFFIX: framework_type = FrameworkType.TENSORFLOW.value else: framework_type = FrameworkType.UNKNOWN.value except IOError: error_msg = "Get UNSUPPORTED model." error = UnknownModelError(error_msg) log.error(str(error)) raise error return framework_type def reset_init_or_construct(template, variable_slot, new_data, scope): """Reset init statement.""" template[variable_slot][scope].clear() template[variable_slot][scope] += new_data return template def replace_string_in_list(str_list: list, original_str: str, target_str: str): """ Replace a string in a list by provided string. Args: str_list (list): A list contains the string to be replaced. original_str (str): The string to be replaced. target_str (str): The replacement of string. Returns, list, the original list with replaced string. """ return [s.replace(original_str, target_str) for s in str_list] def get_third_part_lib_validation_error_info(lib_list): """Get error info when not satisfying third part lib validation.""" error_info = None link_str = ', ' for idx, lib in enumerate(lib_list): if idx == len(lib_list) - 1: link_str = ' and ' lib_version_required = THIRD_PART_VERSION[lib] if len(lib_version_required) == 2: lib_version_required_min = lib_version_required[0] lib_version_required_max = lib_version_required[1] if lib_version_required_min == lib_version_required_max: info = f"{lib}(=={lib_version_required_min})" else: info = f"{lib}(>={lib_version_required_min} and <{lib_version_required_max})" else: info = f"{lib}(>={lib_version_required[0]})" if not error_info: error_info = info else: error_info = link_str.join((error_info, info)) return error_info