Compare commits

...

42 Commits
master ... r0.3

Author SHA1 Message Date
  mindspore-ci-bot 488cfaefa2 !761 enhance float cmp in tests.lineagemgr in r0.3, update securec repository link 5 years ago
  luopengting 3a503bd25a enhance float cmp in tests.lineagemgr, fix probabilistic failure in st 5 years ago
  mindspore-ci-bot a0b6053d57 !221 Fix the bug of selecting train jobs in comparison dashboard 5 years ago
  ph ccbc457631 fix all select issue 5 years ago
  mindspore-ci-bot 64154380ec !213 Modify profiler api comments for the docs 5 years ago
  wangyue01 4218da33d0 Fix profiler api comments for the docs 5 years ago
  mindspore-ci-bot 359a80bb80 !210 Add set context rule for Profiler example 5 years ago
  mindspore-ci-bot 5a138ce198 !190 fix the bug that when the profiler parameter subgraph is Default or Gradients, the profiler analyse will raise an exception 5 years ago
  mindspore-ci-bot b46f83e0ca !192 fix the aicpu profiling file does not exist thrown the exception 5 years ago
  wangyue01 4cf140c97b Add set context rule in Profiler example 5 years ago
  mindspore-ci-bot 8f8d12f4cf !205 fix bugs of multi-scalars comparision under concurrency [r0.3] 5 years ago
  mindspore-ci-bot e8430261f4 !196 record valid user_defined items when raise_Exception is False 5 years ago
  mindspore-ci-bot c7d19107de !209 add parameter validation for train job caches api [r0.3] 5 years ago
  mindspore-ci-bot de795dac9b !206 fix bug: [Bug]{MI]{Profiler]device_target not verified in profilling 5 years ago
  mindspore-ci-bot cd2285dd7b !172 fix bugs of summary watcher when discovering profiler directory 5 years ago
  liangyongxiong 52bbe55ecf add paramater validation for train job caches api 5 years ago
  mindspore-ci-bot 0059d26eaf !186 update readme for mindconverter 5 years ago
  mindspore-ci-bot 16fcb03446 !187 more clear report on mindconverter 5 years ago
  mindspore-ci-bot 34a9026ea8 !204 fixed bug of the profiler page 5 years ago
  mindspore-ci-bot 917895dfd7 !203 UI fix model trace and compare plate issue 5 years ago
  WeibiaoYu 8310927932 fix bug: [Bug]{MI]{Profiler]device_target not verified in profilling 5 years ago
  mindspore-ci-bot 9222301212 !183 change version number to 0.3.0 5 years ago
  mindspore-ci-bot 744eb94ea5 !197 add 0.3.0-alpha release note 5 years ago
  mindspore-ci-bot f44799c414 !201 fix bug: [Profiler]Profiler not read the actual device id in training 5 years ago
  mindspore-ci-bot 3cd77e8b53 !180 lineagemgr: fix EvalLineage comments 5 years ago
  quyongxiu1 59af274f28 update readme 5 years ago
  quyongxiu1 5e49d578f4 convert report fix r0.3 5 years ago
  liangyongxiong eaa4497da6 fix bug of multi-scalars comparision under concurrency 5 years ago
  WeibiaoYu e11183b2d6 Fix issue: Profiler not read the actual device id in training 5 years ago
  WeiFeng 87e03965e5 profiler 5 years ago
  chenchao99 b381439349 fix the bug that when the profiler parameter subgraph is Default or Gradients, the profiler analyse will raise an exception 5 years ago
  ph bd33e306b4 fix issue 5 years ago
  kouzhenzhong d5b397bf69 add 0.3.0-alpha release note 5 years ago
  liangyongxiong cbd50af8ad fix bugs of summary watcher when discovering profiler directory 5 years ago
  kouzhenzhong 6d63c2a688 change version to 0.3.0 5 years ago
  kouzhenzhong b7c77d55ad lineagemgr: fix EvalLineage comments 5 years ago
  luopengting d15c81c699 remove invalid item in user_defined info, record valid items 5 years ago
  mindspore-ci-bot c4fc9bfb4f !195 fix CI issue 5 years ago
  ph 167ab2ed6e fix ci issue 5 years ago
  root 817dccee1d fix the exception for file don't exist 5 years ago
  mindspore-ci-bot e1d627ff77 !176 fix the float compare in lineage ut 5 years ago
  luopengting 618f8c8ccf fix the float compare because ci env update 5 years ago
31 changed files with 685 additions and 321 deletions
Unified View
  1. +1
    -1
      .gitmodules
  2. +30
    -1
      RELEASE.md
  3. +1
    -1
      mindinsight/_version.py
  4. +5
    -6
      mindinsight/datavisual/data_transform/summary_watcher.py
  5. +41
    -20
      mindinsight/datavisual/processors/scalars_processor.py
  6. +12
    -0
      mindinsight/datavisual/processors/train_task_manager.py
  7. +3
    -2
      mindinsight/lineagemgr/collection/model/model_lineage.py
  8. +7
    -4
      mindinsight/lineagemgr/common/validator/validate.py
  9. +22
    -13
      mindinsight/mindconverter/README.md
  10. +103
    -37
      mindinsight/mindconverter/converter.py
  11. +10
    -8
      mindinsight/profiler/README.md
  12. +1
    -1
      mindinsight/profiler/__init__.py
  13. +2
    -0
      mindinsight/profiler/analyser/analyser.py
  14. +2
    -1
      mindinsight/profiler/parser/aicpu_data_parser.py
  15. +31
    -7
      mindinsight/profiler/profiling.py
  16. +1
    -1
      mindinsight/ui/babel.config.js
  17. +0
    -1
      mindinsight/ui/package.json
  18. +18
    -13
      mindinsight/ui/src/components/multiselectGroup.vue
  19. +1
    -1
      mindinsight/ui/src/locales/zh-cn.json
  20. +4
    -0
      mindinsight/ui/src/store.js
  21. +79
    -45
      mindinsight/ui/src/views/train-manage/data-traceback.vue
  22. +66
    -38
      mindinsight/ui/src/views/train-manage/model-traceback.vue
  23. +11
    -17
      mindinsight/ui/src/views/train-manage/profiler.vue
  24. +26
    -24
      tests/st/func/lineagemgr/api/test_model_api.py
  25. +5
    -6
      tests/st/func/lineagemgr/cache/test_lineage_cache.py
  26. +16
    -6
      tests/st/func/lineagemgr/collection/model/test_model_lineage.py
  27. +30
    -23
      tests/ut/lineagemgr/querier/event_data.py
  28. +20
    -20
      tests/ut/lineagemgr/querier/test_querier.py
  29. +37
    -23
      tests/ut/lineagemgr/querier/test_query_model.py
  30. +26
    -1
      tests/ut/profiler/analyser/test_analyser_aicore_detail.py
  31. +74
    -0
      tests/utils/tools.py

+ 1
- 1
.gitmodules View File

@@ -1,3 +1,3 @@
[submodule "third_party/securec"] [submodule "third_party/securec"]
path = third_party/securec path = third_party/securec
url = https://gitee.com/openeuler/bounds_checking_function.git
url = https://gitee.com/openeuler/libboundscheck.git

+ 30
- 1
RELEASE.md View File

@@ -1,5 +1,34 @@
## MindInsight ## MindInsight


# Release 0.3.0-alpha

