Browse Source

!1318 add graph domain proto

From: @liangyongxiong1024
Reviewed-by: 
Signed-off-by:
pull/1318/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
b728d0d24d
7 changed files with 2330 additions and 0 deletions
  1. +14
    -0
      mindinsight/domain/__init__.py
  2. +14
    -0
      mindinsight/domain/graph/__init__.py
  3. +525
    -0
      mindinsight/domain/graph/base.py
  4. +45
    -0
      mindinsight/domain/graph/exceptions.py
  5. +325
    -0
      mindinsight/domain/graph/proto/ms_graph.proto
  6. +1400
    -0
      mindinsight/domain/graph/proto/ms_graph_pb2.py
  7. +7
    -0
      mindinsight/utils/constant.py

+ 14
- 0
mindinsight/domain/__init__.py View File

@@ -0,0 +1,14 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

+ 14
- 0
mindinsight/domain/graph/__init__.py View File

@@ -0,0 +1,14 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

+ 525
- 0
mindinsight/domain/graph/base.py View File

@@ -0,0 +1,525 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Base module."""

import os
import re
import enum
import collections

import numpy as np

from mindinsight.domain.graph.exceptions import UnknownTensorError


class MindSporeType(enum.Enum):
"""MindSpore Type."""
INT = 'int32'
UINT = 'uint'
FLOAT = 'float16'
TENSOR = 'tensor'


class DeviceType(enum.Enum):
"""Device Type."""
ASCEND = 'ascend'
GPU = 'gpu'


class DumpType(enum.Enum):
"""Dump Type."""
E2E = 'e2e'
ASYNC = 'async'


class Tensor:
"""
Tensor object of dump file.

Args:
op_id (str): Operator ID.
index (int): Index of operator inputs/outputs.
file_path (str): Absolute file path of tensor file.
"""

FEATURE = collections.namedtuple('Feature', ['type', 'id', 'io'])
VALUE = collections.namedtuple('Value', ['index', 'shape', 'dtype', 'path'])

@classmethod
def extract_shape_from_str(cls, shape_str):
"""
Extract shape from tensor file.

Args:
shape_str (str): Shape string.

Returns:
tuple, shape of tensor file.
"""
shape = tuple([int(dim.strip()) for dim in shape_str.strip('_').split('_')])
# The shape info in dump file name is (0,) which is inconsistent with the actual tensor shape.
# The shape needs to be converted to (1,).
if shape == (0,):
shape = (1,)
return shape

@classmethod
def parse_tensor_file_name(cls, file_name):
"""
Parse tensor file name.

Args:
file_name (str): Tensor file name.

Returns:
bool, indicating if node is operator.
dict, tensor file info.

