Browse Source

add graph parser

pull/1325/head
liangyongxiong 4 years ago
parent
commit
8fd5ebb453
3 changed files with 982 additions and 0 deletions
  1. +406
    -0
      mindinsight/domain/graph/pb_parser.py
  2. +93
    -0
      mindinsight/domain/graph/query.py
  3. +483
    -0
      mindinsight/domain/graph/utils.py

+ 406
- 0
mindinsight/domain/graph/pb_parser.py View File

@@ -0,0 +1,406 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""PB parser module."""

from mindinsight.domain.graph.proto import ms_graph_pb2 as graph_proto
from mindinsight.domain.graph.base import MindSporeType, InputType, OutputType
from mindinsight.domain.graph.base import Input, Output, Tensor, Source, Constant, Parameter, Operator, Parser
from mindinsight.domain.graph.exceptions import UnknownDataTypeError, TupleGetitemIndexError


class PBParser(Parser):
"""Protobuf file parser."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self.proto = graph_proto
self.dtype_mapping = {
self.proto.DT_BOOL: 'bool',
self.proto.DT_INT8: 'int8',
self.proto.DT_INT16: 'int16',
self.proto.DT_INT32: 'int32',
self.proto.DT_INT64: 'int64',
self.proto.DT_UINT8: 'uint8',
self.proto.DT_UINT16: 'uint16',
self.proto.DT_UINT32: 'uint32',
self.proto.DT_UINT64: 'uint64',
self.proto.DT_FLOAT16: 'float16',
self.proto.DT_FLOAT32: 'float32',
self.proto.DT_FLOAT64: 'float64',
self.proto.DT_TENSOR: 'tensor',
self.proto.DT_TUPLE: 'tuple',
self.proto.DT_BASE_INT: MindSporeType.INT,
self.proto.DT_BASE_UINT: MindSporeType.UINT,
self.proto.DT_BASE_FLOAT: MindSporeType.FLOAT,
}
self.int_types = (
self.proto.DT_INT8,
self.proto.DT_INT16,
self.proto.DT_INT32,
self.proto.DT_INT64,
)
self.uint_types = (
self.proto.DT_UINT8,
self.proto.DT_UINT16,
self.proto.DT_UINT32,
self.proto.DT_UINT64,
)
self.float_types = (
self.proto.DT_FLOAT16,
self.proto.DT_FLOAT32,
self.proto.DT_FLOAT64,
)

def _parse_constants(self, pb_constant):
"""
Parse constants.

Args:
pb_constant (Protobuf): Constant node.
"""
constant = Constant(pb_constant.key)
constant.raw = str(pb_constant)

if pb_constant.value.dtype in self.int_types:
constant.output = Output(OutputType(self.dtype_mapping[pb_constant.value.dtype]))
constant.output.info['value'] = pb_constant.value.int_val
elif pb_constant.value.dtype in self.float_types:
constant.output = Output(OutputType(self.dtype_mapping[pb_constant.value.dtype]))
constant.output.info['value'] = pb_constant.value.float_val
elif pb_constant.value.dtype == self.proto.DT_TENSOR:
constant.output = Output(OutputType.TENSOR)
if pb_constant.value.tensor_val.data_type:
constant.output.info['dtype'] = self.dtype_mapping[pb_constant.value.tensor_val.data_type]
constant.output.info['shape'] = tuple(pb_constant.value.tensor_val.dims)

feature = Tensor.FEATURE(type='', id=constant.name, io='output')
values = self.tensor_mapping.get(feature, [])
if len(values) == 1:
value = values[0]
constant.output.info['tensor'] = Tensor(constant.name, 0, value.path)
else:
constant.output = Output(OutputType.NONE)

self.constants.append(constant)

def _parse_parameters(self, pb_parameter):
"""
Parse parameters.

Args:
pb_parameter (Protobuf): Parameter node.
"""
parameter = Parameter(pb_parameter.name)
parameter.raw = str(pb_parameter)

