Browse Source

!1206 Support params sharing in mindconverter

From: @liuchongming74
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
6c4e1a67b2
9 changed files with 130 additions and 148 deletions
  1. +16
    -55
      mindinsight/mindconverter/graph_based_converter/common/code_fragment.py
  2. +3
    -0
      mindinsight/mindconverter/graph_based_converter/common/global_context.py
  3. +1
    -0
      mindinsight/mindconverter/graph_based_converter/constant.py
  4. +0
    -23
      mindinsight/mindconverter/graph_based_converter/generator/module_struct.py
  5. +25
    -1
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py
  6. +38
    -8
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py
  7. +31
    -1
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py
  8. +1
    -57
      mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py
  9. +15
    -3
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py

+ 16
- 55
mindinsight/mindconverter/graph_based_converter/common/code_fragment.py View File

@@ -163,61 +163,6 @@ class Fragment(abc.ABC):
return self._output_shape return self._output_shape




class CodeFragment(Fragment):
"""
Manage the variables related with code generation.

For single operation type node, the variables in `CodeLine` stands for:
```python
class Module(nn.Cell):
def __init__ (self, ...):
super(Module, self).__init__()
self.<CodeLine.declared_variable_name> = <CodeLine.operation>(<CodeLine.scalar_args>,
<CodeLine.init_trainable_params>)
self.<CodeLine.trainable_params[k].param_name> = Tensor(<CodeLine.trainable_params[k].shape>,
dtype=<CodeLine._trainable_params[k].dtype>)

def construct(self, x, ...):
<CodeLine.output_var_name> = self.<CodeLine.declared_variable_name>(<CodeLine.operation_inputs>)
...
return output
```

Args:
operation (str): Operation name in MindSpore.
actual_args (dict): Actual arg values.
settings (namedTuple): Code generation setting.

"""

def __init__(self, operation, actual_args, settings, input_shape, output_shape,
trainable_params=None, trainable_weights=None):
super(CodeFragment, self).__init__(operation=operation, actual_args=actual_args,
input_shape=input_shape, output_shape=output_shape,
settings=settings)
self._trainable_params = dict() # External weights, like Matmul.
self._init_trainable_params = trainable_params # Can put into operation init method, like Conv2d.
self._trainable_weights = trainable_weights

@property
def trainable_params(self):
"""Return the trainable parameters."""
return self._trainable_params

@property
def trainable_weights(self):
return self._trainable_weights


class ModuleFragment(Fragment):
"""Manage module type code variables."""

def __init__(self, operation, actual_args, settings, input_shape, output_shape):
super(ModuleFragment, self).__init__(operation=operation, actual_args=actual_args,
input_shape=input_shape, output_shape=output_shape,
settings=settings)


class NewFragment: class NewFragment:
""" """
Fragment definition for MindSpore code generation. Fragment definition for MindSpore code generation.
@@ -310,6 +255,12 @@ class NewFragment:
return f"{opt}[{inner_idx}]" return f"{opt}[{inner_idx}]"
return opt return opt


@staticmethod
def create_parameter(weight_shape, weight_dtype):
"""Create a parameter code line."""
return f"Parameter(Tensor(np.random.uniform(0, 1, {weight_shape}).astype(np.{weight_dtype})), " \
f"name=None)"

def __call__(self) -> Tuple[List[str], List[str]]: def __call__(self) -> Tuple[List[str], List[str]]:
""" """
Define parameter rewrite function. Define parameter rewrite function.
@@ -334,6 +285,10 @@ class NewFragment:
ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value) ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value)
return init_stats, call_stats return init_stats, call_stats


def register_parameter(self, var, line):
"""Append a new parameter into template."""
self._code_template[var][TemplateKeywords.INIT.value].append(line)

