|
|
|
@@ -14,6 +14,7 @@ |
|
|
|
# ============================================================================== |
|
|
|
"""Define PyTorch graph.""" |
|
|
|
import re |
|
|
|
from copy import deepcopy |
|
|
|
from typing import Dict, NoReturn |
|
|
|
|
|
|
|
from mindinsight.mindconverter.common.log import logger as log |
|
|
|
@@ -22,7 +23,8 @@ from .input_node import InputNode |
|
|
|
from .pytorch_graph_node import PyTorchGraphNode |
|
|
|
from .pytorch_graph_parser import PyTorchGraphParser |
|
|
|
|
|
|
|
from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID |
|
|
|
from ..constant import SEPARATOR_IN_SCOPE, LINK_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, SCALAR_WITHOUT_SHAPE, \ |
|
|
|
MIN_SCOPE_LENGTH |
|
|
|
from ..constant import LEFT_BUCKET, RIGHT_BUCKET |
|
|
|
|
|
|
|
NONE_SCOPE_OP = { |
|
|
|
@@ -206,6 +208,8 @@ class PyTorchGraph(Graph): |
|
|
|
) |
|
|
|
self.build_connection(node_input_name, node_name) |
|
|
|
|
|
|
|
self._unmerge_multi_ipt_opt_script() |
|
|
|
|
|
|
|
super(PyTorchGraph, self).build(input_shape=input_shape) |
|
|
|
self._collect_ipt_shape_of_each_node(feed_forward_ipt_shape) |
|
|
|
|
|
|
|
@@ -227,13 +231,84 @@ class PyTorchGraph(Graph): |
|
|
|
input_node.set_successor_nodes(node_name) |
|
|
|
self._shape_dict[ipt_nd_name] = input_node.output_shape |
|
|
|
|
|
|
|
if not self._shape_dict[node_name]: |
|
|
|
self._shape_dict[node_name] = SCALAR_WITHOUT_SHAPE |
|
|
|
|
|
|
|
ipt_shape = [] |
|
|
|
for p_nd in node.precursor_nodes: |
|
|
|
shp = self._shape_dict.get(p_nd) |
|
|
|
ipt_shape.append(tuple(shp)) |
|
|
|
ipt_shape.append(tuple(shp) if isinstance(shp, list) else shp) |
|
|
|
|
|
|
|
self._input_shape[node_name] = ipt_shape[0] if len(ipt_shape) == 1 else ipt_shape |
|
|
|
|
|
|
|
def _generate_module(self): |
|
|
|
"""Generate modules.""" |
|
|
|
module_dict = dict() |
|
|
|
for node_key, _ in self._nodes_collection.items(): |
|
|
|
node_key_in_scope = node_key.split(SEPARATOR_IN_SCOPE) |
|
|
|
if len(node_key_in_scope) < MIN_SCOPE_LENGTH: |
|
|
|
continue |
|
|
|
|
|
|
|
for idx in range(1, len(node_key_in_scope)): |
|
|
|
node_key_module = SEPARATOR_IN_SCOPE.join(node_key_in_scope[:idx]) |
|
|
|
node_name = SEPARATOR_IN_SCOPE.join(node_key_in_scope[:idx+1]) |
|
|
|
if not module_dict.get(node_key_module, None): |
|
|
|
module_dict[node_key_module] = {node_name} |
|
|
|
else: |
|
|
|
module_dict[node_key_module].add(node_name) |
|
|
|
|
|
|
|
return module_dict |
|
|
|
|
|
|
|
def _check_multi_ipt(self): |
|
|
|
"""Check whether multi-input exists.""" |
|
|
|
module_dict = self._generate_module() |
|
|
|
for _, nodes_per_module in module_dict.items(): |
|
|
|
prcs_nodes_out_from_module = set() |
|
|
|
for node_name in nodes_per_module: |
|
|
|
node = self._nodes_collection.get(node_name, None) |
|
|
|
if node: |
|
|
|
prcs_nodes = node.precursor_nodes |
|
|
|
else: |
|
|
|
continue |
|
|
|
|
|
|
|
for prcs_node in prcs_nodes: |
|
|
|
if prcs_node not in nodes_per_module: |
|
|
|
prcs_node_module = SEPARATOR_IN_SCOPE.join(prcs_node.split(SEPARATOR_IN_SCOPE)[:-1]) |
|
|
|
if prcs_node_module not in nodes_per_module: |
|
|
|
prcs_nodes_out_from_module.add(prcs_node) |
|
|
|
|
|
|
|
if len(prcs_nodes_out_from_module) > 1: |
|
|
|
return True |
|
|
|
|
|
|
|
return False |
|
|
|
|
|
|
|
def _unmerge_multi_ipt_opt_script(self): |
|
|
|
"""Unmerge all submodule.""" |
|
|
|
if self._check_multi_ipt(): |
|
|
|
for node_key, node_inst in deepcopy(self._nodes_collection).items(): |
|
|
|
prsc_nodes = node_inst.precursor_nodes |
|
|
|
scsr_nodes = node_inst.successor_nodes |
|
|
|
|
|
|
|
node_inst.precursor_nodes = [SEPARATOR_IN_SCOPE.join((prsc_node.split(SEPARATOR_IN_SCOPE)[0], |
|
|
|
prsc_node.split(SEPARATOR_IN_SCOPE)[-1])) |
|
|
|
for prsc_node in deepcopy(prsc_nodes)] |
|
|
|
node_inst.successor_nodes = [SEPARATOR_IN_SCOPE.join((scsr_node.split(SEPARATOR_IN_SCOPE)[0], |
|
|
|
scsr_node.split(SEPARATOR_IN_SCOPE)[-1])) |
|
|
|
for scsr_node in deepcopy(scsr_nodes)] |
|
|
|
|
|
|
|
reduce_node_key = SEPARATOR_IN_SCOPE.join((node_key.split(SEPARATOR_IN_SCOPE)[0], |
|
|
|
node_key.split(SEPARATOR_IN_SCOPE)[-1])) |
|
|
|
|
|
|
|
del self._nodes_collection[node_key] |
|
|
|
self._nodes_collection[reduce_node_key] = node_inst |
|
|
|
|
|
|
|
for node_key, shape in deepcopy(self._shape_dict).items(): |
|
|
|
reduce_node_key = SEPARATOR_IN_SCOPE.join((node_key.split(SEPARATOR_IN_SCOPE)[0], |
|
|
|
node_key.split(SEPARATOR_IN_SCOPE)[-1])) |
|
|
|
|
|
|
|
del self._shape_dict[node_key] |
|
|
|
self._shape_dict[reduce_node_key] = shape |
|
|
|
|
|
|
|
def sub_graph_merging(self): |
|
|
|
""" |
|
|
|
Merge split operation into one. |
|
|
|
|