output = Output(OutputType.TENSOR)
if pb_parameter.type.tensor_type.elem_type == self.proto.DT_UNDEFINED:
output.type = OutputType.NONE
else:
output.info['dtype'] = self.dtype_mapping[pb_parameter.type.tensor_type.elem_type]
output.info['shape'] = tuple([dim.size for dim in pb_parameter.type.tensor_type.shape.dim])

feature = Tensor.FEATURE(type='', id=parameter.name, io='output')
values = self.tensor_mapping.get(feature, [])
if len(values) == 1:
value = values[0]
output.info['tensor'] = Tensor(parameter.name, 0, value.path)

parameter.output = output
self.parameters.append(parameter)

def _get_proto_value(self, value):
"""
Get proto value.

Args:
value (Protobuf): Protobuf value.

Returns:
any, proto value.

Raises:
UnknownDataTypeError: If data type of protobuf value can not be recognized.
"""
if value.dtype == self.proto.DT_UNDEFINED:
return None
if value.dtype == self.proto.DT_TYPE:
return self.dtype_mapping[value.type_val.data_type]
if value.dtype == self.proto.DT_BOOL:
return value.bool_val
if value.dtype == self.proto.DT_STRING:
return value.str_val
if value.dtype in self.uint_types:
return value.uint_val
if value.dtype in self.int_types:
return value.int_val
if value.dtype in self.float_types:
return value.float_val
if value.dtype in (self.proto.DT_LIST, self.proto.DT_TUPLE):
value_items = []
for value_item in value.values:
value_items.append(self._get_proto_value(value_item))
if value.dtype == self.proto.DT_TUPLE:
value_items = tuple(value_items)
return value_items
raise UnknownDataTypeError(value.dtype)

def _find_operator_output_tensors(self, operator):
"""
Find operator output tensors.

Args:
operator (Operator): Operator object.
"""
output = operator.output
if operator.type in ('tuple_getitem', 'make_tuple') or output.type in (OutputType.BOOL, OutputType.NONE):
return

is_load_op = False
if operator.full_name.find('-op') == -1:
is_load_op = True
feature = Tensor.FEATURE(type='', id=operator.full_name, io='output')
else:
_, op_id = operator.full_name.split('-op')
feature = Tensor.FEATURE(type=operator.type, id=op_id, io='output')

values = self.tensor_mapping.get(feature)
if not values:
return

if output.type == OutputType.TENSOR and len(values) == 1:
value = values[0]
if is_load_op:
output.info['tensor'] = Tensor(operator.full_name, 0, value.path)
else:
output.info['tensor'] = Tensor(op_id, value.index, value.path)
elif output.type == OutputType.TUPLE and len(values) == len(output.info['dtypes']):
for value in values:
output.info['tensors'][value.index] = Tensor(op_id, value.index, value.path)

def _process_operator_input(self, operator, input_index, input_types):
"""
Process operator input.

Args:
operator (Operator): Operator.
input_index (int): Input index.
input_types (dict): Input types.
"""
op_input = operator.inputs[input_index]
node = input_types[op_input.type].get(op_input.name)
if not node:
return
node.downstream.append(operator.op_id)
if op_input.type == InputType.OPERATOR:
op_input.op_id = node.op_id

op_input.info = {}

if op_input.type in (InputType.PARAMETER, InputType.CONSTANT):
op_input.info = node.output.info.copy() if node.output.info else None
return

if operator.type in ('tuple_getitem', 'TupleGetItem'):
op_input.info['dtype'] = OutputType.TUPLE
return

if node.output.type == OutputType.TENSOR:
op_input.info = node.output.info.copy()
else:
op_input.info['dtype'] = node.output.type

if node.full_name.find('-op') == -1:
op_id = node.full_name
feature = Tensor.FEATURE(type='', id=node.full_name, io='output')
else:
_, op_id = node.full_name.split('-op')
feature = Tensor.FEATURE(type=node.type, id=op_id, io='input')

values = self.tensor_mapping.get(feature)
if values and len(values) == len(operator.inputs) and op_input.info['tensor'] is None:
value = values[input_index]
op_input.info['tensor'] = Tensor(op_id, value.index, value.path)