@staticmethod @staticmethod
def _rewrite(var, data, template: str) -> str: def _rewrite(var, data, template: str) -> str:
""" """
@@ -353,6 +308,12 @@ class NewFragment:
data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value]) data[ExchangeMessageKeywords.VariableScope.value.INPUTS.value])
if ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value in data: if ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value in data:
rewrite_data.update(data[ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value]) rewrite_data.update(data[ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value])
if ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value in data:
rewrite_params = {
f"{var}/{slot}": data[ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value][slot]
for slot in data[ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value]
}
rewrite_data.update(rewrite_params)
rewrite_data.update(data[ExchangeMessageKeywords.VariableScope.value.ARGS.value]) rewrite_data.update(data[ExchangeMessageKeywords.VariableScope.value.ARGS.value])
return template.format(**{ return template.format(**{
k: str(rewrite_data[k]) for k in rewrite_data k: str(rewrite_data[k]) for k in rewrite_data


+ 3
- 0
mindinsight/mindconverter/graph_based_converter/common/global_context.py View File

@@ -81,6 +81,9 @@ class GlobalContext(metaclass=Singleton):


self.outputs_storage = OutputStorage() self.outputs_storage = OutputStorage()


# Record weights name that used many times.
self.repeated_weights = dict()

def get_onnx_node_from_identifier(self, identifier): def get_onnx_node_from_identifier(self, identifier):
"""Return an OnnxUtils defined node by its identifier.""" """Return an OnnxUtils defined node by its identifier."""
onnx_node_name = self.node_struct_to_onnx_node_map.get(identifier) onnx_node_name = self.node_struct_to_onnx_node_map.get(identifier)


+ 1
- 0
mindinsight/mindconverter/graph_based_converter/constant.py View File

@@ -107,6 +107,7 @@ class ExchangeMessageKeywords(Enum):
ARGS = "args" ARGS = "args"
WEIGHTS = "weights" WEIGHTS = "weights"
TRAINABLE_PARAMS = "trainable_params" TRAINABLE_PARAMS = "trainable_params"
PARAMETERS_DECLARED = "parameters"




BINARY_HEADER_PYTORCH_FILE = \ BINARY_HEADER_PYTORCH_FILE = \


+ 0
- 23
mindinsight/mindconverter/graph_based_converter/generator/module_struct.py View File

@@ -21,7 +21,6 @@ from mindinsight.mindconverter.graph_based_converter.generator.node_struct impor
from mindinsight.mindconverter.graph_based_converter.generator.scope_utils import Scope from mindinsight.mindconverter.graph_based_converter.generator.scope_utils import Scope
from mindinsight.mindconverter.graph_based_converter.common.utils import get_dict_key_by_value from mindinsight.mindconverter.graph_based_converter.common.utils import get_dict_key_by_value
from mindinsight.mindconverter.graph_based_converter.generator.args_translator import ArgsTranslation from mindinsight.mindconverter.graph_based_converter.generator.args_translator import ArgsTranslation
from mindinsight.mindconverter.graph_based_converter.common.code_fragment import ModuleFragment
from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext
from mindinsight.mindconverter.graph_based_converter.common.name_mgr import LocalVarNameMgr from mindinsight.mindconverter.graph_based_converter.common.name_mgr import LocalVarNameMgr


@@ -184,28 +183,6 @@ class ModuleStruct:
"initialized": self.initialized "initialized": self.initialized
}) })


def init_module_fragment(self):
"""Init the module fragment."""
if not self.initialized:
return
# check if fragment exists in global context
op = "Module{}".format(self.pattern_id)
if op == "Module-1": # reset as Main Model's op name
op = "Model"
frag = GlobalContext().get_module_fragment(op)
if frag is not None: # use exists fragment
self._fragment = frag
else:
frag = ModuleFragment(operation=op,
actual_args=None,
input_shape=None,
output_shape=None,
settings=None)
self._fragment = frag
# set fragment pattern
self._fragment.pattern = self._node_structs
GlobalContext().add_module_fragment(op, frag)

def init_args_translator(self): def init_args_translator(self):
"""Initialize the Args Translator for the module.""" """Initialize the Args Translator for the module."""
var_name = self.ms_var_name var_name = self.ms_var_name


+ 25
- 1
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py View File

@@ -14,7 +14,10 @@
# ============================================================================== # ==============================================================================
"""Introduce some standard pattern into MindConverter.""" """Introduce some standard pattern into MindConverter."""


__all__ = ["BUILT_IN_PATTERN", "register_pattern", "is_built_in_pattern"]
__all__ = ["BUILT_IN_PATTERN", "register_pattern", "is_built_in_pattern",
"USER_DEFINED_PATTERN", "user_defined_pattern"]

from collections import OrderedDict


from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.known_module_name import register_module_name from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.known_module_name import register_module_name


@@ -22,6 +25,7 @@ from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common i
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern


BUILT_IN_PATTERN = dict() BUILT_IN_PATTERN = dict()
USER_DEFINED_PATTERN = OrderedDict()




def is_built_in_pattern(pattern: Pattern): def is_built_in_pattern(pattern: Pattern):
@@ -75,6 +79,26 @@ def register_pattern(ptn_name, in_degree, out_degree):
return _reg return _reg




def user_defined_pattern(pattern_name: str):
"""
Register user define pattern to MindConverter.

Args:
pattern_name (str): Pattern name.
"""