Raises:
UnknownTensorError: If tensor file name can not be recognized.
"""
is_op = False
is_npy = file_name.endswith('.npy')
if re.search(r'-op\d+(_|(\.\d+\.\d+\.))(input|output)(_|\.)', file_name):
is_op = True
dump_type = DumpType.E2E
if re.search(r'-op(?P<op_id>\d+)\.(?P<stream_id>\d+)\.(?P<task_id>\d+)', file_name):
dump_type = DumpType.ASYNC

if dump_type == DumpType.ASYNC:
file_name = file_name[file_name.find('.')+1:]
if is_npy:
regex = r'_(?P<op_name>[A-Za-z0-9]+)-op(?P<op_id>\d+)' \
r'\.(?P<stream_id>\d+)\.(?P<task_id>\d+)' \
r'\.(?P<io>input|output)' \
r'\.(?P<index>\d+)' \
r'\.npy$'
else:
regex = r'_(?P<op_name>[A-Za-z0-9]+)-op(?P<op_id>\d+)' \
r'\.(?P<stream_id>\d+)\.(?P<task_id>\d+)' \
r'\.(?P<io>input|output)' \
r'\.(?P<index>\d+)' \
r'\.(?P<shape>[0-9\_]+)' \
r'\.(?P<dtype>bool|((uint|int|float)\d+))' \
r'\.(?P<format>[A-Za-z0-9\_]+)\.bin$'

else:
regex = r'--(?P<op_name>[A-Za-z0-9\_]+)-op(?P<op_id>\d+)' \
r'_(?P<io>input|output)' \
r'_(?P<index>\d+)' \
r'_shape_(?P<shape>[0-9\_]+)' \
r'_.*(?P<dtype>Bool|((UInt|Int|Float)\d+))' \
r'_(?P<format>[A-Za-z0-9\_]+)\.bin$'

else:
regex = r'^(?P<node_name>[A-Za-z0-9\.\_]+)' \
r'_(?P<io>input|output)' \
r'_(?P<index>\d+)' \
r'_shape_(?P<shape>[0-9\_]+)' \
r'_.*(?P<dtype>Bool|((UInt|Int|Float)\d+))' \
r'_(?P<format>[A-Za-z0-9\_]+)\.bin$'

pattern = re.search(regex, file_name)
if pattern is None:
raise UnknownTensorError(is_op, file_name)

info = pattern.groupdict()
info['index'] = int(info['index'])
info['shape'] = None if is_npy else cls.extract_shape_from_str(info['shape'])
info['dtype'] = None if is_npy else info['dtype'].lower()
return is_op, info

@classmethod
def scan_tensors(cls, tensor_dir):
"""
Scan tensors.

Args:
tensor_dir (str): Directory path where holds the tensor files.
check (lambda): Function to check tensor values.

Returns:
dict, tensor file mapping.
"""
tensor_mapping = {}
if not tensor_dir:
return tensor_mapping

file_names = os.listdir(tensor_dir)
for file_name in file_names:
full_path = os.path.join(tensor_dir, file_name)
if not re.search(r'\.(bin|npy)$', file_name) or os.path.isdir(full_path):
continue

try:
is_op, info = cls.parse_tensor_file_name(file_name)
except UnknownTensorError:
continue

if is_op:
feature = cls.FEATURE(type=info['op_name'], id=info['op_id'], io=info['io'])
else:
feature = cls.FEATURE(type='', id=info['node_name'], io=info['io'])

value = cls.VALUE(index=info['index'], shape=info['shape'], dtype=info['dtype'], path=full_path)
tensors = tensor_mapping.get(feature)
if tensors:
tensor_mapping[feature].append(value)
tensor_mapping[feature].sort(key=lambda x: x[0])
else:
tensor_mapping[feature] = [value]

return tensor_mapping

def __init__(self, op_id, index, file_path):
self.op_id = op_id
self.index = index
self.file_path = file_path

def load(self):
"""
Load tensor file.

Returns:
ndarray, tensor data.
"""
if self.file_path.endswith('.npy'):
tensor = np.load(self.file_path)
return tensor

metas = self.metas
if metas is None:
return None
dtype = getattr(np, metas['dtype'])
tensor = np.fromfile(self.file_path, dtype=dtype)
try:
tensor = tensor.reshape(metas['shape'])
except ValueError:
pass
return tensor

@property
def metas(self):
"""
Metas property.

Returns:
dict, metas extracted from tensor file name.
"""
file_name = os.path.basename(self.file_path)
try:
is_op, info = self.parse_tensor_file_name(file_name)
except UnknownTensorError:
return None

if is_op:
info.pop('op_name')
info.pop('op_id')
else:
info.pop('node_name')

if file_name.endswith('.npy'):
info.pop('dtype')
info.pop('shape')

return info

@property
def full_name(self):
"""
Full name property.

Returns:
str, full name.
"""
full_name_str, _ = os.path.basename(self.file_path).split('_output_')
return full_name_str.replace('--', '/')

@property
def scope(self):
"""
Scope property.