def _parse_operators(self, pb_operator):
"""
Parse operator.

Args:
pb_operator (Protobuf): Operator node.

Raises:
UnknownDataTypeError: If data type of protobuf value can not be recognized.
"""
operator = Operator(pb_operator.name, pb_operator.op_type)
operator.full_name = pb_operator.full_name
operator.raw = str(pb_operator)
self.operators.append(operator)

# parse source code
if getattr(pb_operator, 'source_address', None):
operator.stack = Source.build_stack_from_source_address(pb_operator.source_address)

# parse attrs
for attr in pb_operator.attribute:
operator.attrs[attr.name] = self._get_proto_value(attr.value)
if attr.name in ('input_names', 'output_names'):
operator.attrs[attr.name] = list(operator.attrs[attr.name])

# parse inputs
cst_mapping = dict((cst.name, cst) for cst in self.constants)
param_mapping = dict((param.name, param) for param in self.parameters)
op_mapping = dict((op.name, op) for op in self.operators)
for pb_input in pb_operator.input:
input_type = InputType.REFERENCE
if pb_input.name in cst_mapping:
input_type = InputType.CONSTANT
elif pb_input.name in param_mapping:
input_type = InputType.PARAMETER
elif pb_input.name in op_mapping:
input_type = InputType.OPERATOR

op_input = Input(input_type, pb_input.name)
if input_type == InputType.OPERATOR:
op_input.op_id = op_mapping[op_input.name].op_id

operator.inputs.append(op_input)

# parse output
proto_output = pb_operator.output_type
if proto_output.data_type in (self.proto.DT_UNDEFINED, self.proto.DT_NONE):
output = Output(OutputType.NONE)
elif proto_output.data_type == self.proto.DT_BOOL:
output = Output(OutputType.BOOL)
elif proto_output.data_type == self.proto.DT_TENSOR:
output = Output(OutputType.TENSOR)
output.info['dtype'] = self.dtype_mapping[proto_output.tensor_type.elem_type]
if proto_output.tensor_type.shape:
output.info['shape'] = tuple([dim.size for dim in proto_output.tensor_type.shape.dim])
elif proto_output.data_type == self.proto.DT_TUPLE:
output = Output(OutputType.TUPLE)
for elem_type in proto_output.sequence_type.elem_types:
dtype = self._get_tuple_item_dtype(elem_type)
output.info['dtypes'].append(dtype)
output.info['shapes'].append(None)
output.info['tensors'].append(None)
else:
raise UnknownDataTypeError(proto_output.data_type)

operator.output = output
self._find_operator_output_tensors(operator)

def _get_tuple_item_dtype(self, elem_type):
"""
Get tuple item dtype.

Args:
elem_type (TypeProto) : TypeProto of tuple item operator.

Returns:
any, tuple item dtype.
"""
if elem_type.tensor_type.elem_type != self.proto.DT_UNDEFINED:
return self.dtype_mapping[elem_type.tensor_type.elem_type]

return self.dtype_mapping[elem_type.data_type]

def _get_tuple_getitem_index(self, operator, constant_mapping):
"""
Get tuple_getitem index.

Args:
operator (Operator) : Operator.
constant_mapping (dict) : Constant mapping.

Returns:
int, tuple_getitem index.

Raises:
TupleGetitemIndexError: If tuple_getitem index error occurs.
"""
if operator.inputs[1].type == InputType.CONSTANT:
constant = constant_mapping[operator.inputs[1].name]
if constant.output.type in Output.SCALAR_TYPES:
return int(constant.output.info['value'])
raise TupleGetitemIndexError(operator.name, constant.name)

if operator.inputs[1].type == InputType.SCALAR:
return int(operator.inputs[1].name)
raise TupleGetitemIndexError(operator.name, f'{operator.inputs[1].name}')

def _post_process(self):
"""Post-process."""
constant_mapping = self.get_constants()
parameter_mapping = self.get_parameters()
operator_mapping = dict((operator.name, operator) for operator in self.operators)
input_types = {
InputType.CONSTANT: constant_mapping,
InputType.PARAMETER: parameter_mapping,
InputType.OPERATOR: operator_mapping,
}

transition_operator_types = ('Squeeze', 'Reshape', 'ExpandDims', 'Flatten')