def _f(ptn):
pattern = ptn()
if not pattern:
raise ValueError("`ptn` cannot be None.")
if not pattern_name:
raise ValueError("`pattern_name` cannot be None.")
USER_DEFINED_PATTERN[pattern_name] = pattern
return ptn

return _f


@register_pattern("ConvBnClip", 1, 1) @register_pattern("ConvBnClip", 1, 1)
def _conv_bn_clip(): def _conv_bn_clip():
"""Add conv-bn-clip pattern.""" """Add conv-bn-clip pattern."""


+ 38
- 8
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py View File

@@ -424,7 +424,18 @@ def _post_process_overlap(patterns) -> Dict:
return patterns return patterns




class SearchPath:
class BasePath:
"""Base class of SearchPath (auto-search) and ReplacePath (greedy-match)."""

def __init__(self, pattern, sequence: List, prev_path=None):
self.pattern = pattern
self.recursion_path = prev_path.recursion_path[:] if prev_path is not None else list()
if prev_path is not None:
self.recursion_path.append(prev_path)
self.topo_order_bef_repl = sequence


class SearchPath(BasePath):
""" """
Use SearchPath to store the search path. Use SearchPath to store the search path.


@@ -439,15 +450,9 @@ class SearchPath:


def __init__(self, pattern, sequence: List[BaseNode], prev_path=None, def __init__(self, pattern, sequence: List[BaseNode], prev_path=None,
graph=None, sub_graph_size: int = 2): graph=None, sub_graph_size: int = 2):
self.pattern = pattern
super(SearchPath, self).__init__(pattern, sequence, prev_path)
self.graph = copy.copy(prev_path.graph) if prev_path is not None \ self.graph = copy.copy(prev_path.graph) if prev_path is not None \
else copy.copy(graph) else copy.copy(graph)
self.recursion_path = prev_path.recursion_path[:] \
if prev_path is not None else list()
if prev_path is not None:
self.recursion_path.append(prev_path)

self.topo_order_bef_repl = sequence
self.topo_order_aft_repl, self.inverted_index = self._create_new_order() self.topo_order_aft_repl, self.inverted_index = self._create_new_order()
self.node_collection = dict() self.node_collection = dict()
self.hash_of_aft_repl = gen_hash_key(self.topo_order_aft_repl) self.hash_of_aft_repl = gen_hash_key(self.topo_order_aft_repl)
@@ -689,3 +694,28 @@ class SearchPath:
f"H: {self.heuristic_v}, G: {self.actual_v}, E: {self.evaluate_score()}" f"H: {self.heuristic_v}, G: {self.actual_v}, E: {self.evaluate_score()}"


return repr_str return repr_str


class ReplacePath(BasePath):
"""Data struct of replacing path with greedy matching."""

def __init__(self, pattern, sequence: List, prev_path=None):
super(ReplacePath, self).__init__(pattern, sequence, prev_path)
self.topo_order_aft_repl = None

def replace(self, increment_idx):
"""
Greedy matching.

Args:
increment_idx (int): To deduplicate module name.
"""
src = ",".join(self.topo_order_bef_repl)
tgt = self.pattern.pattern
md_name = f"Module{increment_idx}"
src_aft_repl = src.replace(tgt, md_name)
if src != src_aft_repl:
self.pattern.module_name = md_name
self.topo_order_aft_repl = src_aft_repl.split(",")
return md_name
return None

+ 31
- 1
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py View File

@@ -16,6 +16,7 @@
from queue import PriorityQueue, Queue from queue import PriorityQueue, Queue
from typing import Dict, List from typing import Dict, List


from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.built_in_pattern import USER_DEFINED_PATTERN
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern_fuzzy_matching import \ from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern_fuzzy_matching import \
pattern_fuzzy_matching pattern_fuzzy_matching
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import context, DagGraph, gen_hash_key, \ from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import context, DagGraph, gen_hash_key, \
@@ -25,7 +26,7 @@ from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common i
from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext
from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import BaseNode from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import BaseNode
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.search_path import SearchPath, Pattern, \ from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.search_path import SearchPath, Pattern, \
generate_pattern, find_built_in_pattern
generate_pattern, find_built_in_pattern, ReplacePath
from mindinsight.mindconverter.common.exceptions import SubGraphSearchingError from mindinsight.mindconverter.common.exceptions import SubGraphSearchingError




@@ -371,6 +372,7 @@ def _build_connection(loader):
context.precursor_table[node_name] = list(node.get_precursor_dict().keys()) context.precursor_table[node_name] = list(node.get_precursor_dict().keys())
context.successor_table[node_name] = list(node.get_successor_dict().keys()) context.successor_table[node_name] = list(node.get_successor_dict().keys())
context.outputs_table[node_name] = node.output_name_list context.outputs_table[node_name] = node.output_name_list