Returns:
str, scope.
"""
return os.path.dirname(self.full_name)

def __repr__(self):
return str({
'op_id': self.op_id,
'index': self.index,
'file_path': self.file_path,
})


class NodeType(enum.Enum):
"""Node Type."""
OPERATOR = 'operator'
PARAMETER = 'parameter'
CONSTANT = 'constant'


class InputType(enum.Enum):
"""Input Type."""
OPERATOR = 'operator'
PARAMETER = 'parameter'
CONSTANT = 'constant'
TENSOR = 'tensor'
SCALAR = 'scalar'
REFERENCE = 'reference'
NONE = 'none'


class OutputType(enum.Enum):
"""Output Type."""
NONE = 'none'
BOOL = 'bool'
INT8 = 'int8'
INT16 = 'int16'
INT32 = 'int32'
INT64 = 'int64'
UINT8 = 'uint8'
UINT16 = 'uint16'
UINT32 = 'uint32'
UINT64 = 'uint64'
FLOAT16 = 'float16'
FLOAT32 = 'float32'
FLOAT64 = 'float64'
TENSOR = 'tensor'
TUPLE = 'tuple'


class Input:
"""
Graph node input.

Args:
input_type (InputType): Input type.
input_name (str): Input name.
"""

def __init__(self, input_type, input_name):
self.type = input_type
self.name = input_name
self.op_id = ''
self.info = None

def __repr__(self):
return str({
'type': self.type,
'name': self.name,
'op_id': self.op_id,
'info': self.info,
})


class Output:
"""
Graph node output.

Args:
output_type (OutputType): Output type.
"""

SCALAR_TYPES = (
OutputType.INT8,
OutputType.INT16,
OutputType.INT32,
OutputType.INT64,
OutputType.UINT8,
OutputType.UINT16,
OutputType.UINT32,
OutputType.UINT64,
OutputType.FLOAT16,
OutputType.FLOAT32,
)

def __init__(self, output_type):
self.type = output_type
if output_type == OutputType.NONE:
self.info = None
elif output_type == OutputType.BOOL:
self.info = dict(value=None)
elif output_type in self.SCALAR_TYPES:
self.info = dict(value=None)
elif output_type == OutputType.TENSOR:
self.info = dict(dtype='', shape=(), tensor=None)
elif output_type == OutputType.TUPLE:
self.info = dict(dtypes=[], shapes=[], tensors=[])

def __repr__(self):
return str({
'type': self.type,
'info': self.info,
})


class Source:
"""
Source address info.

Args:
file_path (str): Absolute path of source file.
line_no (int): Line number of code line in source file.
code_line (int): Code line content.
"""

def __init__(self, file_path, line_no, code_line):
self.file_path = file_path
self.line_no = line_no
self.code_line = code_line

def to_dict(self):
"""Parse to dict."""
return {
'file_path': self.file_path,
'line_no': self.line_no,
'code_line': self.code_line,
}

def __repr__(self):
return str(self.to_dict())

@classmethod
def build_stack_from_source_address(cls, source_address):
"""
Build stack from source address.

Args:
source_address (str): Source address content.

Returns:
list, list of Source objects.
"""
stack = []
for line in source_address.strip().split('\n'):
regex = r'#\sIn\sfile\s(?P<file_path>.+)\((?P<line_no>\d+)\)/(?P<code_line>.+)/'
pattern = re.search(regex, line.strip())
source = pattern.groupdict()
source['line_no'] = int(source['line_no'])
source['code_line'] = source['code_line'].strip()
stack.append(cls(**source))

return stack


class Node:
"""
Graph node.

Args:
name (str): Node name.
"""

def __init__(self, name):
self.name = name
self.output = None
self.downstream = []
self.raw = ''


class Constant(Node):
"""Constant node within graph."""

def __repr__(self):
return str({
'name': self.name,
'output': self.output,
'downstream': self.downstream,
})


class Parameter(Node):
"""Parameter node within graph."""

def __repr__(self):
return str({
'name': self.name,
'output': self.output,
'downstream': self.downstream,
})


class Operator(Node):
"""
Operator node within graph.

