# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Define basic classes for generator use.""" import abc import copy from typing import Union, Iterable class BaseOutput: """ Define the class of output providing a universal nodes' and modules' output data collection. Args: output_mapping (tuple[tuple]): The mapping of outputs from onnx to mindspore. """ def __init__(self, output_mapping) -> None: super(BaseOutput).__init__() self.idx_in_ms_provider = output_mapping[0] self.idx_in_onnx_provider = output_mapping[1] # For multi users, key as user and value as index self.idx_in_ms_user = dict() self.idx_in_onnx_user = dict() # The following attributes to be set by who referenced this object. self.onnx_edge_name = None self.to_external = False @property def ms_user(self): """Return the output's user in the MindSpore.""" return self.idx_in_ms_user.keys() @property def onnx_user(self): """Return the output's user in the ONNX.""" return self.idx_in_onnx_user.keys() def deepcopy(self): """Return a deepcopy of self instance.""" return copy.deepcopy(self) class BaseOutputManager(abc.ABC): """ Base Output Manager class. Args: output_mappings (list): A list of output mapping. """ def __init__(self, output_mappings): if isinstance(self.__class__, ModuleOutputManager): return self._base_output_list = list() # init base output obj for mapping in output_mappings: obj = BaseOutput(mapping) self._base_output_list.append(obj) @property def outputs(self): """Return the list of BaseOutput in this manager.""" return self._base_output_list @outputs.setter def outputs(self, val: list): """Set the list of BaseOutput in this manager.""" for v in val: if not isinstance(v, BaseOutput): raise TypeError(f"{self.__class__} does not accept the type {type(v)} in the list given.") self._base_output_list = val @abc.abstractmethod def deepcopy(self): """Return the deepcopy of this instance.""" cls = self.__class__ result = cls.__new__(cls) result.outputs = list() for out in self._base_output_list: result.outputs.append(out.deepcopy()) return result class NodeOutputManager(BaseOutputManager): """ Node Output Manager class. Args: identifier (str): The identifier of the node. output_mappings (list): A list of the output mapping. """ def __init__(self, identifier, output_mappings=None) -> None: super(NodeOutputManager, self).__init__(output_mappings) self.identifier = identifier def deepcopy(self): new_mgr = super().deepcopy() new_mgr.identifier = self.identifier return new_mgr class ModuleOutputManager(BaseOutputManager): """ Module Output Manager class. Args: identifier (str): The identifier of the module. output_mappings (list): a list of output mapping """ def __init__(self, identifier, base_out: Union[BaseOutput, Iterable[BaseOutput]]) -> None: super(ModuleOutputManager, self).__init__(None) self.identifier = identifier self._return_list_counter = 0 self._base_output_list = list() if isinstance(base_out, BaseOutput): self._base_output_list.append(base_out) else: self._base_output_list += base_out @property def return_num(self): """Return the number of outputs to be returned.""" return self._return_list_counter @return_num.setter def return_num(self, num: int): """Set the number of outputs to be returned.""" self._return_list_counter = num def deepcopy(self): """Return a deepcopy of current instance.""" new_mgr = super().deepcopy() new_mgr.identifier = self.identifier new_mgr.return_num = self._return_list_counter return new_mgr class OutputStorage: """A class saves all outputs.""" def __init__(self): self._base_output_edge_to_instance = dict() self._base_output_edge_to_onnx_node_name = dict() self._base_output_edge_to_ms_identifier = dict() @property def outputs_collections(self) -> dict: """Return the dict of edge name to output instance.""" return self._base_output_edge_to_instance def onnx_name(self, output_edge) -> str: """Return the dict of edge name to onnx node name.""" return self._base_output_edge_to_onnx_node_name.get(output_edge) def node_identifier(self, output_edge): """Return the dict of edge name to node identifier.""" return self._base_output_edge_to_ms_identifier.get(output_edge) def add_output(self, out: BaseOutput) -> str: """ Add a BaseOutput instance to the storage. Args: out (BaseOutput): The BaseOutput instance. """ if out.onnx_edge_name: self._base_output_edge_to_instance[out.onnx_edge_name] = out else: raise ValueError("Unable to add a BaseOutput instance with unknown ONNX edge.") def add_onnx_node_name(self, edge: str, onnx_node_name: str): """ Add the onnx node name with the edge name. Args: edge (str): The edge name of this output. onnx_node_name (str): The onnx node which has the edge. """ self._base_output_edge_to_onnx_node_name[edge] = onnx_node_name def add_ms_identifier(self, edge: str, ms_identifier: str): """ Add the node identifier with the edge name. Args: edge (str): The edge name of this output. ms_identifier (str): The identifier of the node which has the edge. """ self._base_output_edge_to_ms_identifier[edge] = ms_identifier