# Record the model inputs count, use it to control the search algorithm. # Record the model inputs count, use it to control the search algorithm.
context.has_multi_inputs = len(loader.input_nodes) > 1 context.has_multi_inputs = len(loader.input_nodes) > 1
dag = DagGraph(nodes=context.node_collection.copy(), dag = DagGraph(nodes=context.node_collection.copy(),
@@ -426,6 +428,28 @@ def _add_known_module_name(search_path):
return ctx return ctx




def greedy_match(topo_order, user_defined_ptn):
"""
Greedy replace topological order with given pattern by user.

Args:
topo_order (list[str]): Topological order sequence.
user_defined_ptn (dict): User defined pattern.
"""
increment_idx = 0
prev_path = None
for md_name, ptn in user_defined_ptn:
ptn = Pattern(",".join(ptn), len(ptn), -1, -1, ptn)
ptn.known_module_name = md_name
topo_order_aft_rpl = topo_order[:] if prev_path is None else prev_path.topo_order_aft_repl
repl_path = ReplacePath(ptn, topo_order_aft_rpl, prev_path=prev_path)
module_name = repl_path.replace(increment_idx)
if module_name is not None:
increment_idx += 1
prev_path = repl_path
return prev_path


@SubGraphSearchingError.check_except("Sub-Graph pattern searching fail.") @SubGraphSearchingError.check_except("Sub-Graph pattern searching fail.")
def generate_scope_name(data_loader): def generate_scope_name(data_loader):
""" """
@@ -439,6 +463,12 @@ def generate_scope_name(data_loader):
""" """
init_dag = _build_connection(data_loader) init_dag = _build_connection(data_loader)
try: try:
if USER_DEFINED_PATTERN:
topo_order = [node for _, node in context.node_collection.items()]
repl_path = greedy_match(topo_order, USER_DEFINED_PATTERN)
topo_order_with_scope_name_list = _retrieve_scope_name(repl_path) if repl_path else flatten_graph(init_dag)
return topo_order_with_scope_name_list

result = _sub_graph_matching(init_dag, beam_width=5, sub_graph_size=6) result = _sub_graph_matching(init_dag, beam_width=5, sub_graph_size=6)
topo_order_with_scope_name_list = _retrieve_scope_name(result) if result else flatten_graph(init_dag) topo_order_with_scope_name_list = _retrieve_scope_name(result) if result else flatten_graph(init_dag)




+ 1
- 57
mindinsight/mindconverter/graph_based_converter/third_party_graph/base.py View File

@@ -15,14 +15,11 @@
"""Define graph entity.""" """Define graph entity."""
import abc import abc
from collections import OrderedDict from collections import OrderedDict
from copy import deepcopy


from typing import List from typing import List


from mindinsight.mindconverter.common.log import logger as log from mindinsight.mindconverter.common.log import logger as log
from mindinsight.mindconverter.graph_based_converter.common.code_fragment import CodeFragment
from mindinsight.mindconverter.graph_based_converter.constant import NodeType, InputType
from mindinsight.mindconverter.graph_based_converter.mapper.base import Mapper
from mindinsight.mindconverter.graph_based_converter.constant import InputType
from mindinsight.mindconverter.common.exceptions import NodeInputTypeNotSupportError from mindinsight.mindconverter.common.exceptions import NodeInputTypeNotSupportError




@@ -574,56 +571,3 @@ class GraphNode(abc.ABC):
ipt_args_settings_in_construct = ', '.join((ipt_args_settings_in_construct, settings_in_construct)) ipt_args_settings_in_construct = ', '.join((ipt_args_settings_in_construct, settings_in_construct))


return ipt_args_settings_in_construct return ipt_args_settings_in_construct

def param_transform(self, mapper: Mapper, variable_name):
"""
Transform param in PyTorch operation into MindSpore.

Args:
variable_name (str): Variable name.
mapper (ONNXToMindSporeMapper): Mapper between onnx operation
and MindSpore.

Returns:
dict, transformed params.
"""
if self._node_type != NodeType.OPERATION.value:
args = deepcopy(self._args_in_code)
self._args_in_code = dict()
for arg, value in args.items():
self._args_in_code[self._get_arg_name(arg, variable_name)] = value
return CodeFragment(operation="", actual_args=args, settings=None,
input_shape=self.input_shape, output_shape=self.output_shape)

if self.transformed:
raise ValueError("Already transformed.")

params = deepcopy(self._op_params)
params.update({"input_shape": self.input_shape,
"output_shape": self.output_shape})

ms_op, ms_params, ms_settings, ms_weights = mapper.convert(op_name=self.op_name,
params=params,
weights=self._weight)

if ms_op:
code_fragment = CodeFragment(operation=ms_op,
actual_args=ms_params,
settings=ms_settings,
input_shape=self.input_shape,
output_shape=self.output_shape,
trainable_params=ms_weights)
else:
code_fragment = CodeFragment(operation=self._op_name,
actual_args=self._op_params,
settings=None,
input_shape=self.input_shape,
output_shape=self.output_shape,
trainable_params=self._weight)

for arg, value in code_fragment.actual_args.items():
self._args_in_code[self._get_arg_name(arg, variable_name)] = value

self.transformed = True

return code_fragment

+ 15
- 3
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py View File

@@ -262,7 +262,6 @@ class OnnxDataLoader:
self.model = onnx_model_sim self.model = onnx_model_sim
self.graph = onnx_model_sim.graph self.graph = onnx_model_sim.graph
self.nodes = onnx_model_sim.graph.node self.nodes = onnx_model_sim.graph.node
self.batch_size = list(input_nodes.values())[0][0]
self.input_nodes = input_nodes self.input_nodes = input_nodes
self.output_nodes = output_nodes self.output_nodes = output_nodes
# args for init # args for init
@@ -275,6 +274,9 @@ class OnnxDataLoader:
self.tensors_dict = {} # {tensor_name: OnnxTensor} self.tensors_dict = {} # {tensor_name: OnnxTensor}
self.value_info_dict = {} # Not contains input and output nodes self.value_info_dict = {} # Not contains input and output nodes


# Record the weight names used many times.
self.repeated_weight = dict()

self.node_output_shape_dict = OrderedDict() # {node_name: [int]} self.node_output_shape_dict = OrderedDict() # {node_name: [int]}


# Key is edge of ONNX ir graph, value is the corresponding precursor node. # Key is edge of ONNX ir graph, value is the corresponding precursor node.
@@ -393,6 +395,7 @@ class OnnxDataLoader:
def _parse_nodes(self): def _parse_nodes(self):
"""Parse each onnx nodes in the model.""" """Parse each onnx nodes in the model."""
nodes_topo_idx = [] nodes_topo_idx = []
record_tensors = dict()
for idx, node in enumerate(self.nodes): for idx, node in enumerate(self.nodes):
if not node.name: if not node.name:
node.name = "_".join(node.output) node.name = "_".join(node.output)
@@ -408,11 +411,21 @@ class OnnxDataLoader:
self._global_context.onnx_node_inputs[n.name].append(ipt_nd) self._global_context.onnx_node_inputs[n.name].append(ipt_nd)
else: else:
self._global_context.onnx_node_inputs[n.name] = [ipt_nd] self._global_context.onnx_node_inputs[n.name] = [ipt_nd]
if ipt_nd in self.tensors_dict:
if ipt_nd not in record_tensors:
record_tensors[ipt_nd] = [node.name]
continue
record_tensors[ipt_nd].append(node.name)
self.repeated_weight.setdefault(ipt_nd, [])


self._global_context.onnx_node_name_to_topo_idx[n.name] = idx self._global_context.onnx_node_name_to_topo_idx[n.name] = idx


for k in self.repeated_weight:
self.repeated_weight[k] = record_tensors[k][:]

self._global_context.onnx_nodes_collection = self._nodes_dict self._global_context.onnx_nodes_collection = self._nodes_dict
self._global_context.onnx_nodes_topo_index = nodes_topo_idx self._global_context.onnx_nodes_topo_index = nodes_topo_idx
self._global_context.repeated_weights = self.repeated_weight


def _parse_tensors(self): def _parse_tensors(self):
"""Parse each onnx tensors in the model.""" """Parse each onnx tensors in the model."""
@@ -436,8 +449,7 @@ class OnnxDataLoader:
for i, s in enumerate(shape): for i, s in enumerate(shape):
if 'unk' in s: if 'unk' in s:
# Have to adapt user-define axis name, e.g. 'sequence', 'batch'. # Have to adapt user-define axis name, e.g. 'sequence', 'batch'.
shape[i] = self.batch_size if self.batch_size is not None else 1
continue
raise ValueError(f"cannot get shape of {node_opt_name}.")
if s == "scalar": if s == "scalar":
shape = SCALAR_WITHOUT_SHAPE shape = SCALAR_WITHOUT_SHAPE
continue continue


Loading…
Cancel
Save