for operator in self.operators:
if operator.type in ('make_tuple', 'MakeTuple'):
continue

if operator.type in ('tuple_getitem', 'TupleGetItem'):
index = self._get_tuple_getitem_index(operator, constant_mapping)
output = operator_mapping[operator.inputs[0].name].output
if output.type == OutputType.TUPLE and len(output.info['tensors']) > index:
operator.output.info['tensor'] = output.info['tensors'][index]

elif operator.type in transition_operator_types and operator.output.info['tensor'] is None:
op_input = operator.inputs[0]
if op_input.type in input_types:
node = input_types[op_input.type][op_input.name]
operator.output.info['tensor'] = node.output.info['tensor']

elif operator.type == 'Depend':
if operator.output.type in (OutputType.NONE, OutputType.BOOL) \
or (operator.output.type == OutputType.TENSOR and operator.output.info['tensor'] is None):
op_input = operator.inputs[0]
if op_input.type == InputType.OPERATOR:
node = operator_mapping[op_input.name]
operator.output = node.output

if operator.type == 'Assign' and len(operator.inputs) == 3:
operator.inputs = operator.inputs[1:]

for input_index, op_input in enumerate(operator.inputs):
if op_input.type in input_types:
self._process_operator_input(operator, input_index, input_types)

def parse(self):
"""Parse."""
self.tensor_mapping = Tensor.scan_tensors(self.tensor_dir)

# parse constants
for pb_constant in self.graph_data.const_vals:
self._parse_constants(pb_constant)

# parse parameters:
for pb_parameter in self.graph_data.parameters:
self._parse_parameters(pb_parameter)

# parse operators
for pb_operator in self.graph_data.node:
self._parse_operators(pb_operator)

# post-process
self._post_process()

+ 93
- 0
mindinsight/domain/graph/query.py View File

@@ -0,0 +1,93 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Query module."""

import re
from functools import reduce


class BaseStackOperator:
"""Base stack operator."""

def __init__(self, stack=None):
self.stack = stack or []


class StackQuery:
"""Stack query."""

def __init__(self, operators):
self.operators = operators

def all(self):
"""
Retrieve all operators.

Returns:
list, all operators.
"""
return self.operators

def get(self, index=0):
"""
Retrieve one operator.

Args:
index (int): Operator index, default is 0.

Returns:
Operator, single operator.
"""
if 0 <= index < len(self.operators):
return self.operators[index]
return None

def filter(self, qs, use_regex=False):
"""
Filter operators with query.

Args:
qs (str): Query string.
use_regex (bool): Indicates if qs is regex.

Returns:
StackQuery, cloned object.
"""
if use_regex:
func = lambda x: bool(re.search(qs, x))
else:
func = lambda x: x.find(qs) > -1

operators = []
for operator in self.operators:
if not operator.stack:
continue
stack_contents = [
f'{source.file_path}:{source.line_no}\n{source.code_line}'
for source in operator.stack
]
if reduce(lambda x, y: x or y, map(func, stack_contents)):
operators.append(operator)

return self.clone(operators)

def clone(self, operators):
"""
Clone query object.

Returns:
StackQuery, cloned object.
"""
return StackQuery(operators)

+ 483
- 0
mindinsight/domain/graph/utils.py View File

@@ -0,0 +1,483 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse utils module."""

import os

import xlsxwriter

from mindinsight.domain.graph.base import InputType, OutputType


class Toolkit:
"""Toolkit."""

PLACEHOLDER = '-'

def __init__(self, dump_dir, constants, parameters, operators):
self.dump_dir = dump_dir
self.constants = constants
self.parameters = parameters
self.operators = operators

def export_xlsx(self, file_path):
"""
Export graph data to Excel file.

Args:
file_path (str) : Excel file path.
"""
target_dir = os.path.dirname(file_path)
if not os.path.isdir(target_dir):
print(f'Directory {target_dir} not exists')
return

workbook = xlsxwriter.Workbook(file_path)

