Browse Source

!1301 add async dump tensor convertor

From: @yelihua
Reviewed-by: @wenkai_dist,@lixiaohui33
Signed-off-by: @lixiaohui33
pull/1301/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
e2eafa11b2
3 changed files with 522 additions and 2 deletions
  1. +19
    -0
      mindinsight/debugger/dump/__init__.py
  2. +501
    -0
      mindinsight/debugger/dump/convert.py
  3. +2
    -2
      mindinsight/debugger/stream_cache/data_loader.py

+ 19
- 0
mindinsight/debugger/dump/__init__.py View File

@@ -0,0 +1,19 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Dump Module Introduction.

This module provides Python APIs to parse dump data directory.
"""

+ 501
- 0
mindinsight/debugger/dump/convert.py View File

@@ -0,0 +1,501 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse tensor files from async dump structure."""

import os
import stat
import sys
from collections import namedtuple
from importlib import import_module
from pathlib import Path

import numpy as np

PARSE_ARGS_FIELDS = ['dump_path', 'format', 'output_path', 'output_file_type',
'input', 'output', 'shape',
'custom_script_path', 'dump_version']


class ArgsParser(namedtuple("ArgsParser", PARSE_ARGS_FIELDS)):
"""Args Parser object."""

__slots__ = ()

def __new__(cls, **kwargs):
new_kwargs = {field: kwargs.get(field) for field in PARSE_ARGS_FIELDS}
new_kwargs['dump_version'] = kwargs.get('dump_version', '2.0')
return super().__new__(cls, **new_kwargs)


def load_hisi_tools(msaccucmp_path=None):
"""
Load HISI tools.

Args:
msaccucmp_path (Path): The path object of msaccucmp.py path.

Returns:
tuple, the tuple of utils, common, shape_conversion module in toolkit package.
"""
msaccucmp_path = get_msaccucmp_path(msaccucmp_path)
hisi_tool_path = msaccucmp_path.parent
if str(hisi_tool_path) not in sys.path:
sys.path.append(str(hisi_tool_path))
try:
hisi_utils = import_module('utils')
hisi_common = import_module('common')
hisi_format_conversion = import_module('shape_conversion').FormatConversionMain
except ModuleNotFoundError:
raise ModuleNotFoundError(f'Failed to load HISI tools under {msaccucmp_path}')
return hisi_utils, hisi_common, hisi_format_conversion


def get_msaccucmp_path(msaccucmp_path=None):
"""
Get the Path of HISI msaccucmp file.

Args:
msaccucmp_path (str): The path of `msaccucmp.py` or `msaccucmp.pyc`. Default: None.

Returns:
Path, the msaccucmp.py file path object.
"""
if msaccucmp_path is not None:
msaccucmp_path = Path(msaccucmp_path).resolve()
if not msaccucmp_path.exists():
raise FileNotFoundError(f"File {msaccucmp_path} doesn't exists. Please check the input value.")
return msaccucmp_path
# search msaccucmp file under $ASCEND_AICPU_PATH
ascend_aicpu_path = os.environ.get('ASCEND_AICPU_PATH')
if not ascend_aicpu_path:
raise FileNotFoundError("Failed to find $ASCEND_AICPU_PATH parameter in environment. Please make sure you have"
"install run packages and set the environment correctly.")
ascend_aicpu_path = Path(ascend_aicpu_path).resolve()
msaccucmp_files = list(ascend_aicpu_path.rglob('msaccucmp.py*'))
if not msaccucmp_files:
raise FileNotFoundError(f"Failed to find msaccucmp.py or msaccucmp.pyc file under {ascend_aicpu_path}. Please"
f"make sure you have install toolkit package successfully.")
return msaccucmp_files[0]


class DumpRootDirConverter:
"""Convert the async dump data under dump root directory into host format."""

def __init__(self, data_loader, msaccucmp_path=None):
self.data_loader = data_loader
self.dump_data_dir = Path(data_loader.get_net_dir())
self.failed_summary_file = self.dump_data_dir.joinpath('convert_failed_files_summary.txt')
self._hisi_tools = load_hisi_tools(msaccucmp_path)
self.check_async_dir()

def check_async_dir(self):
"""Check if this directory is dumped asynchronously on Ascend."""
is_sync = self.data_loader.get_sync_flag()
if is_sync:
raise ValueError(f"The data under {str(self.dump_data_dir)} is not dumped asynchronously.")