Args:
op_name (str): Operator name.
op_type (str): Operator type.
"""

def __init__(self, op_name, op_type):
super().__init__(op_name)
self.type = op_type
self.inputs = []
self.attrs = {}
self.full_name = ''
self.stack = []

@property
def scope(self):
"""
Scope property.

Returns:
str, scope.
"""
return os.path.dirname(self.full_name)

@property
def op_id(self):
"""
Op ID property.

Returns:
str, op ID.
"""
pattern = re.search(r'-op(?P<op_id>\d+)$', self.full_name)
if not pattern:
return self.full_name

info = pattern.groupdict()
return info['op_id']

def __repr__(self):
return str({
'name': self.name,
'type': self.type,
'inputs': self.inputs,
'output': self.output,
'downstream': self.downstream,
'attrs': self.attrs,
'full_name': self.full_name,
'op_id': self.op_id,
})


class Parser:
"""Graph file parser."""

def __init__(self, graph_data=None, tensor_dir=''):
self.graph_data = graph_data
self.tensor_dir = os.path.realpath(tensor_dir) if tensor_dir else ''

self.constants = []
self.parameters = []
self.operators = []
self.tensor_mapping = {}

def parse(self):
"""Parse."""
raise NotImplementedError

+ 45
- 0
mindinsight/domain/graph/exceptions.py View File

@@ -0,0 +1,45 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse exceptions module."""

from mindinsight.utils.exceptions import MindInsightException
from mindinsight.utils.constant import GraphDomainErrors


class UnknownDataTypeError(MindInsightException):
"""Unknwn data type error."""

def __init__(self, proto_type):
super().__init__(
error=GraphDomainErrors.UNKNOWN_DATA_TYPE_ERROR,
message=str(proto_type))


class TupleGetitemIndexError(MindInsightException):
"""Tuple getitem index error."""

def __init__(self, op_name, index_name):
super().__init__(
error=GraphDomainErrors.TUPLE_GETITEM_INDEX_ERROR,
message=f'op : {op_name}, index: {index_name}')


class UnknownTensorError(MindInsightException):
"""Unknwn tensor error."""

def __init__(self, is_op, file_name):
super().__init__(
error=GraphDomainErrors.UNKNOWN_TENSOR_ERROR,
message=f'is_op : {is_op}, file_name: {file_name}')

+ 325
- 0
mindinsight/domain/graph/proto/ms_graph.proto View File

@@ -0,0 +1,325 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

syntax = "proto2";

package mindinsight.domain.graph;

// Versioning
enum Version {
// unknown version
UNKNOWWN_VERSION = 0;

// Initial version (IR VERSION 1), published on Sep 23, 2019
IR_VERSION = 0x0000000000000001;
}

// Data type definition
enum DataType {
DT_UNDEFINED = 0;
// Basic types.
DT_BOOL = 1; // bool

DT_INT8 = 2; // int8_t
DT_INT16 = 3; // int16_t
DT_INT32 = 4; // int32_t
DT_INT64 = 5; // int64_t

DT_UINT8 = 6; // uint8_t
DT_UINT16 = 7; // uint16_t
DT_UINT32 = 8; // uint32_t
DT_UINT64 = 9; // uint64_t

DT_FLOAT16 = 10; // float 16
DT_FLOAT32 = 11; // float 32
DT_FLOAT64 = 12; // float 64

DT_STRING = 13; // string
DT_TENSOR = 14; // tensor
DT_GRAPH = 15; // graph

// list type
DT_BOOLS = 16; // list of bool

DT_INTS8 = 17; // list of int8_t
DT_INTS16 = 18; // list of int16_t
DT_INTS32 = 19; // list of int32_t
DT_INTS64 = 20; // list of int64_t

DT_UINTS8 = 21; // list of uint8_t
DT_UINTS16 = 22; // list of uint16_t
DT_UINTS32 = 23; // list of uint32_t
DT_UINTS64 = 24; // list of uint64_t

DT_FLOATS16 = 25; // list of float16
DT_FLOATS32 = 26; // list of float32
DT_FLOATS64 = 27; // list of float64

DT_STRINGS = 28; // list of string
DT_TENSORS = 29; // list of tensor
DT_GRAPHS = 30; // list of graph

DT_TUPLE = 31; // tuple
DT_LIST = 32; // list
DT_DICT = 33; // dictionary

// other types
DT_NONE = 34; // None
DT_SYM_INST = 35; // Symbolic Key Instance

// type related type
DT_BASE_INT = 36; // type generic int
DT_BASE_UINT = 37; // type generate unsigned int
DT_BASE_FLOAT = 38; // type generate float
DT_TYPE = 39; // type type
DT_ANYTHING = 40; // type anything
DT_REFKEY = 41; // type refkey
DT_REF = 42; // type ref
}

