| @@ -13,7 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Graph based scripts converter definition.""" | """Graph based scripts converter definition.""" | ||||
| __all__ = ["graph_based_converter_pytorch_to_ms", "graph_based_converter_tf_to_ms"] | |||||
| from .framework import graph_based_converter_pytorch_to_ms | from .framework import graph_based_converter_pytorch_to_ms | ||||
| from .framework import graph_based_converter_tf_to_ms | from .framework import graph_based_converter_tf_to_ms | ||||
| __all__ = ["graph_based_converter_pytorch_to_ms", "graph_based_converter_tf_to_ms"] | |||||
| @@ -13,16 +13,14 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Hierarchical tree module.""" | """Hierarchical tree module.""" | ||||
| __all__ = ["HierarchicalTreeFactory"] | |||||
| import re | import re | ||||
| from mindinsight.mindconverter.common.log import logger as log | from mindinsight.mindconverter.common.log import logger as log | ||||
| from .hierarchical_tree import HierarchicalTree | from .hierarchical_tree import HierarchicalTree | ||||
| from ..third_party_graph.onnx_graph_node import OnnxGraphNode | from ..third_party_graph.onnx_graph_node import OnnxGraphNode | ||||
| __all__ = [ | |||||
| "HierarchicalTreeFactory" | |||||
| ] | |||||
| from ...common.exceptions import NodeInputMissing, TreeNodeInsertFail | from ...common.exceptions import NodeInputMissing, TreeNodeInsertFail | ||||
| @@ -13,8 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Mapper module.""" | """Mapper module.""" | ||||
| from .base import ONNXToMindSporeMapper | |||||
| __all__ = [ | |||||
| "ONNXToMindSporeMapper" | |||||
| ] | |||||
| __all__ = ["ONNXToMindSporeMapper"] | |||||
| from .base import ONNXToMindSporeMapper | |||||
| @@ -577,12 +577,7 @@ class SearchPath: | |||||
| self.graph.precursor_table[s_nd] = p_nodes | self.graph.precursor_table[s_nd] = p_nodes | ||||
| def evaluate_score(self): | def evaluate_score(self): | ||||
| """ | |||||
| Evaluate path score. | |||||
| Expression = 0.7 * (0.1 * bonus + 0.9 * repl_ratio) + 0.3 * H | |||||
| = 0.07 * bonus + 0.63 * repl_ratio + 0.3 * H | |||||
| """ | |||||
| """Evaluate path score.""" | |||||
| return .7 * self.actual_v + .3 * self.heuristic_v | return .7 * self.actual_v + .3 * self.heuristic_v | ||||
| def _cal_merged_module_length(self, ptn): | def _cal_merged_module_length(self, ptn): | ||||
| @@ -13,6 +13,9 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| """Graph associated definition module.""" | """Graph associated definition module.""" | ||||
| __all__ = ["GraphFactory", "PyTorchGraphNode"] | |||||
| from .base import Graph | from .base import Graph | ||||
| from .pytorch_graph import PyTorchGraph | from .pytorch_graph import PyTorchGraph | ||||
| from .pytorch_graph_node import PyTorchGraphNode | from .pytorch_graph_node import PyTorchGraphNode | ||||
| @@ -44,9 +47,3 @@ class GraphFactory: | |||||
| output_nodes=output_nodes, sample_shape=sample_shape) | output_nodes=output_nodes, sample_shape=sample_shape) | ||||
| return PyTorchGraph.load(model_path=graph_path, sample_shape=sample_shape) | return PyTorchGraph.load(model_path=graph_path, sample_shape=sample_shape) | ||||
| __all__ = [ | |||||
| "GraphFactory", | |||||
| "PyTorchGraphNode", | |||||
| ] | |||||