def convert(self):
"""Convert dump data under root dump data directory from device format to host format."""
source_iterations = self.dump_data_dir.glob(f'device_[0-9]*/*_graph_[0-9]*/[0-9]*/[0-9]*/')
failed_lines = []
if self.failed_summary_file.is_file():
self.failed_summary_file.unlink()
for iter_path in source_iterations:
dump_path = str(iter_path)
res = DirConvert(dump_path=dump_path, output_path=dump_path, hisi_tools=self._hisi_tools).convert()
failed_lines.extend(res)
# add tensor format in file name

if failed_lines:
self.save_failed_fines(failed_lines)
return failed_lines

def save_failed_fines(self, failed_lines):
"""Save failed fines to file."""
with self.failed_summary_file.open('w') as handler:
for line in failed_lines:
handler.write(line + '\n')
self.failed_summary_file.chmod(stat.S_IRUSR)
hisi_utils = self._hisi_tools[0]
hisi_utils.print_info_log(f"Failed summary has saved to {str(self.failed_summary_file)}")


class DirConvert:
"""Convert the async dump data under one directory into host format."""

def __init__(self, dump_path, output_path, target_format='NCHW', output_file_type='npy', hisi_tools=None):
self.args_parser = ArgsParser(dump_path=dump_path,
format=target_format,
output_path=output_path,
output_file_type=output_file_type)
self.output_path = Path(output_path).absolute()
self.failed_file_path = self.output_path.joinpath('convert_failed_file_list.txt')
self.hisi_utils, self.hisi_common, self.hisi_format_conversion = load_hisi_tools() \
if hisi_tools is None else hisi_tools

def _is_npy_target(self):
"""Check if the output_file type is npy."""
return self.args_parser.output_file_type == 'npy'

def clean_old_files(self):
"""Clean old files."""
# clean failed file record
if self.failed_file_path.is_file():
self.failed_file_path.unlink()
# clean old converted data.
old_data_files = self.output_path.glob(f'*.{self.args_parser.output_file_type}')
for file in old_data_files:
file.unlink()

def convert(self):
"""Convert async dump data of src_dir to target_format and saved in output_dir."""
conversion = self.hisi_format_conversion(self.args_parser)
self.clean_old_files()
failed_lines = []
ret = conversion.convert_format()
self.rename_generated_npy_file()
if ret != self.hisi_utils.VECTOR_COMPARISON_NONE_ERROR:
self.hisi_utils.print_info_log(
f"Begin to convert failed operator in {str(self.failed_file_path)} one by one.")
failed_lines = self.convert_failed_tensors()
else:
self.hisi_utils.print_info_log(
f"All tensor under {self.args_parser.dump_path} have been converted to {self.output_path} "
f"successfully.")
return failed_lines

def rename_generated_npy_file(self):
"""Rename the npy file generated by HISI tool to MS file name format."""
# before change, the file name is format like:
# {op_type}.{op_name_with_scope}.{task_id}(.stream_id).{timestamp}.{tensor_type}.{slot}.{shape}.npy
# after change, the file name is format like:
# {op_type}.{op_name}.{task_id}(.stream_id).{timestamp}.{tensor_type}.{slot}.{format}.npy
if not self._is_npy_target():
return
self.hisi_utils.print_info_log(
f"Start to rename npy files under {self.output_path}")
target_format = self.args_parser.format
old_data_files = self.output_path.glob('*.npy')
for file in old_data_files:
name_splits = file.name.split('.')
name_splits[1] = name_splits[1].split('_')[-1]
name_splits[-2] = target_format
new_file_name = '.'.join(name_splits)
file.rename(file.with_name(new_file_name))

def convert_failed_tensors(self):
"""Convert failed tensors from failed txt."""
failed_lines = []
if not self.failed_file_path.is_file():
return failed_lines
with self.failed_file_path.open() as handler:
failed_line = handler.readline().strip('\n')
while failed_line:
try:
self.convert_operator_by_failed_line(failed_line)
except (ValueError, OSError, AttributeError) as err:
self.hisi_utils.print_error_log(f'Failed to convert {failed_line} to Host format. \n {str(err)}')
failed_lines.append(failed_line)
failed_line = handler.readline().strip('\n')
if failed_lines:
self.hisi_utils.print_error_log(f"Failed to convert: {failed_lines}")
self.hisi_utils.print_info_log("Finish convert failed operators to host format.")
return failed_lines

def convert_operator_by_failed_line(self, failed_line):
"""Convert operator by failed line."""
fields = failed_line.split(',')
if len(fields) > 1:
op_file = fields[0]
op_data = self.hisi_utils.parse_dump_file(op_file, self.args_parser.dump_version)
missing_tensors = fields[1:]
for missing_tensor in missing_tensors:
tensor_type, idx = missing_tensor.split(':')
idx = int(idx)
tensor = op_data.input[idx] if tensor_type == 'input' else op_data.output[idx]
dump_data_array = self.get_tensor_numpy_value(tensor)
self.save_tensor_file(op_file, tensor_type, idx, tensor, dump_data_array)