// Value definition for attribute value or parameter default value
message ValueProto {
// data type of value
optional DataType dtype = 1; // discriminator that indicates which field below is in use

// Exactly ONE of the following fields must be present for this version of the IR
optional bool bool_val = 2; // bool
optional int64 int_val = 3; // int
optional uint64 uint_val = 4; // uint
optional float float_val = 5; // float
optional double double_val = 6; // double
optional string str_val = 7; // string
optional TensorProto tensor_val = 8; // tensor value
optional GraphProto graph = 9; // graph

repeated bool bool_vals = 10; // list of bool
repeated int64 int_vals = 11; // list of int
repeated uint64 uint_vals = 12; // list of uint
repeated float float_vals = 13; // list of float
repeated double double_vals = 14; // list of double
repeated string str_vals = 15; // list of string
repeated TensorProto tensor_vals = 16; // list of tensor value
repeated GraphProto graphs = 17; // list of graph

// tuple or list
repeated ValueProto values = 18; // tuple, list of value

// dictionary
repeated NamedValueProto dict_val = 19; // dictionary info

// filed for type type
optional TypeProto type_val = 20; // type type info
}

message AttributeProto {
optional string name = 1; // attribute name
optional ValueProto value = 2; // attribute value
}

message NamedValueProto {
optional string key = 1; // attribute name
optional ValueProto value = 2; // attribute value
}

// Defines a tensor shape.
message TensorShapeProto {
// One dimension of the tensor.
message Dimension {
// Size of the tensor in that dimension.
// This value must be >= -1, but values of -1 are reserved for "unknown"
// shapes (values of -1 mean "unknown" dimension).
optional int64 size = 1;

// Optional name of the tensor dimension.
optional string name = 2;
};

repeated Dimension dim = 1;
}

// Types for graph input(parameter) and output
message TypeProto {

message Tensor {
// This field MUST have a valid DataType value except DT_TENSOR
optional DataType elem_type = 1;
optional TensorShapeProto shape = 2; // for scalar, this field is not set
}

// tuple type
message Sequence {
// The type and optional shape of elements of the tuple.
repeated TypeProto elem_types = 1;
};

// data type
optional DataType data_type = 1;

oneof value {
// The type of a tensor.
Tensor tensor_type = 2;

// The type of a tuple.
Sequence sequence_type = 3;
}
}

// Defines information on graph parameters, including the name, the type, and
// the default value of parameter if exists.
message ParameterProto {
optional string name = 1; // parameter name
optional TypeProto type = 2; // parameter type
optional ValueProto default_val = 3; // default value of parameter if exists
}

// Defines graph output information
message OutputProto {
optional string name = 1; // output node name
optional TypeProto type = 2; // output node type
}

// Define node input information
message InputProto {
enum EdgeType {
DATA_EDGE = 0; // data edge
CONTROL_EDGE = 1; // control edge
}

optional string name = 1;
optional EdgeType type = 2;
}