# text_v_align: 1-top, 2-middle, 3-bottom
# text_h_align: 1-left, 2-center, 3-right
styles = dict(
header_left_fmt=workbook.add_format(dict(
text_v_align=2, text_h_align=1,
font_color='#000000', bg_color='#d9d9d9',
bold=True,
)),
header_center_fmt=workbook.add_format(dict(
text_v_align=2, text_h_align=2,
font_color='#000000', bg_color='#d9d9d9',
bold=True,
)),
content_left_fmt=workbook.add_format(dict(
text_v_align=2, text_h_align=1,
text_wrap=False,
)),
content_center_fmt=workbook.add_format(dict(
text_v_align=2, text_h_align=2,
text_wrap=False,
)),
content_wrapped_fmt=workbook.add_format(dict(
text_v_align=2, text_h_align=1,
text_wrap=True,
)),
)

self._add_info_worksheet(workbook, styles)
self._add_constant_worksheet(workbook, styles)
self._add_parameter_worksheet(workbook, styles)
self._add_operator_worksheet(workbook, styles)
self._add_statistics_worksheet(workbook, styles)
self._add_source_worksheet(workbook, styles)

for worksheet in workbook.sheetnames.values():
worksheet.freeze_panes(1, 0)
worksheet.freeze_panes(1, 1)

workbook.close()

def _convert_column_indices(self, metas):
"""
Convert column metas into indices mapping.

Args:
metas (list): Column metas.

Returns:
dict, holds the indicess of columns.
"""
mapping = {}
for index, (name, _, _) in enumerate(metas):
mapping[name] = index
return mapping

def _add_info_worksheet(self, workbook, styles):
"""
Add info worksheet.

Args:
workbook (WorkBook): Excel workbook.
styles (dict): Workbook styles.
"""
worksheet = workbook.add_worksheet('info')

# column metas contain column names, styles and widths
column_metas = [
('argument', styles['header_center_fmt'], 20),
('value', styles['header_left_fmt'], 150),
]
for index, (column, fmt, width) in enumerate(column_metas):
worksheet.set_column(index, index, width)
worksheet.write(0, index, column, fmt)
worksheet.autofilter(0, 0, 0, len(column_metas) - 1)

indices = self._convert_column_indices(column_metas)
worksheet.write(1, indices.get('argument'), 'dump-dir', styles['content_center_fmt'])
worksheet.write(1, indices.get('value'), self.dump_dir or '', styles['content_left_fmt'])

def _get_operator_input_info(self, operator, input_types):
"""
Add operator worksheet.

Args:
operator (Operator): Operator.
input_types (dict): Input types.

Returns:
dict, input info content.
"""
input_content = ''
input_dtype_content = ''
input_shape_content = ''

for op_input in operator.inputs:
if op_input.type == InputType.OPERATOR:
op = input_types[InputType.OPERATOR][op_input.op_id]
if op.type == 'Load':
input_content += f'{op.type}_{op.name}' + '\n'
else:
input_content += f'{op.type}_{op.op_id}' + '\n'
if op_input.info:
input_dtype_content += str(op_input.info['dtype']) + '\n'
input_shape_content += str(op_input.info.get('shape') or Toolkit.PLACEHOLDER) + '\n'
else:
input_dtype_content += Toolkit.PLACEHOLDER + '\n'
input_shape_content += Toolkit.PLACEHOLDER + '\n'
elif op_input.type == InputType.PARAMETER:
input_content += op_input.name + '\n'
param = input_types[InputType.PARAMETER][op_input.name]
if param.output:
input_dtype_content += param.output.info['dtype'] + '\n'
input_shape_content += str(param.output.info.get('shape') or Toolkit.PLACEHOLDER) + '\n'
else:
input_dtype_content += Toolkit.PLACEHOLDER + '\n'
input_shape_content += Toolkit.PLACEHOLDER + '\n'
elif op_input.type == InputType.CONSTANT:
input_content += op_input.name + '\n'
cst = input_types[InputType.CONSTANT][op_input.name]
if cst.output.type == OutputType.TENSOR:
input_dtype_content += cst.output.info.get('dtype') or Toolkit.PLACEHOLDER + '\n'
input_shape_content += str(cst.output.info.get('shape') or Toolkit.PLACEHOLDER) + '\n'
else:
input_dtype_content += Toolkit.PLACEHOLDER + '\n'
input_shape_content += Toolkit.PLACEHOLDER + '\n'
else:
input_content += op_input.name + '\n'
input_dtype_content += Toolkit.PLACEHOLDER + '\n'
input_shape_content += Toolkit.PLACEHOLDER + '\n'