def get_tensor_numpy_value(self, tensor):
"""Convert tensor from device format to host format."""
dump_data_array = self.hisi_utils.deserialize_dump_data_to_array(tensor)
array = dump_data_array.reshape(tensor.shape.dim)
return array

def save_tensor_file(self, op_file, tensor_type, idx, tensor, dump_data_array):
"""
Save tensor file.

Args:
op_file (str): Source operator file path.
tensor_type (str): The tensor type of the operator, `input` or `output`.
idx (int): Tensor slot index.
tensor (TensorProto): Tensor data in proto format.
dump_data_array (numpy.array): Tensor data in numpy format.
"""
op_name = os.path.basename(op_file)
# shorten the op_name to meet the linux file name len limit.
op_name = self._remove_scope_in_op_name(op_name)
if self._is_npy_target():
self._save_tensor_in_npy(op_name, tensor_type, idx, tensor, dump_data_array)
else:
self._save_tensor_in_bin(op_name, tensor_type, idx, tensor, dump_data_array)

@staticmethod
def _remove_scope_in_op_name(op_name):
"""Remove scope in operation name."""
name_splits = op_name.split('.')
node_name = name_splits[1]
name_splits[1] = node_name.split('_')[-1]
return '.'.join(name_splits)

def _save_tensor_in_npy(self, op_name, tensor_type, idx, tensor, dump_data_array):
"""
Save tensor file in `npy` format.

Args:
op_name (str): Operator name without scope.
tensor_type (str): The tensor type of the operator, `input` or `output`.
idx (int): Tensor slot index.
tensor (TensorProto): Tensor data in proto format.
dump_data_array (numpy.array): Tensor data in numpy format.
"""
out_file_name = "%s.%s.%d.%s.npy" % (
op_name,
tensor_type,
idx,
self.hisi_common.get_format_string(tensor.format)
)
out_path = os.path.join(self.args_parser.output_path, out_file_name)
np.save(out_path, dump_data_array)

def _save_tensor_in_bin(self, op_name, tensor_type, idx, tensor, dump_data_array):
"""
Save tensor file in `bin` format.

Args:
op_name (str): Operator name without scope.
tensor_type (str): The tensor type of the operator, `input` or `output`.
idx (int): Tensor slot index.
tensor (TensorProto): Tensor data in proto format.
dump_data_array (numpy.array): Tensor data in numpy format.

Returns:
str, output tensor file name.
"""
out_file_name = "%s.%s.%d.%s.%s.bin" % (
op_name,
tensor_type,
idx,
self.hisi_utils.get_string_from_list(dump_data_array.shape, 'x'),
self.hisi_common.get_format_string(tensor.format),
)
out_path = os.path.join(self.args_parser.output_path, out_file_name)
dump_data_array.tofile(out_path)


class FileMapping:
"""Mapping op pattern to files."""

def __init__(self, data_loader):
self.data_loader = data_loader
self.output_path = Path(data_loader.get_net_dir()).absolute()

def find_tensor_file(self, pattern, device_ids=None, iterations=None):
"""
Find tensor files.

Args:
pattern (str): File name pattern.
device_ids (Union[None, list[int]]): Filter condition of device id. Default: None.
iterations (Union[None, list[int]]): Filter condition of iteration id. Default: None.

Returns:
dict, file paths.
"""
op_path = OpPathManager(pattern)
if device_ids is None:
device_dirs = self.output_path.glob('device_[0-9]*')
else:
device_dirs = []
for device_id in device_ids:
device_dirs.append(self.output_path.joinpath(f'device_{device_id}'))

for device_dir in device_dirs:
op_device_obj = self.find_tensor_file_per_device(pattern, device_dir, iterations)
op_path.add(op_device_obj)
return op_path

def find_tensor_file_per_device(self, pattern, device_dir, iterations):
"""
Find tensor files per device directory.

Args:
pattern (str): File name pattern.
device_dir (Union[Path, str]): Directory path.
iterations (Union[None, list[int]]): Filter condition of iteration id. Default: None.

Returns:
OpDevicePath, operator file path object of one device.
"""
device_dir = Path(device_dir)
# device_name is like `device_{device_id}`
device_id = int(device_dir.name.split('_')[-1])
op_device_obj = OpDevicePath(device_id)

def _find_by_iter_dirs(dirs):
for iter_dir in dirs:
op_path_per_iter = self.find_tensor_file_per_iter(pattern, iter_dir)
op_device_obj.add(op_path_per_iter)