// Nodes
//
// Computation graphs are made up of a DAG of nodes, which represent what is
// commonly called a "layer" or "pipeline stage" in machine learning frameworks.
//
// For example, it can be a node of type "Conv" that takes in an image, a filter
// tensor and a bias tensor, and produces the convolved output.
message NodeProto {
repeated InputProto input = 1; // namespace Value
optional string name = 2; // namespace Value

// The symbolic identifier of the Operator to execute.
optional string op_type = 3; // namespace Operator
// The domain of the OperatorSet that specifies the operator named by op_type.
optional string scope = 4; // namespace Domain

// Additional named attributes.
repeated AttributeProto attribute = 5;

// Optional type info of this node
optional TypeProto output_type = 6;

// other fields for debug
optional uint64 output_i = 7;

// full name with scope
optional string full_name = 8;

// The corresponding source code for this node.
optional string source_address = 9;
}

// Models
//
// ModelProto is a top-level file/container format for bundling a ML model and
// associating its computation graph with metadata.
//
// The semantics of the model are described by the associated GraphProto.
message ModelProto {
// ir version
optional int64 ir_version = 1;

// Domain name of the model.
// We use reverse domain names as name space indicators. For example:
// `com.facebook.fair` or `com.microsoft.cognitiveservices`
//
// Together with `model_version` and GraphProto.name, this forms the unique identity of
// the graph.
optional string domain = 2;

// The version of the graph encoded. See Version enum below.
optional int64 model_version = 3;

// The parameterized graph that is evaluated to execute the model.
optional GraphProto graph = 4;

// metadata info of operators
optional OperatorSetProto metadata_operators = 5;
};

message OperatorProto {
optional string name = 1; // used as key, must be distinct
optional bytes config = 2; // operator config info
optional bytes obj_info = 3; // operator related object info, e.g. content of operator binary or name
};

message OperatorSetProto {
repeated OperatorProto operators = 1;
optional string summary = 2; // summary info of operators, e.g. file position of operators file
}

// Graphs
//
// A graph defines the computational logic of a model and is comprised of a parameterized
// list of nodes that form a directed acyclic graph based on their inputs and outputs.
// This is the equivalent of the "network" or "graph" in many deep learning
// frameworks.
message GraphProto {
// The nodes in the graph, sorted topologically.
repeated NodeProto node = 1;

// The name of the graph.
optional string name = 2; // namespace Graph

// The parameters(inputs) and outputs of the graph.
repeated ParameterProto parameters = 3;
repeated OutputProto outputs = 4;

// Constants used in this graph
repeated NamedValueProto const_vals = 5;
}

// Tensors
//
// A serialized tensor value.
message TensorProto {
// The node name of the tensor.
optional string node_name = 1;

// The slot of the tensor in its node.
optional string slot = 2;

// The serialized tensor content.
optional bytes tensor_content = 3;

// The shape of the tensor.
repeated int64 dims = 4;

// The data type of the tensor.
// This field MUST have a valid DataType value except DT_TENSOR
optional DataType data_type = 5;

// If the tensor content transferring is finished.
optional bool finished = 6;

// The iteration of the tensor. Supported: "prev" or leave empty.
optional string iter = 7;

// If the tensor name should be truncated.
optional bool truncate = 8;
}

+ 1400
- 0
mindinsight/domain/graph/proto/ms_graph_pb2.py
File diff suppressed because it is too large
View File


+ 7
- 0
mindinsight/utils/constant.py View File

@@ -102,3 +102,10 @@ class OptimizerErrors(Enum):
OPTIMIZER_TERMINATE = 4 OPTIMIZER_TERMINATE = 4
CONFIG_PARAM_ERROR = 5 CONFIG_PARAM_ERROR = 5
HYPER_CONFIG_ENV_ERROR = 6 HYPER_CONFIG_ENV_ERROR = 6


class GraphDomainErrors(Enum):
"""Enum definition for graph domain errors."""
UNKNOWN_DATA_TYPE_ERROR = 1
TUPLE_GETITEM_INDEX_ERROR = 2
UNKNOWN_TENSOR_ERROR = 3

Loading…
Cancel
Save