return {
'input': input_content.strip(),
'input_dtype': input_dtype_content.strip(),
'input_shape': input_shape_content.strip(),
}

def _add_operator_worksheet(self, workbook, styles):
"""
Add operator worksheet.

Args:
workbook (WorkBook): Excel workbook.
styles (dict): Workbook styles.
"""
constant_mapping = dict((constant.name, constant) for constant in self.constants)
parameter_mapping = dict((parameter.name, parameter) for parameter in self.parameters)
operator_mapping = dict((operator.op_id, operator) for operator in self.operators)
input_types = {
InputType.CONSTANT: constant_mapping,
InputType.PARAMETER: parameter_mapping,
InputType.OPERATOR: operator_mapping,
}

worksheet = workbook.add_worksheet('operator')

# column metas contain column names, styles and widths
column_metas = [
('operator', styles['header_left_fmt'], 30),
('input', styles['header_left_fmt'], 30),
('input_dtype', styles['header_left_fmt'], 20),
('input_shape', styles['header_left_fmt'], 25),
('output_dtype', styles['header_left_fmt'], 20),
('output_shape', styles['header_left_fmt'], 25),
('downstream', styles['header_left_fmt'], 30),
('name', styles['header_center_fmt'], 10),
('attrs', styles['header_left_fmt'], 30),
('full_name', styles['header_left_fmt'], 20),
('device_id', styles['header_left_fmt'], 20),
('graph_name', styles['header_left_fmt'], 30),
('stack', styles['header_left_fmt'], 150),
]
for index, (column, fmt, width) in enumerate(column_metas):
worksheet.set_column(index, index, width)
worksheet.write(0, index, column, fmt)
worksheet.autofilter(0, 0, 0, len(column_metas) - 1)

indices = self._convert_column_indices(column_metas)
for index, operator in enumerate(self.operators):
if operator.type == 'Load':
operator_content = f'{operator.type}_{operator.name}'
else:
operator_content = f'{operator.type}_{operator.op_id}'

worksheet.write(index + 1, indices.get('operator'), operator_content, styles['content_left_fmt'])

if operator.type == 'make_tuple':
worksheet.write(index + 1, indices.get('device_id'), operator.device_id, styles['content_left_fmt'])
worksheet.write(index + 1, indices.get('graph_name'), operator.graph_name, styles['content_left_fmt'])
continue

input_info = self._get_operator_input_info(operator, input_types)
worksheet.write(index + 1, indices.get('input'), input_info['input'], styles['content_wrapped_fmt'])
worksheet.write(
index + 1, indices.get('input_dtype'),
input_info['input_dtype'], styles['content_wrapped_fmt'])
worksheet.write(
index + 1, indices.get('input_shape'),
input_info['input_shape'], styles['content_wrapped_fmt'])

output_dtype_content = ''
output_shape_content = ''
if operator.output and operator.output.type == OutputType.TENSOR:
output_dtype_content = operator.output.info['dtype']
output_shape_content = str(operator.output.info['shape'])
elif operator.output and operator.output.type == OutputType.TUPLE:
output_dtype_content = '\n'.join([
Toolkit.PLACEHOLDER if dtype is None else dtype
for dtype in operator.output.info['dtypes']
])
output_shape_content = '\n'.join([
Toolkit.PLACEHOLDER if shape is None else str(shape)
for shape in operator.output.info['shapes']
])
worksheet.write(
index + 1, indices.get('output_dtype'),
output_dtype_content, styles['content_wrapped_fmt'])
worksheet.write(
index + 1, indices.get('output_shape'),
output_shape_content, styles['content_wrapped_fmt'])

downstream_content = ''
for op_id in operator.downstream:
op = operator_mapping[op_id]
downstream_content += f'{op.type}_{op.op_id}' + '\n'
worksheet.write(
index + 1, indices.get('downstream'),
downstream_content.strip(), styles['content_wrapped_fmt'])