## Major Features and Improvements
* Profiling
* Provide easy to use apis for profiling start/stop and profiling data analyse (on Ascend only).
* Provide operators performance display and analysis on MindInsight UI.
* Large scale network computation graph visualization.
* Optimize summary record implementation and improve its performance.
* Improve lineage usability
* Optimize lineage display and enrich tabular operation.
* Decouple lineage callback from `SummaryRecord`.
* Support scalar compare of multiple runs.
* Scripts conversion from other frameworks
* Support for converting PyTorch scripts within TorchVision to MindSpore scripts automatically.
## Bugfixes
* Fix pb files loaded problem when files are modified at the same time ([!53](https://gitee.com/mindspore/mindinsight/pulls/53)).
* Fix load data thread stuck in `LineageCacheItemUpdater` ([!114](https://gitee.com/mindspore/mindinsight/pulls/114)).
* Fix samples from previous steps erased due to tags size too large problem ([!86](https://gitee.com/mindspore/mindinsight/pulls/86)).
* Fix image and histogram event package error ([!1143](https://gitee.com/mindspore/mindspore/pulls/1143)).
* Equally distribute histogram ignoring actual step number to avoid large white space ([!66](https://gitee.com/mindspore/mindinsight/pulls/66)).

## Thanks to our Contributors
Thanks goes to these wonderful people:

Chao Chen, Congli Gao, Ye Huang, Weifeng Huang, Zhenzhong Kou, Hongzhang Li, Longfei Li, Yongxiong Liang, Pengting Luo, Yanming Miao, Gongchang Ou, Yongxiu Qu, Hui Pan, Luyu Qiu, Junyan Qin, Kai Wen, Weining Wang, Yue Wang, Zhuanke Wu, Yifan Xia, Weibiao Yu, Ximiao Yu, Ting Zhao, Jianfeng Zhu.

Contributions of any kind are welcome!

# Release 0.2.0-alpha # Release 0.2.0-alpha


## Major Features and Improvements ## Major Features and Improvements
@@ -14,7 +43,7 @@ Now you can use [`HistogramSummary`](https://www.mindspore.cn/api/zh-CN/master/a
* Fix unsafe functions and duplication files and redundant codes ([!14](https://gitee.com/mindspore/mindinsight/pulls/14)). * Fix unsafe functions and duplication files and redundant codes ([!14](https://gitee.com/mindspore/mindinsight/pulls/14)).
* Fix sha256 checksum missing bug ([!24](https://gitee.com/mindspore/mindinsight/pulls/24)). * Fix sha256 checksum missing bug ([!24](https://gitee.com/mindspore/mindinsight/pulls/24)).
* Fix graph bug when node name is empty ([!34](https://gitee.com/mindspore/mindinsight/pulls/34)). * Fix graph bug when node name is empty ([!34](https://gitee.com/mindspore/mindinsight/pulls/34)).
* Fix start/stop command exit-code incorrect ([!44](https://gitee.com/mindspore/mindinsight/pulls/44)).
* Fix start/stop command error code incorrect ([!44](https://gitee.com/mindspore/mindinsight/pulls/44)).


## Thanks to our Contributors ## Thanks to our Contributors
Thanks goes to these wonderful people: Thanks goes to these wonderful people:


+ 1
- 1
mindinsight/_version.py View File

@@ -14,4 +14,4 @@
# ============================================================================ # ============================================================================
"""Mindinsight version module.""" """Mindinsight version module."""


VERSION = '0.2.0'
VERSION = '0.3.0'

+ 5
- 6
mindinsight/datavisual/data_transform/summary_watcher.py View File

@@ -257,12 +257,11 @@ class SummaryWatcher:
'mtime': mtime, 'mtime': mtime,
} }
if relative_path not in summary_dict:
summary_dict[relative_path] = {
'ctime': ctime,
'mtime': mtime,
'profiler': profiler,
}
summary_dict[relative_path] = {
'ctime': ctime,
'mtime': mtime,
'profiler': profiler,
}
def is_summary_directory(self, summary_base_dir, relative_path): def is_summary_directory(self, summary_base_dir, relative_path):
""" """


+ 41
- 20
mindinsight/datavisual/processors/scalars_processor.py View File

@@ -16,8 +16,10 @@
from urllib.parse import unquote from urllib.parse import unquote


from mindinsight.utils.exceptions import ParamValueError, UrlDecodeError from mindinsight.utils.exceptions import ParamValueError, UrlDecodeError
from mindinsight.datavisual.common.log import logger
from mindinsight.datavisual.utils.tools import if_nan_inf_to_none from mindinsight.datavisual.utils.tools import if_nan_inf_to_none
from mindinsight.datavisual.common.exceptions import ScalarNotExistError from mindinsight.datavisual.common.exceptions import ScalarNotExistError
from mindinsight.datavisual.common.exceptions import TrainJobNotExistError
from mindinsight.datavisual.common.validation import Validation from mindinsight.datavisual.common.validation import Validation
from mindinsight.datavisual.processors.base_processor import BaseProcessor from mindinsight.datavisual.processors.base_processor import BaseProcessor


@@ -71,25 +73,44 @@ class ScalarsProcessor(BaseProcessor):


scalars = [] scalars = []
for train_id in train_ids: for train_id in train_ids:
for tag in tags:
try:
tensors = self._data_manager.list_tensors(train_id, tag)
except ParamValueError:
continue

scalar = {
'train_id': train_id,
'tag': tag,
'values': [],
}

for tensor in tensors:
scalar['values'].append({
'wall_time': tensor.wall_time,
'step': tensor.step,
'value': if_nan_inf_to_none('scalar_value', tensor.value),
})

scalars.append(scalar)
scalars += self._get_train_scalars(train_id, tags)

return scalars

def _get_train_scalars(self, train_id, tags):
"""
Get scalar data for given train_id and tags.

Args:
train_id (str): Specify train job ID.
tags (list): Specify list of tags.

Returns:
list[dict], a list of dictionaries containing the `wall_time`, `step`, `value` for each scalar.
"""
scalars = []
for tag in tags:
try:
tensors = self._data_manager.list_tensors(train_id, tag)
except ParamValueError:
continue
except TrainJobNotExistError:
logger.warning('Can not find the given train job in cache.')
return []

scalar = {
'train_id': train_id,
'tag': tag,
'values': [],
}

for tensor in tensors:
scalar['values'].append({
'wall_time': tensor.wall_time,
'step': tensor.step,
'value': if_nan_inf_to_none('scalar_value', tensor.value),
})

scalars.append(scalar)


return scalars return scalars

+ 12
- 0
mindinsight/datavisual/processors/train_task_manager.py View File

@@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
"""Train task manager.""" """Train task manager."""


from mindinsight.utils.exceptions import ParamTypeError
from mindinsight.datavisual.common.log import logger from mindinsight.datavisual.common.log import logger
from mindinsight.datavisual.common import exceptions from mindinsight.datavisual.common import exceptions
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
@@ -141,9 +142,20 @@ class TrainTaskManager(BaseProcessor):


Returns: Returns:
dict, indicates train job ID and its current cache status. dict, indicates train job ID and its current cache status.

Raises:
ParamTypeError, if the given train_ids parameter is not in valid type.
""" """
if not isinstance(train_ids, list):
logger.error("train_ids must be list.")
raise ParamTypeError('train_ids', list)

cache_result = [] cache_result = []
for train_id in train_ids: for train_id in train_ids:
if not isinstance(train_id, str):
logger.error("train_id must be str.")
raise ParamTypeError('train_id', str)

try: try:
train_job = self._data_manager.get_train_job(train_id) train_job = self._data_manager.get_train_job(train_id)
except exceptions.TrainJobNotExistError: except exceptions.TrainJobNotExistError:


+ 3
- 2
mindinsight/lineagemgr/collection/model/model_lineage.py View File

@@ -236,6 +236,7 @@ class EvalLineage(Callback):
""" """
Collect lineage of an evaluation job. Collect lineage of an evaluation job.


Args:
summary_record (Union[SummaryRecord, str]): The `SummaryRecord` object which summary_record (Union[SummaryRecord, str]): The `SummaryRecord` object which
is used to record the summary value(see mindspore.train.summary.SummaryRecord), is used to record the summary value(see mindspore.train.summary.SummaryRecord),
or a log dir(as a `str`) to be passed to `LineageSummary` to create or a log dir(as a `str`) to be passed to `LineageSummary` to create
@@ -284,8 +285,8 @@ class EvalLineage(Callback):
self.lineage_summary = LineageSummary(self.lineage_log_dir) self.lineage_summary = LineageSummary(self.lineage_log_dir)


self.user_defined_info = user_defined_info self.user_defined_info = user_defined_info
if user_defined_info:
validate_user_defined_info(user_defined_info)
if self.user_defined_info:
validate_user_defined_info(self.user_defined_info)


except MindInsightException as err: except MindInsightException as err:
log.error(err) log.error(err)


+ 7
- 4
mindinsight/lineagemgr/common/validator/validate.py View File

@@ -410,7 +410,7 @@ def validate_path(summary_path):


def validate_user_defined_info(user_defined_info): def validate_user_defined_info(user_defined_info):
""" """
Validate user defined info.
Validate user defined info, delete the item if its key is in lineage.


Args: Args:
user_defined_info (dict): The user defined info. user_defined_info (dict): The user defined info.
@@ -437,10 +437,13 @@ def validate_user_defined_info(user_defined_info):


field_map = set(FIELD_MAPPING.keys()) field_map = set(FIELD_MAPPING.keys())
user_defined_keys = set(user_defined_info.keys()) user_defined_keys = set(user_defined_info.keys())
all_keys = field_map | user_defined_keys
insertion = list(field_map & user_defined_keys)


if len(field_map) + len(user_defined_keys) != len(all_keys):
raise LineageParamValueError("There are some keys have defined in lineage.")
if insertion:
for key in insertion:
user_defined_info.pop(key)
raise LineageParamValueError("There are some keys have defined in lineage. "
"Duplicated key(s): %s. " % insertion)




def validate_train_id(relative_path): def validate_train_id(relative_path):


+ 22
- 13
mindinsight/mindconverter/README.md View File

@@ -2,13 +2,6 @@


MindConverter is a tool that converting PyTorch scripts to MindSpore scripts. With minial manual editing and the guidance from conversion reports, users may easily migrate their model from PyTorch framework to MindSpore. MindConverter is a tool that converting PyTorch scripts to MindSpore scripts. With minial manual editing and the guidance from conversion reports, users may easily migrate their model from PyTorch framework to MindSpore.




### System Requirements

* PyTorch v1.5.0
* MindSpore v0.2.0

### Installation ### Installation


This tool is part of MindInsight and accessible to users after installing MindInsight, no extra installation is needed. This tool is part of MindInsight and accessible to users after installing MindInsight, no extra installation is needed.
@@ -24,8 +17,6 @@ mindconverter commandline usage:
mindconverter [-h] [--version] --in_file IN_FILE [--output OUTPUT] mindconverter [-h] [--version] --in_file IN_FILE [--output OUTPUT]
[--report REPORT] [--report REPORT]


MindConverter CLI entry point (version: 0.2.0)

optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
--version show program's version number and exit --version show program's version number and exit
@@ -36,13 +27,31 @@ optional arguments:
directorys directorys
``` ```


Usage example:
#### Use example:

We have a collection of PyTorch model scripts
```buildoutcfg
~$ ls
models
~$ ls models
alexnet.py resnet.py vgg.py
```

Then we set the PYTHONPATH environment variable and convert alexnet.py
```buildoutcfg
~$ export PYTHONPATH=~/models
~$ mindconverter --in_file models/alexnet.py
```

Then we will see a conversion report and the output MindSpore script
```buildoutcfg ```buildoutcfg
export PYTHONPATH=~/my_pt_proj/models
mindconverter --in_file lenet.py
~$ ls
alexnet_report.txt models output
~$ ls output
alexent.py
``` ```


Since the conversion is not 100% flawless, we encourage users to checkout the reports when fixing issues of the converted scripts.
Since the conversion is not 100% flawless, we encourage users to checkout the report when fixing issues of the converted script.




### Unsupported Situation #1 ### Unsupported Situation #1


+ 103
- 37
mindinsight/mindconverter/converter.py View File

@@ -28,6 +28,8 @@ from mindinsight.mindconverter.config import ALL_UNSUPPORTED
from mindinsight.mindconverter.common.log import logger from mindinsight.mindconverter.common.log import logger
from mindinsight.mindconverter.forward_call import ForwardCall from mindinsight.mindconverter.forward_call import ForwardCall


LINE_NO_INDEX_DIFF = 1



class Converter: class Converter:
"""Convert class""" """Convert class"""
@@ -197,6 +199,7 @@ class Converter:
raise ValueError('"(" not found, {} should work with "("'.format(call_name)) raise ValueError('"(" not found, {} should work with "("'.format(call_name))
right = self.find_right_parentheses(code, left) right = self.find_right_parentheses(code, left)
end = right end = right

expr = code[start:end + 1] expr = code[start:end + 1]
args_str = code[left:right + 1] args_str = code[left:right + 1]


@@ -336,6 +339,96 @@ class Converter:
mapping.update(convert_fun(*args)) mapping.update(convert_fun(*args))
return mapping return mapping


@staticmethod
def get_code_start_line_num(source_lines):
"""
Get the start code line number exclude comments.

Args:
source_lines (list[str]): Split results of original code.

Returns:
int, the start line number.
"""
stack = []
index = 0
for i, line in enumerate(source_lines):
if line.strip().startswith('#'):
continue
if line.strip().startswith('"""'):
if not line.endswith('"""\n'):
stack.append('"""')
continue
if line.strip().startswith("'''"):
if not line.endswith("'''\n"):
stack.append("'''")
continue
if line.endswith('"""\n') or line.endswith("'''\n"):
stack.pop()
continue
if line.strip() != '' and not stack:
index = i
break
return index

def update_code_and_convert_info(self, code, mapping):
"""
Replace code according to mapping, and update convert info.

Args:
code (str): The code to replace.
mapping (dict): Mapping for original code and the replaced code.

Returns:
str, the replaced code.
"""

for key, value in mapping.items():
code = code.replace(key, value)

source_lines = code.splitlines(keepends=True)
start_line_number = self.get_code_start_line_num(source_lines)
add_import_infos = ['import mindspore\n',
'import mindspore.nn as nn\n',
'import mindspore.ops.operations as P\n']
for i, add_import_info in enumerate(add_import_infos):
source_lines.insert(start_line_number + i, add_import_info)
self.convert_info += '[Add Import] {}.\n'.format(add_import_info.strip())

insert_count = len(add_import_infos)
line_diff = insert_count - LINE_NO_INDEX_DIFF

for i in range(start_line_number + insert_count, len(source_lines)):
line = source_lines[i]

if (line.startswith('from torch') and 'import' in line) or line.startswith('import torch'):
new_line = '# ' + line
source_lines[i] = new_line
self.convert_info += '[Annotate][Line{:3d}] {} is annotated.\n'.format(i - line_diff, line.strip())
if line.strip().startswith('class') and '(nn.Module)' in line:
new_line = line.replace('nn.Module', 'nn.Cell')
source_lines[i] = new_line
self.convert_info += '[Convert][Line{:3d}] nn.Module is converted.\n'.format(i - line_diff)
if line.strip().startswith('def forward('):
new_line = line.replace('forward', 'construct')
source_lines[i] = new_line
self.convert_info += '[Convert][Line{:3d}] forward is converted.\n'.format(i - line_diff)
if 'nn.Linear' in line:
new_line = line.replace('nn.Linear', 'nn.Dense')
source_lines[i] = new_line
self.convert_info += '[Convert][Line{:3d}] nn.Linear is converted.\n'.format(i - line_diff)
if '(nn.Sequential)' in line:
new_line = line.replace('nn.Sequential', 'nn.SequentialCell')
source_lines[i] = new_line
self.convert_info += '[Convert][Line{:3d}] nn.Sequential is converted.\n'.format(i - line_diff)
if 'nn.init.' in line:
new_line = line.replace('nn.init', 'pass # nn.init')
source_lines[i] = new_line
self.convert_info += '[Annotate][Line{:3d}] {} is annotated.\n'.format(i - line_diff, 'nn.init')

code = ''.join(source_lines)
return code

def convert(self, import_name, output_dir, report_dir): def convert(self, import_name, output_dir, report_dir):
""" """
Convert a module's code, code converted will be save in output_dir, and a report will be save in report_dir. Convert a module's code, code converted will be save in output_dir, and a report will be save in report_dir.
@@ -346,10 +439,10 @@ class Converter:
report_dir (str): The path to save report file. report_dir (str): The path to save report file.
""" """
logger.info("Start converting %s", import_name) logger.info("Start converting %s", import_name)
self.convert_info += '[Start Convert]\nThe module is {}\n'.format(import_name)
start_info = '[Start Convert]\n'
module_info = 'The module is {}.\n'.format(import_name)


import_mod = importlib.import_module(import_name) import_mod = importlib.import_module(import_name)

srcfile = inspect.getsourcefile(import_mod) srcfile = inspect.getsourcefile(import_mod)
logger.info("Script file is %s", srcfile) logger.info("Script file is %s", srcfile)


@@ -358,40 +451,14 @@ class Converter:


# replace python function under nn.Module # replace python function under nn.Module
mapping = self.get_mapping(import_mod, forward_list) mapping = self.get_mapping(import_mod, forward_list)

code = inspect.getsource(import_mod) code = inspect.getsource(import_mod)
for key, value in mapping.items():
code = code.replace(key, value)

code = 'import mindspore.ops.operations as P\n' + code
code = 'import mindspore.nn as nn\n' + code
code = 'import mindspore\n' + code

self.convert_info += '||[Import Add] Add follow import sentences:\n'
self.convert_info += 'import mindspore.ops.operations as P\n'
self.convert_info += 'import mindspore.nn as nn\n'
self.convert_info += 'import mindspore\n\n'

code = code.replace('import torch', '# import torch')
code = code.replace('from torch', '# from torch')
code = code.replace('(nn.Module):', '(nn.Cell):')
code = code.replace('forward(', 'construct(')
code = code.replace('nn.Linear', 'nn.Dense')
code = code.replace('(nn.Sequential)', '(nn.SequentialCell)')
code = code.replace('nn.init.', 'pass # nn.init.')

self.convert_info += '||[Import Annotated] Annotated follow import sentences:\n'
self.convert_info += 'import sentence on torch as follows are annotated:\n'
self.convert_info += 'import torch\n'
self.convert_info += 'from torch ...\n'

self.convert_info += '||[Explicit Convert] Module or function are explicitly converted as follows:\n'
self.convert_info += '[nn.Module] is converted to [nn.Cell]\n'
self.convert_info += '[forward] is converted to [construct]\n'
self.convert_info += '[nn.Linear] is converted to [nn.Dense]\n'
self.convert_info += '[nn.Sequential] is converted to [nn.SequentialCell]\n'
self.convert_info += '[nn.init] is not converted and annotated\n'
self.convert_info += '[Convert over]'
code = self.update_code_and_convert_info(code, mapping)
convert_info_split = self.convert_info.splitlines(keepends=True)
convert_info_split = sorted(convert_info_split)
convert_info_split.insert(0, start_info)
convert_info_split.insert(1, module_info)
convert_info_split.append('[Convert Over]')
self.convert_info = ''.join(convert_info_split)


dest_file = os.path.join(output_dir, os.path.basename(srcfile)) dest_file = os.path.join(output_dir, os.path.basename(srcfile))
with os.fdopen(os.open(dest_file, self.flags, self.modes), 'w') as file: with os.fdopen(os.open(dest_file, self.flags, self.modes), 'w') as file:
@@ -428,7 +495,6 @@ def _path_split(file):


Returns: Returns:
list[str], list of file tail list[str], list of file tail

""" """
file_dir, name = os.path.split(file) file_dir, name = os.path.split(file)
if file_dir: if file_dir:
@@ -456,6 +522,6 @@ def main(files_config):
module_name = '.'.join(in_file_split) module_name = '.'.join(in_file_split)
convert_ins.convert(module_name, files_config['outfile_dir'], files_config['report_dir']) convert_ins.convert(module_name, files_config['outfile_dir'], files_config['report_dir'])


in_module = files_config['in_module']
in_module = files_config.get('in_module')
if in_module: if in_module:
convert_ins.convert(in_module, files_config['outfile_dir'], files_config['report_dir']) convert_ins.convert(in_module, files_config['outfile_dir'], files_config['report_dir'])

+ 10
- 8
mindinsight/profiler/README.md View File

@@ -12,16 +12,18 @@ The Profiler enables users to:
To enable profiling on MindSpore, the MindInsight Profiler apis should be added to the script: To enable profiling on MindSpore, the MindInsight Profiler apis should be added to the script:


1. Import MindInsight Profiler 1. Import MindInsight Profiler
```
from mindinsight.profiler import Profiler from mindinsight.profiler import Profiler
2. Initialize the Profiler before training
```
2. Initialize the Profiler after set context, and before the network initialization.


Example: Example:
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=int(os.environ["DEVICE_ID"]))
profiler = Profiler(output_path="./data", is_detail=True, is_show_op_path=False, subgraph='All') profiler = Profiler(output_path="./data", is_detail=True, is_show_op_path=False, subgraph='All')
Parameters including:
net = Net()
Parameters of Profiler including:
subgraph (str): Defines which subgraph to monitor and analyse, can be 'all', 'Default', 'Gradients'. subgraph (str): Defines which subgraph to monitor and analyse, can be 'all', 'Default', 'Gradients'.
is_detail (bool): Whether to show profiling data for op_instance level, only show optype level if False. is_detail (bool): Whether to show profiling data for op_instance level, only show optype level if False.
@@ -31,9 +33,9 @@ To enable profiling on MindSpore, the MindInsight Profiler apis should be added
will deal with all op if null. will deal with all op if null.
optypes_not_deal (list): Op type names, the data of which optype will not be collected and analysed. optypes_not_deal (list): Op type names, the data of which optype will not be collected and analysed.


3. Call Profiler.analyse() at the end of the program
3. Call ```Profiler.analyse()``` at the end of the program


Profiler.analyse() will collect profiling data and generate the analysis results.
```Profiler.analyse()``` will collect profiling data and generate the analysis results.


After training, we can open MindInsight UI to analyse the performance. After training, we can open MindInsight UI to analyse the performance.




+ 1
- 1
mindinsight/profiler/__init__.py View File

@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
""" """
Profiler Module Introduction
Profiler Module Introduction.


This module provides Python APIs to enable the profiling of MindSpore neural networks. This module provides Python APIs to enable the profiling of MindSpore neural networks.
Users can import the mindinsight.profiler.Profiler, initialize the Profiler object to start profiling, Users can import the mindinsight.profiler.Profiler, initialize the Profiler object to start profiling,


+ 2
- 0
mindinsight/profiler/analyser/analyser.py View File

@@ -124,6 +124,8 @@ class AicoreDetailAnalyser(BaseAnalyser):
result = [] result = []
for op_type in op_type_order: for op_type in op_type_order:
detail_infos = type_detail_cache.get(op_type) detail_infos = type_detail_cache.get(op_type)
if detail_infos is None:
continue
detail_infos.sort(key=lambda item: item[2], reverse=True) detail_infos.sort(key=lambda item: item[2], reverse=True)
result.extend(detail_infos) result.extend(detail_infos)




+ 2
- 1
mindinsight/profiler/parser/aicpu_data_parser.py View File

@@ -15,6 +15,7 @@
""" """
The parser for AI CPU preprocess data. The parser for AI CPU preprocess data.
""" """
import os


from tabulate import tabulate from tabulate import tabulate


@@ -50,7 +51,7 @@ class DataPreProcessParser:
def execute(self): def execute(self):
"""Execute the parser, get result data, and write it to the output file.""" """Execute the parser, get result data, and write it to the output file."""


if self._source_file_name is None:
if not os.path.exists(self._source_file_name):
logger.info("Did not find the aicpu profiling source file") logger.info("Did not find the aicpu profiling source file")
return return




+ 31
- 7
mindinsight/profiler/profiling.py View File

@@ -37,19 +37,21 @@ class Profiler:
""" """
Performance profiling API. Performance profiling API.


Enable MindSpore users to profile the neural network.
Enable MindSpore users to profile the performance of neural network.


Args: Args:
subgraph (str): Defines which subgraph to monitor and analyse, can be 'all', 'Default', 'Gradients'.
subgraph (str): Define which subgraph to monitor and analyse, can be 'all', 'Default', 'Gradients'.
is_detail (bool): Whether to show profiling data for op_instance level, only show optype level if False. is_detail (bool): Whether to show profiling data for op_instance level, only show optype level if False.
is_show_op_path (bool): Whether to save the full path for each op instance. is_show_op_path (bool): Whether to save the full path for each op instance.
output_path (str): Output data path. output_path (str): Output data path.
optypes_to_deal (list): Op type names, the data of which optype should be collected and analysed,
will deal with all op if null.
optypes_not_deal (list): Op type names, the data of which optype will not be collected and analysed.
optypes_to_deal (list[str]): Op type names, the data of which optype should be collected and analysed,
will deal with all op if null.
optypes_not_deal (list[str]): Op type names, the data of which optype will not be collected and analysed.


Examples: Examples:
>>> from mindinsight.profiler import Profiler >>> from mindinsight.profiler import Profiler
>>> context.set_context(mode=context.GRAPH_MODE, device_target=“Ascend”,
>>> device_id=int(os.environ["DEVICE_ID"]))
>>> profiler = Profiler(subgraph='all', is_detail=True, is_show_op_path=False, output_path='./data') >>> profiler = Profiler(subgraph='all', is_detail=True, is_show_op_path=False, output_path='./data')
>>> model = Model(train_network) >>> model = Model(train_network)
>>> dataset = get_dataset() >>> dataset = get_dataset()
@@ -64,10 +66,30 @@ class Profiler:


def __init__(self, subgraph='all', is_detail=True, is_show_op_path=False, output_path='./data', def __init__(self, subgraph='all', is_detail=True, is_show_op_path=False, output_path='./data',
optypes_to_deal='', optypes_not_deal='Variable', job_id=""): optypes_to_deal='', optypes_not_deal='Variable', job_id=""):
dev_id = os.getenv('DEVICE_ID')

# get device_id and device_target
device_target = ""
try:
import mindspore.context as context
dev_id = str(context.get_context("device_id"))
device_target = context.get_context("device_target")
except ImportError:
logger.error("Profiling: fail to import context from mindspore.")
except ValueError as err:
logger.error("Profiling: fail to get context %s", err.message)

if not dev_id:
dev_id = str(os.getenv('DEVICE_ID'))
if not dev_id: if not dev_id:
dev_id = "0" dev_id = "0"
logger.error("Fail to get DEVICE_ID, use 0 instead.") logger.error("Fail to get DEVICE_ID, use 0 instead.")

if device_target and device_target != "Davinci" \
and device_target != "Ascend":
msg = ("Profiling: unsupport backend: %s" \
% device_target)
raise RuntimeError(msg)

self._dev_id = dev_id self._dev_id = dev_id
self._container_path = os.path.join(self._base_profiling_container_path, dev_id) self._container_path = os.path.join(self._base_profiling_container_path, dev_id)
data_path = os.path.join(self._container_path, "data") data_path = os.path.join(self._container_path, "data")
@@ -88,7 +110,7 @@ class Profiler:
except ImportError: except ImportError:
logger.error("Profiling: fail to import context from mindspore.") logger.error("Profiling: fail to import context from mindspore.")
except ValueError as err: except ValueError as err:
logger.err("Profiling: fail to set context", err.message)
logger.error("Profiling: fail to set context, %s", err.message)


os.environ['AICPU_PROFILING_MODE'] = 'true' os.environ['AICPU_PROFILING_MODE'] = 'true'
os.environ['PROFILING_DIR'] = str(self._container_path) os.environ['PROFILING_DIR'] = str(self._container_path)
@@ -107,6 +129,8 @@ class Profiler:


Examples: Examples:
>>> from mindinsight.profiler import Profiler >>> from mindinsight.profiler import Profiler
>>> context.set_context(mode=context.GRAPH_MODE, device_target=“Ascend”,
>>> device_id=int(os.environ["DEVICE_ID"]))
>>> profiler = Profiler(subgraph='all', is_detail=True, is_show_op_path=False, output_path='./data') >>> profiler = Profiler(subgraph='all', is_detail=True, is_show_op_path=False, output_path='./data')
>>> model = Model(train_network) >>> model = Model(train_network)
>>> dataset = get_dataset() >>> dataset = get_dataset()


+ 1
- 1
mindinsight/ui/babel.config.js View File

@@ -18,7 +18,7 @@ module.exports = {
[ [
'@vue/app', '@vue/app',
{ {
polyfills: ['es6.promise', 'es6.symbol'],
polyfills: ['es.promise', 'es.symbol'],
}, },
], ],
], ],


+ 0
- 1
mindinsight/ui/package.json View File

@@ -24,7 +24,6 @@
"@intlify/vue-i18n-loader": "0.6.1", "@intlify/vue-i18n-loader": "0.6.1",
"@vue/cli-service": "4.1.0", "@vue/cli-service": "4.1.0",
"@vue/cli-plugin-babel": "4.1.0", "@vue/cli-plugin-babel": "4.1.0",
"babel-core": "6.26.0",
"babel-eslint": "10.0.3", "babel-eslint": "10.0.3",
"eslint": "6.6.0", "eslint": "6.6.0",
"eslint-config-google": "0.13.0", "eslint-config-google": "0.13.0",


+ 18
- 13
mindinsight/ui/src/components/multiselectGroup.vue View File

@@ -162,6 +162,7 @@ export default {
listSelectAll() { listSelectAll() {
this.operateSelectAll = !this.operateSelectAll; this.operateSelectAll = !this.operateSelectAll;
this.multiSelectedItemNames = {}; this.multiSelectedItemNames = {};
this.selectedNumber = 0;
// Setting the status of list items // Setting the status of list items
if (this.operateSelectAll) { if (this.operateSelectAll) {
if (this.isLimit) { if (this.isLimit) {
@@ -171,7 +172,7 @@ export default {
break; break;
} }
const listItem = this.checkListArr[i]; const listItem = this.checkListArr[i];
if (listItem.show) {
if ((listItem.show && !listItem.checked) || listItem.checked) {
listItem.checked = true; listItem.checked = true;
this.multiSelectedItemNames[listItem.label] = true; this.multiSelectedItemNames[listItem.label] = true;
this.selectedNumber++; this.selectedNumber++;
@@ -216,14 +217,17 @@ export default {
} }
this.valiableSearchInput = this.searchInput; this.valiableSearchInput = this.searchInput;
this.multiSelectedItemNames = {}; this.multiSelectedItemNames = {};
this.selectedNumber = 0;
let itemSelectAll = true; let itemSelectAll = true;
// Filter the tags that do not meet the conditions in the operation bar and hide them // Filter the tags that do not meet the conditions in the operation bar and hide them
this.checkListArr.forEach((listItem) => { this.checkListArr.forEach((listItem) => {
if (listItem.checked) {
this.multiSelectedItemNames[listItem.label] = true;
this.selectedNumber++;
}
if (reg.test(listItem.label)) { if (reg.test(listItem.label)) {
listItem.show = true; listItem.show = true;
if (listItem.checked) {
this.multiSelectedItemNames[listItem.label] = true;
} else {
if (!listItem.checked) {
itemSelectAll = false; itemSelectAll = false;
} }
} else { } else {
@@ -232,7 +236,7 @@ export default {
}); });
// Update the selected status of the Select All button // Update the selected status of the Select All button
if (this.isLimit && !itemSelectAll) { if (this.isLimit && !itemSelectAll) {
itemSelectAll = this.selectedNumber >= this.limitNum;
itemSelectAll = this.selectedNumber >= this.limitNum || this.selectedNumber >= this.checkListArr.length;
} }
this.operateSelectAll = itemSelectAll; this.operateSelectAll = itemSelectAll;
this.$emit('selectedChange', this.multiSelectedItemNames); this.$emit('selectedChange', this.multiSelectedItemNames);
@@ -271,7 +275,7 @@ export default {
} }
}); });
if (this.isLimit && !itemSelectAll) { if (this.isLimit && !itemSelectAll) {
itemSelectAll = this.selectedNumber >= this.limitNum;
itemSelectAll = this.selectedNumber >= this.limitNum || this.selectedNumber >= this.checkListArr.length;
} }
this.operateSelectAll = itemSelectAll; this.operateSelectAll = itemSelectAll;
// Return a dictionary containing selected items. // Return a dictionary containing selected items.
@@ -309,23 +313,24 @@ export default {
const loopCount = this.checkListArr.length; const loopCount = this.checkListArr.length;
for (let i = 0; i < loopCount; i++) { for (let i = 0; i < loopCount; i++) {
const listItem = this.checkListArr[i]; const listItem = this.checkListArr[i];
if (reg.test(listItem.label)) {
listItem.show = true;
if (listItem.checked) {
if (this.selectedNumber >= this.limitNum) { if (this.selectedNumber >= this.limitNum) {
listItem.checked = false; listItem.checked = false;
itemSelectAll = false;
} else if (listItem.checked) {
} else {
this.multiSelectedItemNames[listItem.label] = true; this.multiSelectedItemNames[listItem.label] = true;
this.selectedNumber++; this.selectedNumber++;
} else {
itemSelectAll = false;
} }
}
if (reg.test(listItem.label)) {
listItem.show = true;
} else { } else {
listItem.show = false; listItem.show = false;
} }
} }
if (!itemSelectAll && this.selectedNumber >= this.limitNum) {
if (this.selectedNumber >= this.limitNum || this.selectedNumber >= this.checkListArr.length) {
itemSelectAll = true; itemSelectAll = true;
} else {
itemSelectAll = false;
} }
} else { } else {
this.checkListArr.forEach((listItem) => { this.checkListArr.forEach((listItem) => {


+ 1
- 1
mindinsight/ui/src/locales/zh-cn.json View File

@@ -78,7 +78,7 @@
"userDefinedLabel": "User Defined", "userDefinedLabel": "User Defined",
"hyperLabel": "Hyper", "hyperLabel": "Hyper",
"otherLabel": "其他", "otherLabel": "其他",
"remarkTips": "提示:终止服务后备注及tag信息将被清除"
"remarkTips": "提示:终止服务后备注及tag将被清除"
}, },
"dataTraceback": { "dataTraceback": {
"details": "详情", "details": "详情",


+ 4
- 0
mindinsight/ui/src/store.js View File

@@ -32,6 +32,7 @@ export default new Vuex.Store({
: 3, : 3,
// multiSelevtGroup component count // multiSelevtGroup component count
multiSelectedGroupCount: 0, multiSelectedGroupCount: 0,
tableId: 0,
}, },
mutations: { mutations: {
// set cancelTokenArr // set cancelTokenArr
@@ -72,6 +73,9 @@ export default new Vuex.Store({
multiSelectedGroupComponentNum(state) { multiSelectedGroupComponentNum(state) {
state.multiSelectedGroupCount++; state.multiSelectedGroupCount++;
}, },
increaseTableId(state) {
state.tableId++;
},
}, },
actions: {}, actions: {},
}); });

+ 79
- 45
mindinsight/ui/src/views/train-manage/data-traceback.vue View File

@@ -24,11 +24,12 @@ limitations under the License.
type="primary" type="primary"
size="mini" size="mini"
plain plain
v-show="(summaryDirList&&!summaryDirList.length)||(totalSeries&&totalSeries.length)">
v-show="(summaryDirList && !summaryDirList.length)||(totalSeries && totalSeries.length)">
{{ $t('modelTraceback.showAllData') }} {{ $t('modelTraceback.showAllData') }}
</el-button> </el-button>
<div class="select-container" <div class="select-container"
v-show="totalSeries&&totalSeries.length&&(!summaryDirList||(summaryDirList&&summaryDirList.length))">
v-show="totalSeries && totalSeries.length &&
(!summaryDirList || (summaryDirList && summaryDirList.length))">
<div class="display-column"> <div class="display-column">
{{$t('modelTraceback.displayColumn')}} {{$t('modelTraceback.displayColumn')}}
</div> </div>
@@ -50,17 +51,17 @@ limitations under the License.
<button type="text" <button type="text"
@click="allSelect" @click="allSelect"
class="select-all-button" class="select-all-button"
:class="[selectCheckAll?'checked-color':'button-text',
basearr.length>checkOptions.length ? 'btn-disabled' : '']"
:disabled="basearr.length>checkOptions.length">
:class="[selectCheckAll ? 'checked-color' : 'button-text',
basearr.length > checkOptions.length ? 'btn-disabled' : '']"
:disabled="basearr.length > checkOptions.length">
{{ $t('public.selectAll')}} {{ $t('public.selectAll')}}
</button> </button>
<button type="text" <button type="text"
@click="deselectAll" @click="deselectAll"
class="deselect-all-button" class="deselect-all-button"
:class="[!selectCheckAll?'checked-color':'button-text',
basearr.length>checkOptions.length ? 'btn-disabled' : '']"
:disabled="basearr.length>checkOptions.length">
:class="[!selectCheckAll ? 'checked-color' : 'button-text',
basearr.length > checkOptions.length ? 'btn-disabled' : '']"
:disabled="basearr.length > checkOptions.length">
{{ $t('public.deselectAll')}} {{ $t('public.deselectAll')}}
</button> </button>
</div> </div>
@@ -69,7 +70,7 @@ limitations under the License.
:label="item.label" :label="item.label"
:value="item.value" :value="item.value"
:disabled="item.disabled" :disabled="item.disabled"
:title="item.disabled?$t('modelTraceback.mustExist'):''">
:title="item.disabled ? $t('modelTraceback.mustExist') : ''">
</el-option> </el-option>
</el-select> </el-select>
</div> </div>
@@ -79,7 +80,8 @@ limitations under the License.
<div id="data-echart" <div id="data-echart"
v-show="showEchartPic && !echartNoData"></div> v-show="showEchartPic && !echartNoData"></div>
<div class="echart-nodata-container" <div class="echart-nodata-container"
v-show="!showEchartPic && showTable"></div>
v-show="!showEchartPic && showTable && !(summaryDirList && !summaryDirList.length)">
</div>
<div class="btns-container" <div class="btns-container"
v-show="!echartNoData && showTable"> v-show="!echartNoData && showTable">
<el-button type="primary" <el-button type="primary"
@@ -103,7 +105,7 @@ limitations under the License.
<el-table ref="table" <el-table ref="table"
:data="table.data" :data="table.data"
tooltip-effect="light" tooltip-effect="light"
height="calc(100% - 54px)"
height="calc(100% - 40px)"
row-key="summary_dir" row-key="summary_dir"
@selection-change="handleSelectionChange" @selection-change="handleSelectionChange"
@sort-change="tableSortChange"> @sort-change="tableSortChange">
@@ -116,8 +118,8 @@ limitations under the License.
:key="key" :key="key"
:prop="key" :prop="key"
:label="table.columnOptions[key].label" :label="table.columnOptions[key].label"
:sortable="sortArray.includes(table.columnOptions[key].label)?'custom':false"
:fixed="table.columnOptions[key].label===text?true:false"
:sortable="sortArray.includes(table.columnOptions[key].label) ? 'custom' : false"
:fixed="table.columnOptions[key].label === text?true:false"
min-width="200" min-width="200"
show-overflow-tooltip> show-overflow-tooltip>
<template slot="header" <template slot="header"
@@ -151,7 +153,7 @@ limitations under the License.
</el-table-column> </el-table-column>
<!-- remark column --> <!-- remark column -->
<el-table-column fixed="right" <el-table-column fixed="right"
width="310">
width="260">
<template slot="header"> <template slot="header">
<div> <div>
<div class="label-text">{{$t('public.remark')}}</div> <div class="label-text">{{$t('public.remark')}}</div>
@@ -208,10 +210,10 @@ limitations under the License.
<div> <div>
<div class="icon-image-container"> <div class="icon-image-container">
<div class="icon-image" <div class="icon-image"
:class="[item.number===scope.row.tag && scope.row.showIcon?'icon-border':'']"
:class="[item.number === scope.row.tag && scope.row.showIcon ? 'icon-border' : '']"
v-for="item in imageList" v-for="item in imageList"
:key="item.number" :key="item.number"
@click="iconValueChange(scope.row,item.number,$event)">
@click="iconValueChange(scope.row, item.number, $event)">
<img :src="item.iconAdd"> <img :src="item.iconAdd">
</div> </div>
</div> </div>
@@ -243,33 +245,34 @@ limitations under the License.
</template> </template>
</el-table-column> </el-table-column>
</el-table> </el-table>
<div>
<div class="hide-count"
v-show="recordsNumber-showNumber">
{{ $t('modelTraceback.totalHide').replace(`{n}`,(recordsNumber-showNumber))}}
</div>
<div class="pagination-container">
<el-pagination @current-change="handleCurrentChange" <el-pagination @current-change="handleCurrentChange"
:current-page="pagination.currentPage" :current-page="pagination.currentPage"
:page-size="pagination.pageSize" :page-size="pagination.pageSize"
:layout="pagination.layout" :layout="pagination.layout"
:total="pagination.total"> :total="pagination.total">
</el-pagination> </el-pagination>
<div class="hide-count"
v-show="recordsNumber-showNumber">
{{ $t('modelTraceback.totalHide').replace(`{n}`, (recordsNumber-showNumber))}}
</div>
<div class="clear"></div>
</div> </div>
</div> </div>
<div v-show="((!lineagedata.serData || !lineagedata.serData.length) && initOver) <div v-show="((!lineagedata.serData || !lineagedata.serData.length) && initOver)
||(echartNoData&&(lineagedata.serData&&!!lineagedata.serData.length))"
||(echartNoData && (lineagedata.serData && !!lineagedata.serData.length))"
class="no-data-page"> class="no-data-page">
<div class="no-data-img"> <div class="no-data-img">
<img :src="require('@/assets/images/nodata.png')" <img :src="require('@/assets/images/nodata.png')"
alt="" /> alt="" />
<p class="no-data-text" <p class="no-data-text"
v-show="!summaryDirList||(summaryDirList&&summaryDirList.length)&&!lineagedata.serData">
v-show="!summaryDirList || (summaryDirList && summaryDirList.length) && !lineagedata.serData">
{{ $t('public.noData') }} {{ $t('public.noData') }}
</p> </p>
<div v-show="echartNoData&&(lineagedata.serData&&!!lineagedata.serData.length)">
<div v-show="echartNoData && (lineagedata.serData && !!lineagedata.serData.length)">
<p class="no-data-text">{{ $t('dataTraceback.noDataFound') }}</p> <p class="no-data-text">{{ $t('dataTraceback.noDataFound') }}</p>
</div> </div>
<div v-show="summaryDirList&&!summaryDirList.length">
<div v-show="summaryDirList && !summaryDirList.length">
<p class="no-data-text">{{ $t('dataTraceback.noDataFound') }}</p> <p class="no-data-text">{{ $t('dataTraceback.noDataFound') }}</p>
<p class="no-data-text"> <p class="no-data-text">
{{ $t('dataTraceback.click') }} {{ $t('dataTraceback.click') }}
@@ -494,7 +497,7 @@ export default {
obj.iconAdd = require('@/assets/images/icon' + obj.number + '.svg'); obj.iconAdd = require('@/assets/images/icon' + obj.number + '.svg');
this.imageList.push(obj); this.imageList.push(obj);
} }
document.title = this.$t('summaryManage.dataTraceback') + '-MindInsight';
document.title = `${this.$t('summaryManage.dataTraceback')}-MindInsight`;
document.addEventListener('click', this.blurFloat, true); document.addEventListener('click', this.blurFloat, true);
this.$nextTick(() => { this.$nextTick(() => {
this.init(); this.init();
@@ -527,8 +530,8 @@ export default {
return; return;
} }
row.showIcon = true; row.showIcon = true;
const e = window.event;
document.getElementById('icon-dialog').style.top = e.clientY + 'px';
document.getElementById('icon-dialog').style.top =
window.event.clientY + 'px';
}, },


iconValueChange(row, num, event) { iconValueChange(row, num, event) {
@@ -575,6 +578,13 @@ export default {
*/ */


clearIcon(row) { clearIcon(row) {
const classWrap = event.path.find((item) => {
return item.className === 'icon-dialog';
});
const classArr = classWrap.querySelectorAll('.icon-border');
classArr.forEach((item) => {
item.classList.remove('icon-border');
});
row.showIcon = false; row.showIcon = false;
this.iconValue = 0; this.iconValue = 0;
row.tag = 0; row.tag = 0;
@@ -848,7 +858,7 @@ export default {
} }
this.initChart(); this.initChart();
const list = []; const list = [];
this.checkOptions.forEach((item) => {
this.basearr.forEach((item) => {
this.selectArrayValue.forEach((i) => { this.selectArrayValue.forEach((i) => {
if (i === item.value) { if (i === item.value) {
list.push(i); list.push(i);
@@ -917,7 +927,7 @@ export default {
}); });
} }
const list = []; const list = [];
this.checkOptions.forEach((item) => {
this.basearr.forEach((item) => {
this.selectArrayValue.forEach((i) => { this.selectArrayValue.forEach((i) => {
if (i === item.value) { if (i === item.value) {
const obj = {}; const obj = {};
@@ -1061,6 +1071,9 @@ export default {
this.showTable = false; this.showTable = false;
this.echartNoData = true; this.echartNoData = true;
} else { } else {
const echartLength = this.echart.brushData.length;
this.recordsNumber = echartLength;
this.showNumber = echartLength;
this.echart.showData = this.echart.brushData; this.echart.showData = this.echart.brushData;
this.initChart(); this.initChart();
this.pagination.currentPage = 1; this.pagination.currentPage = 1;
@@ -1431,6 +1444,7 @@ export default {
this.initOver = false; this.initOver = false;
this.echartNoData = false; this.echartNoData = false;
this.showEchartPic = true; this.showEchartPic = true;
this.selectCheckAll = true;
// checkOptions initializate to an empty array // checkOptions initializate to an empty array
this.checkOptions = []; this.checkOptions = [];
this.selectArrayValue = []; this.selectArrayValue = [];
@@ -1733,7 +1747,9 @@ export default {
const item = {}; const item = {};
item.key = k; item.key = k;
item.value = dataObj[key][k]; item.value = dataObj[key][k];
item.id = (index + 1) * 10 + 1 + j;
item.id =
`${new Date().getTime()}` + `${this.$store.state.tableId}`;
this.$store.commit('increaseTableId');
tempData.children.push(item); tempData.children.push(item);
}); });
} }
@@ -1775,14 +1791,15 @@ export default {
<style lang="scss"> <style lang="scss">
.label-text { .label-text {
line-height: 20px !important; line-height: 20px !important;
vertical-align: bottom;
padding-top: 20px;
display: block !important;
} }
.remark-tip { .remark-tip {
line-height: 14px !important;
line-height: 20px !important;
font-size: 12px; font-size: 12px;
white-space: pre-wrap !important; white-space: pre-wrap !important;
vertical-align: bottom;
color: gray; color: gray;
display: block !important;
} }
.el-color-dropdown__main-wrapper, .el-color-dropdown__main-wrapper,
.el-color-dropdown__value, .el-color-dropdown__value,
@@ -1841,6 +1858,13 @@ export default {
height: 100%; height: 100%;
overflow-y: auto; overflow-y: auto;
position: relative; position: relative;
.el-table th.is-leaf {
background: #f5f7fa;
}
.el-table td,
.el-table th.is-leaf {
border: 1px solid #ebeef5;
}
.inline-block-set { .inline-block-set {
display: inline-block; display: inline-block;
} }
@@ -1878,7 +1902,7 @@ export default {
.no-data-page { .no-data-page {
width: 100%; width: 100%;
height: 100%; height: 100%;
padding-top: 224px;
padding-top: 200px;
} }
.no-data-img { .no-data-img {
background: #fff; background: #fff;
@@ -1944,6 +1968,7 @@ export default {
.data-checkbox-area { .data-checkbox-area {
position: relative; position: relative;
margin: 24px 32px 12px; margin: 24px 32px 12px;
height: 46px;
.reset-btn { .reset-btn {
position: absolute; position: absolute;
right: 0px; right: 0px;
@@ -1951,12 +1976,12 @@ export default {
} }
} }
#data-echart { #data-echart {
height: 34%;
height: 32%;
width: 100%; width: 100%;
padding: 0 12px; padding: 0 12px;
} }
.echart-nodata-container { .echart-nodata-container {
height: 34%;
height: 32%;
width: 100%; width: 100%;
} }
.btn-container-margin { .btn-container-margin {
@@ -1975,8 +2000,8 @@ export default {


.table-container { .table-container {
background-color: white; background-color: white;
height: calc(60% - 90px);
margin: 6px 32px 0;
height: calc(68% - 130px);
padding: 6px 32px;
position: relative; position: relative;
.custom-label { .custom-label {
max-width: calc(100% - 25px); max-width: calc(100% - 25px);
@@ -1997,24 +2022,33 @@ export default {
.click-span { .click-span {
cursor: pointer; cursor: pointer;
} }
.clear {
clear: both;
}
.hide-count { .hide-count {
display: inline-block;
position: absolute;
right: 450px;
height: 32px; height: 32px;
line-height: 32px; line-height: 32px;
padding-top: 12px;
color: red; color: red;
float: right;
margin-right: 10px;
} }
.el-pagination { .el-pagination {
position: absolute;
right: 0px;
float: right;
margin-right: 32px;
bottom: 10px; bottom: 10px;
} }
.pagination-container {
height: 40px;
}
} }
} }


.details-data-list { .details-data-list {
.el-table td,
.el-table th.is-leaf {
border: none;
border-top: 1px solid #ebeef5;
}
.el-table { .el-table {
th { th {
padding: 10px 0; padding: 10px 0;


+ 66
- 38
mindinsight/ui/src/views/train-manage/model-traceback.vue View File

@@ -20,7 +20,7 @@ limitations under the License.
<div class="select-box" <div class="select-box"
v-if="!noData && v-if="!noData &&
(!summaryDirList || (summaryDirList && summaryDirList.length))"> (!summaryDirList || (summaryDirList && summaryDirList.length))">
<div v-show="showTable&&!noData"
<div v-show="showTable && !noData"
class="select-container"> class="select-container">
<!-- multiple collapse-tags --> <!-- multiple collapse-tags -->
<div class="display-column"> {{$t('modelTraceback.displayColumn')}}</div> <div class="display-column"> {{$t('modelTraceback.displayColumn')}}</div>
@@ -40,19 +40,18 @@ limitations under the License.
<button type="text" <button type="text"
@click="allSelect" @click="allSelect"
class="select-all-button" class="select-all-button"
:class="[selectCheckAll?
'checked-color':'button-text',
basearr.length>checkOptions.length?'btn-disabled':'']"
:disabled="basearr.length>checkOptions.length">
:class="[selectCheckAll ? 'checked-color' : 'button-text',
basearr.length > checkOptions.length ? 'btn-disabled' : '']"
:disabled="basearr.length > checkOptions.length">
{{$t('public.selectAll')}} {{$t('public.selectAll')}}
</button> </button>
<button type="text" <button type="text"
@click="deselectAll" @click="deselectAll"
class="deselect-all-button" class="deselect-all-button"
:class="[!selectCheckAll? :class="[!selectCheckAll?
'checked-color':'button-text',
basearr.length>checkOptions.length?'btn-disabled':'']"
:disabled="basearr.length>checkOptions.length">
'checked-color' : 'button-text',
basearr.length > checkOptions.length ? 'btn-disabled' : '']"
:disabled="basearr.length > checkOptions.length">
{{$t('public.deselectAll')}} {{$t('public.deselectAll')}}
</button> </button>
</div> </div>
@@ -64,7 +63,7 @@ limitations under the License.
:label="item.label" :label="item.label"
:value="item.value" :value="item.value"
:disabled="item.disabled" :disabled="item.disabled"
:title="item.disabled?$t('modelTraceback.mustExist'):''">
:title="item.disabled ? $t('modelTraceback.mustExist') : ''">
</el-option> </el-option>
</el-option-group> </el-option-group>
</el-select> </el-select>
@@ -82,19 +81,19 @@ limitations under the License.
type="primary" type="primary"
size="mini" size="mini"
plain plain
v-if="(!noData&&basearr.length) ||
v-if="(!noData && basearr.length) ||
(noData && summaryDirList && !summaryDirList.length)"> (noData && summaryDirList && !summaryDirList.length)">
{{ $t('modelTraceback.showAllData') }}</el-button> {{ $t('modelTraceback.showAllData') }}</el-button>
</div> </div>


</div> </div>
<div id="echart" <div id="echart"
v-show="!noData&&showEchartPic"></div>
v-show="!noData && showEchartPic"></div>
<div class="echart-no-data" <div class="echart-no-data"
v-show="!showEchartPic"> v-show="!showEchartPic">
</div> </div>
<div class="btns-container" <div class="btns-container"
v-show="showTable&&!noData">
v-show="showTable && !noData">
<el-button type="primary" <el-button type="primary"
size="mini" size="mini"
class="custom-btn" class="custom-btn"
@@ -118,7 +117,7 @@ limitations under the License.
<el-table-column type="selection" <el-table-column type="selection"
width="55" width="55"
:reserve-selection="true" :reserve-selection="true"
v-show="showTable&&!noData">
v-show="showTable && !noData">
</el-table-column> </el-table-column>


<!--metric table column--> <!--metric table column-->
@@ -188,7 +187,7 @@ limitations under the License.
</div> </div>
</template> </template>
<template slot-scope="scope"> <template slot-scope="scope">
<span>{{formatNumber(key,scope.row[key])}}</span>
<span>{{formatNumber(key, scope.row[key])}}</span>
</template> </template>
</el-table-column> </el-table-column>
</el-table-column> </el-table-column>
@@ -197,7 +196,7 @@ limitations under the License.
:key="key" :key="key"
:prop="key" :prop="key"
:label="table.columnOptions[key].label" :label="table.columnOptions[key].label"
:fixed="table.columnOptions[key].label===text?true:false"
:fixed="table.columnOptions[key].label === text ? true : false"
show-overflow-tooltip show-overflow-tooltip
min-width="150" min-width="150"
sortable="custom"> sortable="custom">
@@ -216,7 +215,7 @@ limitations under the License.
</el-table-column> </el-table-column>
<!-- remark column --> <!-- remark column -->
<el-table-column fixed="right" <el-table-column fixed="right"
width="310">
width="260">
<template slot="header"> <template slot="header">
<div> <div>
<div class="label-text">{{$t('public.remark')}}</div> <div class="label-text">{{$t('public.remark')}}</div>
@@ -271,7 +270,7 @@ limitations under the License.
<div> <div>
<div class="icon-image-container"> <div class="icon-image-container">
<div class="icon-image" <div class="icon-image"
:class="[item.number===scope.row.tag&&scope.row.showIcon ? 'icon-border':'']"
:class="[item.number === scope.row.tag && scope.row.showIcon ? 'icon-border' : '']"
v-for="item in imageList" v-for="item in imageList"
:key="item.number" :key="item.number"
@click="iconValueChange(scope.row,item.number,$event)"> @click="iconValueChange(scope.row,item.number,$event)">
@@ -300,17 +299,18 @@ limitations under the License.
</template> </template>
</el-table-column> </el-table-column>
</el-table> </el-table>
<div>
<div class="hide-count"
v-show="recordsNumber-showNumber">
{{$t('modelTraceback.totalHide').replace(`{n}`,(recordsNumber-showNumber))}}
</div>
<div class="pagination-container">
<el-pagination @current-change="pagination.pageChange" <el-pagination @current-change="pagination.pageChange"
:current-page="pagination.currentPage" :current-page="pagination.currentPage"
:page-size="pagination.pageSize" :page-size="pagination.pageSize"
:layout="pagination.layout" :layout="pagination.layout"
:total="pagination.total"> :total="pagination.total">
</el-pagination> </el-pagination>
<div class="hide-count"
v-show="recordsNumber-showNumber">
{{$t('modelTraceback.totalHide').replace(`{n}`, (recordsNumber-showNumber))}}
</div>
<div class="clear"></div>
</div> </div>


</div> </div>
@@ -425,7 +425,7 @@ export default {
obj.iconAdd = require('@/assets/images/icon' + obj.number + '.svg'); obj.iconAdd = require('@/assets/images/icon' + obj.number + '.svg');
this.imageList.push(obj); this.imageList.push(obj);
} }
document.title = this.$t('summaryManage.modelTraceback') + '-MindInsight';
document.title = `${this.$t('summaryManage.modelTraceback')}-MindInsight`;
document.addEventListener('click', this.blurFloat, true); document.addEventListener('click', this.blurFloat, true);
this.$store.commit('setSelectedBarList', []); this.$store.commit('setSelectedBarList', []);
this.getStoreList(); this.getStoreList();
@@ -466,8 +466,8 @@ export default {
return; return;
} }
row.showIcon = true; row.showIcon = true;
const e = window.event;
document.getElementById('icon-dialog').style.top = e.clientY + 'px';
document.getElementById('icon-dialog').style.top =
window.event.clientY + 'px';
}, },


/** /**
@@ -514,6 +514,13 @@ export default {
}, },
// clear icon // clear icon
clearIcon(row) { clearIcon(row) {
const classWrap = event.path.find((item) => {
return item.className === 'icon-dialog';
});
const classArr = classWrap.querySelectorAll('.icon-border');
classArr.forEach((item) => {
item.classList.remove('icon-border');
});
row.showIcon = false; row.showIcon = false;
this.iconValue = 0; this.iconValue = 0;
row.tag = 0; row.tag = 0;
@@ -1345,10 +1352,11 @@ export default {
this.echart.brushData = list; this.echart.brushData = list;
this.echart.showData = this.echart.brushData; this.echart.showData = this.echart.brushData;
this.initChart(); this.initChart();
this.table.data = list.slice(
const showList = list.slice(
(this.pagination.currentPage - 1) * this.pagination.pageSize, (this.pagination.currentPage - 1) * this.pagination.pageSize,
this.pagination.currentPage * this.pagination.pageSize, this.pagination.currentPage * this.pagination.pageSize,
); );
this.table.data = showList;
this.recordsNumber = this.table.data.length; this.recordsNumber = this.table.data.length;
this.showNumber = this.table.data.length; this.showNumber = this.table.data.length;
this.pagination.total = res.data.count || 0; this.pagination.total = res.data.count || 0;
@@ -1365,6 +1373,8 @@ export default {
sortChange(column) { sortChange(column) {
this.sortInfo.sorted_name = column.prop; this.sortInfo.sorted_name = column.prop;
this.sortInfo.sorted_type = column.order; this.sortInfo.sorted_type = column.order;
this.recordsNumber = 0;
this.showNumber = 0;
this.getStoreList(); this.getStoreList();
const tempParam = { const tempParam = {
limit: this.pagination.pageSize, limit: this.pagination.pageSize,
@@ -1384,9 +1394,21 @@ export default {
(res) => { (res) => {
if (res && res.data && res.data.object) { if (res && res.data && res.data.object) {
const list = this.setDataOfModel(res.data.object); const list = this.setDataOfModel(res.data.object);
this.table.data = list;
const tempList = list.slice(0, this.pagination.pageSize);
this.recordsNumber = tempList.length;
if (this.hidenDirChecked.length) {
this.hidenDirChecked.forEach((dir) => {
tempList.forEach((item, index) => {
if (item.summary_dir === dir) {
tempList.splice(index, 1);
}
});
});
}
this.showNumber = tempList.length;
this.table.data = tempList;
this.pagination.total = res.data.count || 0; this.pagination.total = res.data.count || 0;
this.pagination.currentPage = 0;
this.pagination.currentPage = 1;
} }
}, },
(error) => {}, (error) => {},
@@ -1741,6 +1763,7 @@ export default {
this.$store.commit('setSelectedBarList', []); this.$store.commit('setSelectedBarList', []);
this.noData = false; this.noData = false;
this.showTable = false; this.showTable = false;
this.selectCheckAll = true;
this.chartFilter = {}; this.chartFilter = {};
this.tableFilter.summary_dir = undefined; this.tableFilter.summary_dir = undefined;
this.sortInfo = {}; this.sortInfo = {};
@@ -1838,14 +1861,15 @@ export default {
<style lang="scss"> <style lang="scss">
.label-text { .label-text {
line-height: 20px !important; line-height: 20px !important;
vertical-align: bottom;
padding-top: 20px;
display: block !important;
} }
.remark-tip { .remark-tip {
line-height: 14px !important;
line-height: 20px !important;
font-size: 12px; font-size: 12px;
white-space: pre-wrap !important; white-space: pre-wrap !important;
vertical-align: bottom;
color: gray; color: gray;
display: block !important;
} }
.el-color-dropdown__main-wrapper, .el-color-dropdown__main-wrapper,
.el-color-dropdown__value, .el-color-dropdown__value,
@@ -1943,6 +1967,7 @@ export default {
.btns { .btns {
margin-left: 20px; margin-left: 20px;
padding-top: 12px; padding-top: 12px;
height: 46px;
} }
.btn-container-margin { .btn-container-margin {
margin: 0 55px 10px; margin: 0 55px 10px;
@@ -2048,7 +2073,7 @@ export default {
} }
.table-container { .table-container {
background-color: white; background-color: white;
height: calc(60% - 40px);
height: calc(68% - 130px);
padding: 6px 32px; padding: 6px 32px;
position: relative; position: relative;
.custom-label { .custom-label {
@@ -2059,21 +2084,24 @@ export default {
a { a {
cursor: pointer; cursor: pointer;
} }
.clear {
clear: both;
}
.hide-count { .hide-count {
display: inline-block;
position: absolute;
right: 450px;
height: 32px; height: 32px;
line-height: 32px; line-height: 32px;
padding-top: 4px;
color: red; color: red;
float: right;
margin-right: 10px;
} }
.el-pagination { .el-pagination {
float: right;
margin-right: 32px; margin-right: 32px;
position: absolute;
right: 0;
bottom: 10px; bottom: 10px;
} }
.pagination-container {
height: 40px;
}
} }
.no-data-page { .no-data-page {
width: 100%; width: 100%;


+ 11
- 17
mindinsight/ui/src/views/train-manage/profiler.vue View File

@@ -61,12 +61,10 @@
<div class="cl-search-box"> <div class="cl-search-box">
<el-input v-model="searchByTypeInput" <el-input v-model="searchByTypeInput"
v-if="statisticType === 0" v-if="statisticType === 0"
suffix-icon="el-icon-search"
:placeholder="$t('profiler.searchByType')" :placeholder="$t('profiler.searchByType')"
@keyup.enter.native="searchOpCoreList()"></el-input> @keyup.enter.native="searchOpCoreList()"></el-input>
<el-input v-model="searchByNameInput" <el-input v-model="searchByNameInput"
v-if="statisticType === 1" v-if="statisticType === 1"
suffix-icon="el-icon-search"
:placeholder="$t('profiler.searchByName')" :placeholder="$t('profiler.searchByName')"
@keyup.enter.native="searchOpCoreList()"></el-input> @keyup.enter.native="searchOpCoreList()"></el-input>
</div> </div>
@@ -90,6 +88,8 @@
:property="ele" :property="ele"
:key="key" :key="key"
:sortable="ele === 'op_info' ? false : 'custom'" :sortable="ele === 'op_info' ? false : 'custom'"
:width="(ele==='execution_time'|| ele==='subgraph' ||
ele==='op_name'|| ele==='op_type')?'220':''"
show-overflow-tooltip show-overflow-tooltip
:label="ele"> :label="ele">
</el-table-column> </el-table-column>
@@ -124,6 +124,8 @@
:key="$index" :key="$index"
:label="item" :label="item"
:sortable="item === 'op_info' ? false : 'custom'" :sortable="item === 'op_info' ? false : 'custom'"
:width="(item==='execution_time'|| item==='subgraph' ||
item==='op_name'|| item==='op_type')?'220':''"
show-overflow-tooltip> show-overflow-tooltip>
</el-table-column> </el-table-column>
</el-table> </el-table>
@@ -168,7 +170,6 @@
</span> </span>
<div class="cl-search-box"> <div class="cl-search-box">
<el-input v-model="searchByCPUNameInput" <el-input v-model="searchByCPUNameInput"
suffix-icon="el-icon-search"
:placeholder="$t('profiler.searchByName')" :placeholder="$t('profiler.searchByName')"
@keyup.enter.native="searchOpCpuList()"></el-input> @keyup.enter.native="searchOpCpuList()"></el-input>
</div> </div>
@@ -814,7 +815,8 @@ export default {
option.xAxis = { option.xAxis = {
type: 'category', type: 'category',
axisLabel: { axisLabel: {
interval: 1,
interval: 0,
rotate: -30,
}, },
data: [], data: [],
}; };
@@ -822,7 +824,7 @@ export default {
left: 50, left: 50,
top: 20, top: 20,
right: 0, right: 0,
bottom: 30,
bottom: 50,
}; };
option.yAxis = { option.yAxis = {
type: 'value', type: 'value',
@@ -925,7 +927,7 @@ export default {
const item = {}; const item = {};
item.key = k; item.key = k;
item.value = dataObj[key][k]; item.value = dataObj[key][k];
item.id = (index + 1) * 10 + 1 + j;
item.id = item.key + Math.random();
tempData.children.push(item); tempData.children.push(item);
}); });
} }
@@ -955,20 +957,12 @@ export default {
}, },
}, },
mounted() { mounted() {
if (
this.$route.query &&
this.$route.query.dir &&
this.$route.query.id
) {
if (this.$route.query && this.$route.query.dir && this.$route.query.id) {
this.profile_dir = this.$route.query.dir; this.profile_dir = this.$route.query.dir;
this.train_id = this.$route.query.id; this.train_id = this.$route.query.id;
document.title =
decodeURIComponent(this.train_id) +
'-' +
this.$t('profiler.titleText') +
'-MindInsight';
document.title = `${ decodeURIComponent(this.train_id)}-${this.$t('profiler.titleText')}-MindInsight`;
} else { } else {
document.title = this.$t('profiler.titleText') + '-MindInsight';
document.title = `${this.$t('profiler.titleText')}-MindInsight`;
} }
this.init(); this.init();
window.addEventListener('resize', this.resizeCallback, false); window.addEventListener('resize', this.resizeCallback, false);


+ 26
- 24
tests/st/func/lineagemgr/api/test_model_api.py View File

@@ -31,6 +31,7 @@ from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageFileNotF
LineageSearchConditionParamError) LineageSearchConditionParamError)
from ..conftest import BASE_SUMMARY_DIR, DATASET_GRAPH, SUMMARY_DIR, SUMMARY_DIR_2 from ..conftest import BASE_SUMMARY_DIR, DATASET_GRAPH, SUMMARY_DIR, SUMMARY_DIR_2
from .....ut.lineagemgr.querier import event_data from .....ut.lineagemgr.querier import event_data
from .....utils.tools import assert_equal_lineages


LINEAGE_INFO_RUN1 = { LINEAGE_INFO_RUN1 = {
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'), 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'),
@@ -39,7 +40,7 @@ LINEAGE_INFO_RUN1 = {
}, },
'hyper_parameters': { 'hyper_parameters': {
'optimizer': 'Momentum', 'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'learning_rate': 0.12,
'loss_function': 'SoftmaxCrossEntropyWithLogits', 'loss_function': 'SoftmaxCrossEntropyWithLogits',
'epoch': 14, 'epoch': 14,
'parallel_mode': 'stand_alone', 'parallel_mode': 'stand_alone',
@@ -73,11 +74,11 @@ LINEAGE_FILTRATION_EXCEPT_RUN = {
'user_defined': {}, 'user_defined': {},
'network': 'ResNet', 'network': 'ResNet',
'optimizer': 'Momentum', 'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'learning_rate': 0.12,
'epoch': 10, 'epoch': 10,
'batch_size': 32, 'batch_size': 32,
'device_num': 2, 'device_num': 2,
'loss': 0.029999999329447746,
'loss': 0.03,
'model_size': 64, 'model_size': 64,
'metric': {}, 'metric': {},
'dataset_mark': 2 'dataset_mark': 2
@@ -92,10 +93,14 @@ LINEAGE_FILTRATION_RUN1 = {
'train_dataset_count': 1024, 'train_dataset_count': 1024,
'test_dataset_path': None, 'test_dataset_path': None,
'test_dataset_count': 1024, 'test_dataset_count': 1024,
'user_defined': {},
'user_defined': {
'info': 'info1',
'version': 'v1',
'eval_version': 'version2'
},
'network': 'ResNet', 'network': 'ResNet',
'optimizer': 'Momentum', 'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'learning_rate': 0.12,
'epoch': 14, 'epoch': 14,
'batch_size': 32, 'batch_size': 32,
'device_num': 2, 'device_num': 2,
@@ -119,14 +124,14 @@ LINEAGE_FILTRATION_RUN2 = {
'user_defined': {}, 'user_defined': {},
'network': "ResNet", 'network': "ResNet",
'optimizer': "Momentum", 'optimizer': "Momentum",
'learning_rate': 0.11999999731779099,
'learning_rate': 0.12,
'epoch': 10, 'epoch': 10,
'batch_size': 32, 'batch_size': 32,
'device_num': 2, 'device_num': 2,
'loss': 0.029999999329447746,
'loss': 0.03,
'model_size': 10, 'model_size': 10,
'metric': { 'metric': {
'accuracy': 2.7800000000000002
'accuracy': 2.78
}, },
'dataset_mark': 3 'dataset_mark': 3
}, },
@@ -173,7 +178,7 @@ class TestModelApi(TestCase):
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'), 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'),
'hyper_parameters': { 'hyper_parameters': {
'optimizer': 'Momentum', 'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
'learning_rate': 0.12,
'loss_function': 'SoftmaxCrossEntropyWithLogits', 'loss_function': 'SoftmaxCrossEntropyWithLogits',
'epoch': 14, 'epoch': 14,
'parallel_mode': 'stand_alone', 'parallel_mode': 'stand_alone',
@@ -190,9 +195,9 @@ class TestModelApi(TestCase):
'network': 'ResNet' 'network': 'ResNet'
} }
} }
assert expect_total_res == total_res
assert expect_partial_res1 == partial_res1
assert expect_partial_res2 == partial_res2
assert_equal_lineages(expect_total_res, total_res, self.assertDictEqual)
assert_equal_lineages(expect_partial_res1, partial_res1, self.assertDictEqual)
assert_equal_lineages(expect_partial_res2, partial_res2, self.assertDictEqual)


# the lineage summary file is empty # the lineage summary file is empty
result = get_summary_lineage(self.dir_with_empty_lineage) result = get_summary_lineage(self.dir_with_empty_lineage)
@@ -329,7 +334,7 @@ class TestModelApi(TestCase):
def test_filter_summary_lineage(self): def test_filter_summary_lineage(self):
"""Test the interface of filter_summary_lineage.""" """Test the interface of filter_summary_lineage."""
expect_result = { expect_result = {
'customized': event_data.CUSTOMIZED__0,
'customized': event_data.CUSTOMIZED__1,
'object': [ 'object': [
LINEAGE_FILTRATION_EXCEPT_RUN, LINEAGE_FILTRATION_EXCEPT_RUN,
LINEAGE_FILTRATION_RUN1, LINEAGE_FILTRATION_RUN1,
@@ -345,7 +350,7 @@ class TestModelApi(TestCase):
expect_objects = expect_result.get('object') expect_objects = expect_result.get('object')
for idx, res_object in enumerate(res.get('object')): for idx, res_object in enumerate(res.get('object')):
expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark')
assert expect_result == res
assert_equal_lineages(expect_result, res, self.assertDictEqual)


expect_result = { expect_result = {
'customized': {}, 'customized': {},
@@ -356,7 +361,7 @@ class TestModelApi(TestCase):
expect_objects = expect_result.get('object') expect_objects = expect_result.get('object')
for idx, res_object in enumerate(res.get('object')): for idx, res_object in enumerate(res.get('object')):
expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark')
assert expect_result == res
assert_equal_lineages(expect_result, res, self.assertDictEqual)


@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_arm_ascend_training @pytest.mark.platform_arm_ascend_training
@@ -383,7 +388,7 @@ class TestModelApi(TestCase):
'offset': 0 'offset': 0
} }
expect_result = { expect_result = {
'customized': event_data.CUSTOMIZED__0,
'customized': event_data.CUSTOMIZED__1,
'object': [ 'object': [
LINEAGE_FILTRATION_RUN2, LINEAGE_FILTRATION_RUN2,
LINEAGE_FILTRATION_RUN1 LINEAGE_FILTRATION_RUN1
@@ -394,7 +399,7 @@ class TestModelApi(TestCase):
expect_objects = expect_result.get('object') expect_objects = expect_result.get('object')
for idx, res_object in enumerate(partial_res.get('object')): for idx, res_object in enumerate(partial_res.get('object')):
expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark')
assert expect_result == partial_res
assert_equal_lineages(expect_result, partial_res, self.assertDictEqual)


@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_arm_ascend_training @pytest.mark.platform_arm_ascend_training
@@ -421,7 +426,7 @@ class TestModelApi(TestCase):
'offset': 0 'offset': 0
} }
expect_result = { expect_result = {
'customized': event_data.CUSTOMIZED__0,
'customized': event_data.CUSTOMIZED__1,
'object': [ 'object': [
LINEAGE_FILTRATION_RUN2, LINEAGE_FILTRATION_RUN2,
LINEAGE_FILTRATION_RUN1 LINEAGE_FILTRATION_RUN1
@@ -432,7 +437,7 @@ class TestModelApi(TestCase):
expect_objects = expect_result.get('object') expect_objects = expect_result.get('object')
for idx, res_object in enumerate(partial_res.get('object')): for idx, res_object in enumerate(partial_res.get('object')):
expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark')
assert expect_result == partial_res
assert_equal_lineages(expect_result, partial_res, self.assertDictEqual)


@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_arm_ascend_training @pytest.mark.platform_arm_ascend_training
@@ -449,7 +454,7 @@ class TestModelApi(TestCase):
'sorted_name': 'metric/accuracy', 'sorted_name': 'metric/accuracy',
} }
expect_result = { expect_result = {
'customized': event_data.CUSTOMIZED__0,
'customized': event_data.CUSTOMIZED__1,
'object': [ 'object': [
LINEAGE_FILTRATION_EXCEPT_RUN, LINEAGE_FILTRATION_EXCEPT_RUN,
LINEAGE_FILTRATION_RUN1, LINEAGE_FILTRATION_RUN1,
@@ -461,7 +466,7 @@ class TestModelApi(TestCase):
expect_objects = expect_result.get('object') expect_objects = expect_result.get('object')
for idx, res_object in enumerate(partial_res1.get('object')): for idx, res_object in enumerate(partial_res1.get('object')):
expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark')
assert expect_result == partial_res1
assert_equal_lineages(expect_result, partial_res1, self.assertDictEqual)


search_condition2 = { search_condition2 = {
'batch_size': { 'batch_size': {
@@ -477,9 +482,6 @@ class TestModelApi(TestCase):
'count': 0 'count': 0
} }
partial_res2 = filter_summary_lineage(BASE_SUMMARY_DIR, search_condition2) partial_res2 = filter_summary_lineage(BASE_SUMMARY_DIR, search_condition2)
expect_objects = expect_result.get('object')
for idx, res_object in enumerate(partial_res2.get('object')):
expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark')
assert expect_result == partial_res2 assert expect_result == partial_res2


@pytest.mark.level0 @pytest.mark.level0


+ 5
- 6
tests/st/func/lineagemgr/cache/test_lineage_cache.py View File

@@ -33,7 +33,7 @@ from ..api.test_model_api import LINEAGE_INFO_RUN1, LINEAGE_FILTRATION_EXCEPT_RU
LINEAGE_FILTRATION_RUN1, LINEAGE_FILTRATION_RUN2 LINEAGE_FILTRATION_RUN1, LINEAGE_FILTRATION_RUN2
from ..conftest import BASE_SUMMARY_DIR from ..conftest import BASE_SUMMARY_DIR
from .....ut.lineagemgr.querier import event_data from .....ut.lineagemgr.querier import event_data
from .....utils.tools import check_loading_done
from .....utils.tools import check_loading_done, assert_equal_lineages




@pytest.mark.usefixtures("create_summary_dir") @pytest.mark.usefixtures("create_summary_dir")
@@ -58,8 +58,7 @@ class TestModelApi(TestCase):
"""Test the interface of get_summary_lineage.""" """Test the interface of get_summary_lineage."""
total_res = general_get_summary_lineage(data_manager=self._data_manger, summary_dir="./run1") total_res = general_get_summary_lineage(data_manager=self._data_manger, summary_dir="./run1")
expect_total_res = LINEAGE_INFO_RUN1 expect_total_res = LINEAGE_INFO_RUN1

assert expect_total_res == total_res
assert_equal_lineages(expect_total_res, total_res, self.assertDictEqual)


@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_arm_ascend_training @pytest.mark.platform_arm_ascend_training
@@ -70,7 +69,7 @@ class TestModelApi(TestCase):
def test_filter_summary_lineage(self): def test_filter_summary_lineage(self):
"""Test the interface of filter_summary_lineage.""" """Test the interface of filter_summary_lineage."""
expect_result = { expect_result = {
'customized': event_data.CUSTOMIZED__0,
'customized': event_data.CUSTOMIZED__1,
'object': [ 'object': [
LINEAGE_FILTRATION_EXCEPT_RUN, LINEAGE_FILTRATION_EXCEPT_RUN,
LINEAGE_FILTRATION_RUN1, LINEAGE_FILTRATION_RUN1,
@@ -86,7 +85,7 @@ class TestModelApi(TestCase):
expect_objects = expect_result.get('object') expect_objects = expect_result.get('object')
for idx, res_object in enumerate(res.get('object')): for idx, res_object in enumerate(res.get('object')):
expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark')
assert expect_result == res
assert_equal_lineages(expect_result, res, self.assertDictEqual)


expect_result = { expect_result = {
'customized': {}, 'customized': {},
@@ -100,4 +99,4 @@ class TestModelApi(TestCase):
} }
} }
res = general_filter_summary_lineage(data_manager=self._data_manger, search_condition=search_condition) res = general_filter_summary_lineage(data_manager=self._data_manger, search_condition=search_condition)
assert expect_result == res
assert_equal_lineages(expect_result, res, self.assertDictEqual)

+ 16
- 6
tests/st/func/lineagemgr/collection/model/test_model_lineage.py View File

@@ -28,7 +28,7 @@ from unittest import mock, TestCase
import numpy as np import numpy as np
import pytest import pytest


from mindinsight.lineagemgr import get_summary_lineage
from mindinsight.lineagemgr import get_summary_lineage, filter_summary_lineage
from mindinsight.lineagemgr.collection.model.model_lineage import TrainLineage, EvalLineage, \ from mindinsight.lineagemgr.collection.model.model_lineage import TrainLineage, EvalLineage, \
AnalyzeObject AnalyzeObject
from mindinsight.lineagemgr.common.utils import make_directory from mindinsight.lineagemgr.common.utils import make_directory
@@ -73,6 +73,10 @@ class TestModelLineage(TestCase):
TrainLineage(cls.summary_record) TrainLineage(cls.summary_record)
] ]
cls.run_context['list_callback'] = _ListCallback(callback) cls.run_context['list_callback'] = _ListCallback(callback)
cls.user_defined_info = {
"info": "info1",
"version": "v1"
}


@pytest.mark.scene_train(2) @pytest.mark.scene_train(2)
@pytest.mark.level0 @pytest.mark.level0
@@ -83,7 +87,7 @@ class TestModelLineage(TestCase):
@pytest.mark.env_single @pytest.mark.env_single
def test_train_begin(self): def test_train_begin(self):
"""Test the begin function in TrainLineage.""" """Test the begin function in TrainLineage."""
train_callback = TrainLineage(self.summary_record, True)
train_callback = TrainLineage(self.summary_record, True, self.user_defined_info)
train_callback.begin(RunContext(self.run_context)) train_callback.begin(RunContext(self.run_context))
assert train_callback.initial_learning_rate == 0.12 assert train_callback.initial_learning_rate == 0.12
lineage_log_path = train_callback.lineage_summary.lineage_log_path lineage_log_path = train_callback.lineage_summary.lineage_log_path
@@ -98,7 +102,11 @@ class TestModelLineage(TestCase):
@pytest.mark.env_single @pytest.mark.env_single
def test_train_begin_with_user_defined_info(self): def test_train_begin_with_user_defined_info(self):
"""Test TrainLineage with nested user defined info.""" """Test TrainLineage with nested user defined info."""
user_defined_info = {"info": {"version": "v1"}}
user_defined_info = {
"info": "info1",
"version": "v1",
"network": "LeNet"
}
train_callback = TrainLineage( train_callback = TrainLineage(
self.summary_record, self.summary_record,
False, False,
@@ -108,6 +116,8 @@ class TestModelLineage(TestCase):
assert train_callback.initial_learning_rate == 0.12 assert train_callback.initial_learning_rate == 0.12
lineage_log_path = train_callback.lineage_summary.lineage_log_path lineage_log_path = train_callback.lineage_summary.lineage_log_path
assert os.path.isfile(lineage_log_path) is True assert os.path.isfile(lineage_log_path) is True
res = filter_summary_lineage(os.path.dirname(lineage_log_path))
assert self.user_defined_info == res['object'][0]['model_lineage']['user_defined']


@pytest.mark.scene_train(2) @pytest.mark.scene_train(2)
@pytest.mark.level0 @pytest.mark.level0
@@ -138,7 +148,7 @@ class TestModelLineage(TestCase):
def test_training_end(self, *args): def test_training_end(self, *args):
"""Test the end function in TrainLineage.""" """Test the end function in TrainLineage."""
args[0].return_value = 64 args[0].return_value = 64
train_callback = TrainLineage(self.summary_record, True)
train_callback = TrainLineage(self.summary_record, True, self.user_defined_info)
train_callback.initial_learning_rate = 0.12 train_callback.initial_learning_rate = 0.12
train_callback.end(RunContext(self.run_context)) train_callback.end(RunContext(self.run_context))
res = get_summary_lineage(SUMMARY_DIR) res = get_summary_lineage(SUMMARY_DIR)
@@ -158,7 +168,7 @@ class TestModelLineage(TestCase):
@pytest.mark.env_single @pytest.mark.env_single
def test_eval_end(self): def test_eval_end(self):
"""Test the end function in EvalLineage.""" """Test the end function in EvalLineage."""
eval_callback = EvalLineage(self.summary_record, True)
eval_callback = EvalLineage(self.summary_record, True, {'eval_version': 'version2'})
eval_run_context = self.run_context eval_run_context = self.run_context
eval_run_context['metrics'] = {'accuracy': 0.78} eval_run_context['metrics'] = {'accuracy': 0.78}
eval_run_context['valid_dataset'] = self.run_context['train_dataset'] eval_run_context['valid_dataset'] = self.run_context['train_dataset']
@@ -331,7 +341,7 @@ class TestModelLineage(TestCase):
def test_train_with_customized_network(self, *args): def test_train_with_customized_network(self, *args):
"""Test train with customized network.""" """Test train with customized network."""
args[0].return_value = 64 args[0].return_value = 64
train_callback = TrainLineage(self.summary_record, True)
train_callback = TrainLineage(self.summary_record, True, self.user_defined_info)
run_context_customized = self.run_context run_context_customized = self.run_context
del run_context_customized['optimizer'] del run_context_customized['optimizer']
del run_context_customized['net_outputs'] del run_context_customized['net_outputs']


+ 30
- 23
tests/ut/lineagemgr/querier/event_data.py View File

@@ -22,7 +22,7 @@ EVENT_TRAIN_DICT_0 = {
'train_lineage': { 'train_lineage': {
'hyper_parameters': { 'hyper_parameters': {
'optimizer': 'ApplyMomentum0', 'optimizer': 'ApplyMomentum0',
'learning_rate': 0.10000000149011612,
'learning_rate': 0.11,
'loss_function': '', 'loss_function': '',
'epoch': 1, 'epoch': 1,
'parallel_mode': 'stand_alone0', 'parallel_mode': 'stand_alone0',
@@ -31,7 +31,7 @@ EVENT_TRAIN_DICT_0 = {
}, },
'algorithm': { 'algorithm': {
'network': 'TrainOneStepCell0', 'network': 'TrainOneStepCell0',
'loss': 2.3025848865509033
'loss': 2.3025841
}, },
'train_dataset': { 'train_dataset': {
'train_dataset_path': '', 'train_dataset_path': '',
@@ -49,7 +49,7 @@ EVENT_TRAIN_DICT_1 = {
'train_lineage': { 'train_lineage': {
'hyper_parameters': { 'hyper_parameters': {
'optimizer': 'ApplyMomentum1', 'optimizer': 'ApplyMomentum1',
'learning_rate': 0.20000000298023224,
'learning_rate': 0.2100001,
'loss_function': 'loss_function1', 'loss_function': 'loss_function1',
'epoch': 1, 'epoch': 1,
'parallel_mode': 'stand_alone1', 'parallel_mode': 'stand_alone1',
@@ -58,7 +58,7 @@ EVENT_TRAIN_DICT_1 = {
}, },
'algorithm': { 'algorithm': {
'network': 'TrainOneStepCell1', 'network': 'TrainOneStepCell1',
'loss': 2.4025847911834717
'loss': 2.4025841
}, },
'train_dataset': { 'train_dataset': {
'train_dataset_path': '/path/to/train_dataset1', 'train_dataset_path': '/path/to/train_dataset1',
@@ -76,7 +76,7 @@ EVENT_TRAIN_DICT_2 = {
'train_lineage': { 'train_lineage': {
'hyper_parameters': { 'hyper_parameters': {
'optimizer': 'ApplyMomentum2', 'optimizer': 'ApplyMomentum2',
'learning_rate': 0.30000001192092896,
'learning_rate': 0.3100001,
'loss_function': 'loss_function2', 'loss_function': 'loss_function2',
'epoch': 2, 'epoch': 2,
'parallel_mode': 'stand_alone2', 'parallel_mode': 'stand_alone2',
@@ -85,7 +85,7 @@ EVENT_TRAIN_DICT_2 = {
}, },
'algorithm': { 'algorithm': {
'network': 'TrainOneStepCell2', 'network': 'TrainOneStepCell2',
'loss': 2.502584934234619
'loss': 2.5025841
}, },
'train_dataset': { 'train_dataset': {
'train_dataset_path': '/path/to/train_dataset2', 'train_dataset_path': '/path/to/train_dataset2',
@@ -103,7 +103,7 @@ EVENT_TRAIN_DICT_3 = {
'train_lineage': { 'train_lineage': {
'hyper_parameters': { 'hyper_parameters': {
'optimizer': 'ApplyMomentum3', 'optimizer': 'ApplyMomentum3',
'learning_rate': 0.4000000059604645,
'learning_rate': 0.4,
'loss_function': 'loss_function3', 'loss_function': 'loss_function3',
'epoch': 2, 'epoch': 2,
'parallel_mode': 'stand_alone3', 'parallel_mode': 'stand_alone3',
@@ -112,7 +112,7 @@ EVENT_TRAIN_DICT_3 = {
}, },
'algorithm': { 'algorithm': {
'network': 'TrainOneStepCell3', 'network': 'TrainOneStepCell3',
'loss': 2.6025848388671875
'loss': 2.6025841
}, },
'train_dataset': { 'train_dataset': {
'train_dataset_path': '/path/to/train_dataset3', 'train_dataset_path': '/path/to/train_dataset3',
@@ -139,7 +139,7 @@ EVENT_TRAIN_DICT_4 = {
}, },
'algorithm': { 'algorithm': {
'network': 'TrainOneStepCell4', 'network': 'TrainOneStepCell4',
'loss': 2.702584981918335
'loss': 2.7025841
}, },
'train_dataset': { 'train_dataset': {
'train_dataset_path': '/path/to/train_dataset4', 'train_dataset_path': '/path/to/train_dataset4',
@@ -166,7 +166,7 @@ EVENT_TRAIN_DICT_5 = {
}, },
'algorithm': { 'algorithm': {
'network': 'TrainOneStepCell5', 'network': 'TrainOneStepCell5',
'loss': 2.702584981918335
'loss': 2.7025841
}, },
'train_dataset': { 'train_dataset': {
'train_dataset_size': 35 'train_dataset_size': 35
@@ -192,6 +192,13 @@ CUSTOMIZED__0 = {
'metric/accuracy': {'label': 'metric/accuracy', 'required': True, 'type': 'float'}, 'metric/accuracy': {'label': 'metric/accuracy', 'required': True, 'type': 'float'},
} }


CUSTOMIZED__1 = {
**CUSTOMIZED__0,
'user_defined/info': {'label': 'user_defined/info', 'required': False, 'type': 'str'},
'user_defined/version': {'label': 'user_defined/version', 'required': False, 'type': 'str'},
'user_defined/eval_version': {'label': 'user_defined/eval_version', 'required': False, 'type': 'str'}
}

CUSTOMIZED_0 = { CUSTOMIZED_0 = {
**CUSTOMIZED__0, **CUSTOMIZED__0,
'metric/mae': {'label': 'metric/mae', 'required': True, 'type': 'float'}, 'metric/mae': {'label': 'metric/mae', 'required': True, 'type': 'float'},
@@ -211,33 +218,33 @@ CUSTOMIZED_2 = {
} }


METRIC_1 = { METRIC_1 = {
'accuracy': 1.0000002,
'accuracy': 1.2000002,
'mae': 2.00000002, 'mae': 2.00000002,
'mse': 3.00000002 'mse': 3.00000002
} }


METRIC_2 = { METRIC_2 = {
'accuracy': 1.0000003,
'mae': 2.00000003,
'mse': 3.00000003
'accuracy': 1.3000003,
'mae': 2.30000003,
'mse': 3.30000003
} }


METRIC_3 = { METRIC_3 = {
'accuracy': 1.0000004,
'mae': 2.00000004,
'mse': 3.00000004
'accuracy': 1.4000004,
'mae': 2.40000004,
'mse': 3.40000004
} }


METRIC_4 = { METRIC_4 = {
'accuracy': 1.0000005,
'mae': 2.00000005,
'mse': 3.00000005
'accuracy': 1.5000005,
'mae': 2.50000005,
'mse': 3.50000005
} }


METRIC_5 = { METRIC_5 = {
'accuracy': 1.0000006,
'mae': 2.00000006,
'mse': 3.00000006
'accuracy': 1.7000006,
'mae': 2.60000006,
'mse': 3.60000006
} }


EVENT_EVAL_DICT_0 = { EVENT_EVAL_DICT_0 = {


+ 20
- 20
tests/ut/lineagemgr/querier/test_querier.py View File

@@ -27,6 +27,7 @@ from mindinsight.lineagemgr.querier.querier import Querier
from mindinsight.lineagemgr.summary.lineage_summary_analyzer import LineageInfo from mindinsight.lineagemgr.summary.lineage_summary_analyzer import LineageInfo


from . import event_data from . import event_data
from ....utils.tools import assert_equal_lineages




def create_lineage_info(train_event_dict, eval_event_dict, dataset_event_dict): def create_lineage_info(train_event_dict, eval_event_dict, dataset_event_dict):
@@ -266,7 +267,6 @@ class TestQuerier(TestCase):
mock_file_handler = MagicMock() mock_file_handler = MagicMock()
mock_file_handler.size = 1 mock_file_handler.size = 1



args[2].return_value = [{'relative_path': './', 'update_time': 1}] args[2].return_value = [{'relative_path': './', 'update_time': 1}]
single_summary_path = '/path/to/summary0' single_summary_path = '/path/to/summary0'
lineage_objects = LineageOrganizer(summary_base_dir=single_summary_path).super_lineage_objs lineage_objects = LineageOrganizer(summary_base_dir=single_summary_path).super_lineage_objs
@@ -286,13 +286,13 @@ class TestQuerier(TestCase):
"""Test the success of get_summary_lineage.""" """Test the success of get_summary_lineage."""
expected_result = [LINEAGE_INFO_0] expected_result = [LINEAGE_INFO_0]
result = self.single_querier.get_summary_lineage() result = self.single_querier.get_summary_lineage()
self.assertListEqual(expected_result, result)
assert_equal_lineages(expected_result, result, self.assertListEqual)


def test_get_summary_lineage_success_2(self): def test_get_summary_lineage_success_2(self):
"""Test the success of get_summary_lineage.""" """Test the success of get_summary_lineage."""
expected_result = [LINEAGE_INFO_0] expected_result = [LINEAGE_INFO_0]
result = self.single_querier.get_summary_lineage() result = self.single_querier.get_summary_lineage()
self.assertListEqual(expected_result, result)
assert_equal_lineages(expected_result, result, self.assertListEqual)


def test_get_summary_lineage_success_3(self): def test_get_summary_lineage_success_3(self):
"""Test the success of get_summary_lineage.""" """Test the success of get_summary_lineage."""
@@ -306,7 +306,7 @@ class TestQuerier(TestCase):
result = self.single_querier.get_summary_lineage( result = self.single_querier.get_summary_lineage(
filter_keys=['model', 'algorithm'] filter_keys=['model', 'algorithm']
) )
self.assertListEqual(expected_result, result)
assert_equal_lineages(expected_result, result, self.assertListEqual)


def test_get_summary_lineage_success_4(self): def test_get_summary_lineage_success_4(self):
"""Test the success of get_summary_lineage.""" """Test the success of get_summary_lineage."""
@@ -353,7 +353,7 @@ class TestQuerier(TestCase):
} }
] ]
result = self.multi_querier.get_summary_lineage() result = self.multi_querier.get_summary_lineage()
self.assertListEqual(expected_result, result)
assert_equal_lineages(expected_result, result, self.assertListEqual)


def test_get_summary_lineage_success_5(self): def test_get_summary_lineage_success_5(self):
"""Test the success of get_summary_lineage.""" """Test the success of get_summary_lineage."""
@@ -361,7 +361,7 @@ class TestQuerier(TestCase):
result = self.multi_querier.get_summary_lineage( result = self.multi_querier.get_summary_lineage(
summary_dir='/path/to/summary1' summary_dir='/path/to/summary1'
) )
self.assertListEqual(expected_result, result)
assert_equal_lineages(expected_result, result, self.assertListEqual)


def test_get_summary_lineage_success_6(self): def test_get_summary_lineage_success_6(self):
"""Test the success of get_summary_lineage.""" """Test the success of get_summary_lineage."""
@@ -380,7 +380,7 @@ class TestQuerier(TestCase):
result = self.multi_querier.get_summary_lineage( result = self.multi_querier.get_summary_lineage(
summary_dir='/path/to/summary0', filter_keys=filter_keys summary_dir='/path/to/summary0', filter_keys=filter_keys
) )
self.assertListEqual(expected_result, result)
assert_equal_lineages(expected_result, result, self.assertListEqual)


def test_get_summary_lineage_fail(self): def test_get_summary_lineage_fail(self):
"""Test the function of get_summary_lineage with exception.""" """Test the function of get_summary_lineage with exception."""
@@ -423,7 +423,7 @@ class TestQuerier(TestCase):
'count': 2, 'count': 2,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result)
assert_equal_lineages(expected_result, result, self.assertDictEqual)


def test_filter_summary_lineage_success_2(self): def test_filter_summary_lineage_success_2(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
@@ -448,7 +448,7 @@ class TestQuerier(TestCase):
'count': 2, 'count': 2,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result)
assert_equal_lineages(expected_result, result, self.assertDictEqual)


def test_filter_summary_lineage_success_3(self): def test_filter_summary_lineage_success_3(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
@@ -465,7 +465,7 @@ class TestQuerier(TestCase):
'count': 7, 'count': 7,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result)
assert_equal_lineages(expected_result, result, self.assertDictEqual)


def test_filter_summary_lineage_success_4(self): def test_filter_summary_lineage_success_4(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
@@ -483,7 +483,7 @@ class TestQuerier(TestCase):
'count': 7, 'count': 7,
} }
result = self.multi_querier.filter_summary_lineage() result = self.multi_querier.filter_summary_lineage()
self.assertDictEqual(expected_result, result)
assert_equal_lineages(expected_result, result, self.assertDictEqual)


def test_filter_summary_lineage_success_5(self): def test_filter_summary_lineage_success_5(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
@@ -498,7 +498,7 @@ class TestQuerier(TestCase):
'count': 1, 'count': 1,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result)
assert_equal_lineages(expected_result, result, self.assertDictEqual)


def test_filter_summary_lineage_success_6(self): def test_filter_summary_lineage_success_6(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
@@ -520,7 +520,7 @@ class TestQuerier(TestCase):
'count': 7, 'count': 7,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result)
assert_equal_lineages(expected_result, result, self.assertDictEqual)


def test_filter_summary_lineage_success_7(self): def test_filter_summary_lineage_success_7(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
@@ -542,14 +542,14 @@ class TestQuerier(TestCase):
'count': 7, 'count': 7,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result)
assert_equal_lineages(expected_result, result, self.assertDictEqual)


def test_filter_summary_lineage_success_8(self): def test_filter_summary_lineage_success_8(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
condition = { condition = {
'metric/accuracy': { 'metric/accuracy': {
'lt': 1.0000006,
'gt': 1.0000004
'lt': 1.6000006,
'gt': 1.4000004
} }
} }
expected_result = { expected_result = {
@@ -558,7 +558,7 @@ class TestQuerier(TestCase):
'count': 1, 'count': 1,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result)
assert_equal_lineages(expected_result, result, self.assertDictEqual)


def test_filter_summary_lineage_success_9(self): def test_filter_summary_lineage_success_9(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
@@ -572,14 +572,14 @@ class TestQuerier(TestCase):
'count': 7, 'count': 7,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result)
assert_equal_lineages(expected_result, result, self.assertDictEqual)


def test_filter_summary_lineage_fail(self): def test_filter_summary_lineage_fail(self):
"""Test the function of filter_summary_lineage with exception.""" """Test the function of filter_summary_lineage with exception."""
condition = { condition = {
'xxx': { 'xxx': {
'lt': 1.0000006,
'gt': 1.0000004
'lt': 1.6000006,
'gt': 1.4000004
} }
} }
self.assertRaises( self.assertRaises(


+ 37
- 23
tests/ut/lineagemgr/querier/test_query_model.py View File

@@ -21,6 +21,7 @@ from mindinsight.lineagemgr.querier.query_model import LineageObj


from . import event_data from . import event_data
from .test_querier import create_filtration_result, create_lineage_info from .test_querier import create_filtration_result, create_lineage_info
from ....utils.tools import assert_equal_lineages




class TestLineageObj(TestCase): class TestLineageObj(TestCase):
@@ -53,49 +54,62 @@ class TestLineageObj(TestCase):
def test_property(self): def test_property(self):
"""Test the function of getting property.""" """Test the function of getting property."""
self.assertEqual(self.summary_dir, self.lineage_obj.summary_dir) self.assertEqual(self.summary_dir, self.lineage_obj.summary_dir)
self.assertDictEqual(
assert_equal_lineages(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'],
self.lineage_obj.algorithm
self.lineage_obj.algorithm,
self.assertDictEqual
) )
self.assertDictEqual(
assert_equal_lineages(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'],
self.lineage_obj.model
self.lineage_obj.model,
self.assertDictEqual
) )
self.assertDictEqual(
assert_equal_lineages(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['train_dataset'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['train_dataset'],
self.lineage_obj.train_dataset
self.lineage_obj.train_dataset,
self.assertDictEqual
) )
self.assertDictEqual(
assert_equal_lineages(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['hyper_parameters'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['hyper_parameters'],
self.lineage_obj.hyper_parameters
self.lineage_obj.hyper_parameters,
self.assertDictEqual
) )
self.assertDictEqual(event_data.METRIC_0, self.lineage_obj.metric)
self.assertDictEqual(
assert_equal_lineages(
event_data.METRIC_0,
self.lineage_obj.metric,
self.assertDictEqual
)
assert_equal_lineages(
event_data.EVENT_EVAL_DICT_0['evaluation_lineage']['valid_dataset'], event_data.EVENT_EVAL_DICT_0['evaluation_lineage']['valid_dataset'],
self.lineage_obj.valid_dataset
self.lineage_obj.valid_dataset,
self.assertDictEqual
) )


def test_property_eval_not_exist(self): def test_property_eval_not_exist(self):
"""Test the function of getting property with no evaluation event.""" """Test the function of getting property with no evaluation event."""
self.assertEqual(self.summary_dir, self.lineage_obj.summary_dir) self.assertEqual(self.summary_dir, self.lineage_obj.summary_dir)
self.assertDictEqual(
assert_equal_lineages(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'],
self.lineage_obj_no_eval.algorithm
self.lineage_obj_no_eval.algorithm,
self.assertDictEqual
) )
self.assertDictEqual(
assert_equal_lineages(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'],
self.lineage_obj_no_eval.model
self.lineage_obj_no_eval.model,
self.assertDictEqual
) )
self.assertDictEqual(
assert_equal_lineages(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['train_dataset'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['train_dataset'],
self.lineage_obj_no_eval.train_dataset
self.lineage_obj_no_eval.train_dataset,
self.assertDictEqual
) )
self.assertDictEqual(
assert_equal_lineages(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['hyper_parameters'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['hyper_parameters'],
self.lineage_obj_no_eval.hyper_parameters
self.lineage_obj_no_eval.hyper_parameters,
self.assertDictEqual
) )
self.assertDictEqual({}, self.lineage_obj_no_eval.metric)
self.assertDictEqual({}, self.lineage_obj_no_eval.valid_dataset)
assert_equal_lineages({}, self.lineage_obj_no_eval.metric, self.assertDictEqual)
assert_equal_lineages({}, self.lineage_obj_no_eval.valid_dataset, self.assertDictEqual)


def test_get_summary_info(self): def test_get_summary_info(self):
"""Test the function of get_summary_info.""" """Test the function of get_summary_info."""
@@ -106,7 +120,7 @@ class TestLineageObj(TestCase):
'model': event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'] 'model': event_data.EVENT_TRAIN_DICT_0['train_lineage']['model']
} }
result = self.lineage_obj.get_summary_info(filter_keys) result = self.lineage_obj.get_summary_info(filter_keys)
self.assertDictEqual(expected_result, result)
assert_equal_lineages(expected_result, result, self.assertDictEqual)


def test_to_model_lineage_dict(self): def test_to_model_lineage_dict(self):
"""Test the function of to_model_lineage_dict.""" """Test the function of to_model_lineage_dict."""
@@ -120,7 +134,7 @@ class TestLineageObj(TestCase):
expected_result['model_lineage']['dataset_mark'] = None expected_result['model_lineage']['dataset_mark'] = None
expected_result.pop('dataset_graph') expected_result.pop('dataset_graph')
result = self.lineage_obj.to_model_lineage_dict() result = self.lineage_obj.to_model_lineage_dict()
self.assertDictEqual(expected_result, result)
assert_equal_lineages(expected_result, result, self.assertDictEqual)


def test_to_dataset_lineage_dict(self): def test_to_dataset_lineage_dict(self):
"""Test the function of to_dataset_lineage_dict.""" """Test the function of to_dataset_lineage_dict."""


+ 26
- 1
tests/ut/profiler/analyser/test_analyser_aicore_detail.py View File

@@ -267,7 +267,7 @@ class TestAicoreDetailAnalyser(TestCase):
result = self._analyser.query(condition) result = self._analyser.query(condition)
self.assertDictEqual(expect_result, result) self.assertDictEqual(expect_result, result)


def test_query_and_sort_by_op_type(self):
def test_query_and_sort_by_op_type_1(self):
"""Test the success of the querying and sorting function by operator type.""" """Test the success of the querying and sorting function by operator type."""
detail_infos = get_detail_infos(indexes=[9, 0, 2, 1, 5, 3, 4]) detail_infos = get_detail_infos(indexes=[9, 0, 2, 1, 5, 3, 4])
expect_result = { expect_result = {
@@ -289,6 +289,31 @@ class TestAicoreDetailAnalyser(TestCase):
) )
self.assertDictEqual(expect_result, result) self.assertDictEqual(expect_result, result)


def test_query_and_sort_by_op_type_2(self):
"""Test the success of the querying and sorting function by operator type."""
detail_infos = get_detail_infos(indexes=[9, 0, 2, 1, 3, 4, 8, 6])
expect_result = {
'col_name': AicoreDetailAnalyser.__col_names__[0:4],
'object': [item[0:4] for item in detail_infos]
}

filter_condition = {
'op_type': {},
'subgraph': {
'in': ['Default']
},
'is_display_detail': False,
'is_display_full_op_name': False
}
op_type_order = [
'MatMul', 'AtomicAddrClean', 'Cast', 'Conv2D', 'TransData'
]
result = self._analyser.query_and_sort_by_op_type(
filter_condition, op_type_order
)
print(result)
self.assertDictEqual(expect_result, result)

def test_col_names(self): def test_col_names(self):
"""Test the querying column names function.""" """Test the querying column names function."""
self.assertListEqual( self.assertListEqual(


+ 74
- 0
tests/utils/tools.py View File

@@ -81,3 +81,77 @@ def compare_result_with_file(result, expected_file_path):
with open(expected_file_path, 'r') as file: with open(expected_file_path, 'r') as file:
expected_results = json.load(file) expected_results = json.load(file)
assert result == expected_results assert result == expected_results


def deal_float_for_dict(res: dict, expected_res: dict, decimal_num=2):
"""
Deal float rounded to specified decimals in dict.

For example:
res:{
"model_lineages": {
"metric": {"acc": 0.1234561}
}
}
expected_res:
{
"model_lineages": {
"metric": {"acc": 0.1234562}
}
}
After:
res:{
"model_lineages": {
"metric": {"acc": 0.12346}
}
}
expected_res:
{
"model_lineages": {
"metric": {"acc": 0.12346}
}
}

Args:
res (dict): e.g.
{
"model_lineages": {
"metric": {"acc": 0.1234561}
}
}
expected_res (dict):
{
"model_lineages": {
"metric": {"acc": 0.1234562}
}
}
decimal_num (int): decimal rounded digits.

"""
for key in res:
value = res[key]
expected_value = expected_res[key]
if isinstance(value, dict):
deal_float_for_dict(value, expected_value)
elif isinstance(value, float):
res[key] = round(value, decimal_num)
expected_res[key] = round(expected_value, decimal_num)


def _deal_float_for_list(list1, list2, decimal_num):
"""Deal float for list1 and list2."""
index = 0
for _ in list1:
deal_float_for_dict(list1[index], list2[index], decimal_num)
index += 1


def assert_equal_lineages(lineages1, lineages2, assert_func, decimal_num=2):
"""Assert float almost equal for lineage data."""
if isinstance(lineages1, list) and isinstance(lineages2, list):
_deal_float_for_list(lineages1, lineages2, decimal_num)
elif lineages1.get('object') is not None and lineages2.get('object') is not None:
_deal_float_for_list(lineages1['object'], lineages2['object'], decimal_num)
else:
deal_float_for_dict(lineages1, lineages2, decimal_num)
assert_func(lineages1, lineages2)

Loading…
Cancel
Save