if iterations is None:
iter_dirs = device_dir.glob('*_graph_[0-9]*/[0-9]*/[0-9]*')
_find_by_iter_dirs(iter_dirs)
else:
for iteration in iterations:
iter_dirs = device_dir.glob(f'*_graph_[0-9]*/[0-9]*/{iteration}')
_find_by_iter_dirs(iter_dirs)
return op_device_obj

@staticmethod
def find_tensor_file_per_iter(pattern, iter_dir):
"""
Find tensor files per iteration directory.

Args:
pattern (str): File name pattern.
iter_dir (Union[Path, str]): Iteration path.

Returns:
OpPath, the operator file path object of one iteration.
"""
dir_path = Path(iter_dir)

def _get_file_generator(tensor_type):
return dir_path.glob(f'*{pattern}.*{tensor_type}.[0-9]*.npy')

in_gen = _get_file_generator('input')
out_gen = _get_file_generator('output')
iteration = int(dir_path.name)
op_path_obj = OpPath(iteration, in_gen, out_gen)
return op_path_obj


class OpPathManager:
"""The manager of tensor files of one operator."""

def __init__(self, pattern, op_full_name=None):
self.pattern = pattern
self.op_full_name = op_full_name
self._op_path = {}

@property
def devices(self):
"""Get list of iterations in cache."""
return list(self._op_path.keys())

def add(self, op_device_path):
"""Add OpDevicePath object."""
self._op_path[op_device_path.device_id] = op_device_path

def device(self, device_id):
"""Get OpDevicePath object according to device id."""
return self._op_path.get(device_id)

def to_dict(self):
"""Get operator files of all devices in dict format."""
res = {}
for device_id, op_path in self._op_path.items():
key = f'device_{device_id}'
res[key] = op_path.to_dict()
return res


class OpDevicePath:
"""The operator file object of specific device."""

def __init__(self, device_id):
self._device_id = device_id
# record the operation path object of different iteration
# the format is like <int, OpPath>
self._op_path = {}

@property
def device_id(self):
"""The property of device id."""
return self._device_id

@property
def iterations(self):
"""Get list of iterations in cache."""
return list(self._op_path.keys())

def iteration(self, iteration):
"""Get the op path object according to iteration."""
return self._op_path.get(iteration)

def add(self, op_path):
"""Add OpPath object."""
self._op_path[op_path.iteration] = op_path

def to_dict(self):
"""Get operator files of one device in dict format."""
res = {}
for iteration, op_path in self._op_path.items():
res[iteration] = op_path.to_dict()
return res


class OpPath:
"""The operator file object of specific iteration."""

def __init__(self, iteration, input_gen, output_gen):
self._iter = iteration
self._input_files = None
self._input_gen = input_gen
self._output_files = None
self._output_gen = output_gen

@staticmethod
def _convert_path_gen_to_list(path_gen):
"""Convert generator of Path.glob to list of string."""
return [str(path) for path in path_gen]

@property
def input(self):
"""The list of input tensor files."""
if self._input_files is None:
self._input_files = self._convert_path_gen_to_list(self._input_gen)
return self._input_files

@property
def output(self):
"""The list of output tensor file paths."""
if self._output_files is None:
self._output_files = self._convert_path_gen_to_list(self._output_gen)
return self._output_files

@property
def iteration(self):
"""The iteration of the tensor file."""
return self._iter

def to_dict(self):
"""Get operator files of one iteration in dict format."""
res = {
'input': self.input,
'output': self.output
}
return res

+ 2
- 2
mindinsight/debugger/stream_cache/data_loader.py View File

@@ -24,7 +24,7 @@ from mindinsight.debugger.common.utils import DumpSettings
class DataLoader:
"""The DataLoader object provides interface to load graphs and device information from base_dir."""
def __init__(self, base_dir):
self._debugger_base_dir = base_dir
self._debugger_base_dir = os.path.realpath(base_dir)
self._graph_protos = []
self._device_info = {}
self._step_num = {}
@@ -46,7 +46,7 @@ class DataLoader:
if dump_config.get(DumpSettings.E2E_DUMP_SETTINGS.value) and \
dump_config[DumpSettings.E2E_DUMP_SETTINGS.value]['enable']:
self._is_sync = True
self._net_dir = os.path.join(self._debugger_base_dir, self._net_name)
self._net_dir = os.path.realpath(os.path.join(self._debugger_base_dir, self._net_name))
elif dump_config.get(DumpSettings.ASYNC_DUMP_SETTINGS.value) and \
dump_config[DumpSettings.ASYNC_DUMP_SETTINGS.value]['enable']:
self._is_sync = False


Loading…
Cancel
Save