worksheet.write(index + 1, indices.get('name'), operator.name, styles['content_center_fmt'])
worksheet.write(index + 1, indices.get('attrs'), str(operator.attrs), styles['content_left_fmt'])
worksheet.write(index + 1, indices.get('full_name'), operator.full_name, styles['content_left_fmt'])
worksheet.write(index + 1, indices.get('device_id'), operator.device_id, styles['content_left_fmt'])
worksheet.write(index + 1, indices.get('graph_name'), operator.graph_name, styles['content_left_fmt'])

stack_content = ''
for source in operator.stack:
stack_content += f'{source.file_path}:{source.line_no}\n{source.code_line}\n'
worksheet.write(index + 1, indices.get('stack'), stack_content.strip(), styles['content_wrapped_fmt'])

def _add_parameter_worksheet(self, workbook, styles):
"""
Add parameter worksheet.

Args:
workbook (WorkBook): Excel workbook.
styles (dict): Workbook styles.
"""
worksheet = workbook.add_worksheet('parameter')

# column metas contain column names, styles and widths
column_metas = [
('name', styles['header_left_fmt'], 50),
('output_dtype', styles['header_left_fmt'], 20),
('output_shape', styles['header_left_fmt'], 25),
('downstream', styles['header_left_fmt'], 30),
('device_id', styles['header_left_fmt'], 20),
('graph_name', styles['header_left_fmt'], 30),
]
for index, (column, fmt, width) in enumerate(column_metas):
worksheet.set_column(index, index, width)
worksheet.write(0, index, column, fmt)
worksheet.autofilter(0, 0, 0, len(column_metas) - 1)

indices = self._convert_column_indices(column_metas)
operator_mapping = dict((operator.op_id, operator) for operator in self.operators)
for index, parameter in enumerate(self.parameters):
worksheet.write(index + 1, indices.get('name'), parameter.name, styles['content_left_fmt'])
worksheet.write(
index + 1, indices.get('output_dtype'),
parameter.output.info['dtype'], styles['content_left_fmt'])
worksheet.write(
index + 1, indices.get('output_shape'),
str(parameter.output.info['shape']), styles['content_left_fmt'])

downstream_nodes = [operator_mapping[op_id] for op_id in parameter.downstream]
downstream_content = ''
for op in downstream_nodes:
if op.type == 'Load':
downstream_content += f'{op.type}_{op.name}' + '\n'
else:
downstream_content += f'{op.type}_{op.op_id}' + '\n'
worksheet.write(
index + 1, indices.get('downstream'),
downstream_content.strip(), styles['content_wrapped_fmt'])

worksheet.write(index + 1, indices.get('device_id'), parameter.device_id, styles['content_left_fmt'])
worksheet.write(index + 1, indices.get('graph_name'), parameter.graph_name, styles['content_left_fmt'])

def _add_constant_worksheet(self, workbook, styles):
"""
Add constant worksheet.

Args:
workbook (WorkBook): Excel workbook.
styles (dict): Workbook styles.
"""
worksheet = workbook.add_worksheet('constant')

# column metas contain column names, styles and widths
column_metas = [
('name', styles['header_left_fmt'], 10),
('value', styles['header_left_fmt'], 30),
('downstream', styles['header_left_fmt'], 30),
('device_id', styles['header_left_fmt'], 20),
('graph_name', styles['header_left_fmt'], 30),
]
for index, (column, fmt, width) in enumerate(column_metas):
worksheet.set_column(index, index, width)
worksheet.write(0, index, column, fmt)
worksheet.autofilter(0, 0, 0, len(column_metas) - 1)

indices = self._convert_column_indices(column_metas)
operator_mapping = dict((operator.op_id, operator) for operator in self.operators)
for index, constant in enumerate(self.constants):
worksheet.write(index + 1, indices.get('name'), constant.name, styles['content_left_fmt'])

if constant.output.type == OutputType.NONE:
value_content = 'NONE'
elif constant.output.type == OutputType.TENSOR:
value_content = 'TENSOR'
else:
value_content = constant.output.info['value']
worksheet.write(index + 1, indices.get('value'), value_content, styles['content_left_fmt'])

downstream_nodes = [operator_mapping[op_id] for op_id in constant.downstream]
downstream_content = ''
for op in downstream_nodes:
if op.type == 'Load':
downstream_content += f'{op.type}_{op.name}' + '\n'
else:
downstream_content += f'{op.type}_{op.op_id}' + '\n'
worksheet.write(
index + 1, indices.get('downstream'),
downstream_content.strip(), styles['content_wrapped_fmt'])

worksheet.write(index + 1, indices.get('device_id'), constant.device_id, styles['content_left_fmt'])
worksheet.write(index + 1, indices.get('graph_name'), constant.graph_name, styles['content_left_fmt'])

def _add_statistics_worksheet(self, workbook, styles):
"""
Add statistics worksheet.

Args:
workbook (WorkBook): Excel workbook.
styles (dict): Workbook styles.
"""
worksheet = workbook.add_worksheet('statistics')

# column metas contain column names, styles and widths
column_metas = [
('operator', styles['header_left_fmt'], 30),
('count', styles['header_center_fmt'], 20),
]
for index, (column, fmt, width) in enumerate(column_metas):
worksheet.set_column(index, index, width)
worksheet.write(0, index, column, fmt)
worksheet.autofilter(0, 0, 0, len(column_metas) - 1)

operator_type_set = set()
for operator in self.operators:
operator_type_set.add(operator.type)

operator_types = sorted(list(operator_type_set))
stats = dict(zip(operator_types, [0]*len(operator_types)))
for operator in self.operators:
stats[operator.type] += 1

indices = self._convert_column_indices(column_metas)
for index, operator_type in enumerate(operator_types):
worksheet.write(index + 1, indices.get('operator'), operator_type, styles['content_left_fmt'])
worksheet.write(index + 1, indices.get('count'), stats[operator_type], styles['content_center_fmt'])

def _add_source_worksheet(self, workbook, styles):
"""
Add source worksheet.

Args:
workbook (WorkBook): Excel workbook.
styles (dict): Workbook styles.
"""
worksheet = workbook.add_worksheet('source')

# column metas contain column names, styles and widths
column_metas = [
('stack', styles['header_left_fmt'], 150),
('operator', styles['header_left_fmt'], 30),
('full_name', styles['header_left_fmt'], 20),
('device_id', styles['header_left_fmt'], 20),
('graph_name', styles['header_left_fmt'], 30),
]
for index, (column, fmt, width) in enumerate(column_metas):
worksheet.set_column(index, index, width)
worksheet.write(0, index, column, fmt)
worksheet.autofilter(0, 0, 0, len(column_metas) - 1)

source_mapping = {}
for operator in self.operators:
if not operator.stack:
continue
stack = [f'{source.file_path}:{source.line_no}\n{source.code_line}' for source in operator.stack]
key = '\n'.join(stack)
if key in source_mapping:
source_mapping[key].append(operator)
else:
source_mapping[key] = [operator]

row = 0
indices = self._convert_column_indices(column_metas)
for key in source_mapping:
operators = source_mapping[key]
operators.sort(key=lambda x: int(x.op_id))

if len(operators) == 1:
worksheet.write(row + 1, indices.get('stack'), key, styles['content_wrapped_fmt'])
else:
worksheet.merge_range(
row + 1, indices.get('stack'),
row+len(operators), 0, key, styles['content_wrapped_fmt'])

for index, operator in enumerate(operators):
operator_content = f'{operator.type}_{operator.op_id}'
worksheet.write(
row + index + 1, indices.get('operator'),
operator_content, styles['content_left_fmt'])
worksheet.write(
row + index + 1, indices.get('full_name'),
operator.full_name, styles['content_left_fmt'])
worksheet.write(
row + index + 1, indices.get('device_id'),
operator.device_id, styles['content_left_fmt'])
worksheet.write(
row + index + 1, indices.get('graph_name'),
operator.graph_name, styles['content_left_fmt'])

row += len(operators)

Loading…
Cancel
Save