| @@ -1,3 +1,3 @@ | |||
| [submodule "third_party/securec"] | |||
| path = third_party/securec | |||
| url = https://gitee.com/openeuler/bounds_checking_function.git | |||
| url = https://gitee.com/openeuler/libboundscheck.git | |||
| @@ -1,5 +1,34 @@ | |||
| ## MindInsight | |||
| # Release 0.3.0-alpha | |||
| ## Major Features and Improvements | |||
| * Profiling | |||
| * Provide easy to use apis for profiling start/stop and profiling data analyse (on Ascend only). | |||
| * Provide operators performance display and analysis on MindInsight UI. | |||
| * Large scale network computation graph visualization. | |||
| * Optimize summary record implementation and improve its performance. | |||
| * Improve lineage usability | |||
| * Optimize lineage display and enrich tabular operation. | |||
| * Decouple lineage callback from `SummaryRecord`. | |||
| * Support scalar compare of multiple runs. | |||
| * Scripts conversion from other frameworks | |||
| * Support for converting PyTorch scripts within TorchVision to MindSpore scripts automatically. | |||
| ## Bugfixes | |||
| * Fix pb files loaded problem when files are modified at the same time ([!53](https://gitee.com/mindspore/mindinsight/pulls/53)). | |||
| * Fix load data thread stuck in `LineageCacheItemUpdater` ([!114](https://gitee.com/mindspore/mindinsight/pulls/114)). | |||
| * Fix samples from previous steps erased due to tags size too large problem ([!86](https://gitee.com/mindspore/mindinsight/pulls/86)). | |||
| * Fix image and histogram event package error ([!1143](https://gitee.com/mindspore/mindspore/pulls/1143)). | |||
| * Equally distribute histogram ignoring actual step number to avoid large white space ([!66](https://gitee.com/mindspore/mindinsight/pulls/66)). | |||
| ## Thanks to our Contributors | |||
| Thanks goes to these wonderful people: | |||
| Chao Chen, Congli Gao, Ye Huang, Weifeng Huang, Zhenzhong Kou, Hongzhang Li, Longfei Li, Yongxiong Liang, Pengting Luo, Yanming Miao, Gongchang Ou, Yongxiu Qu, Hui Pan, Luyu Qiu, Junyan Qin, Kai Wen, Weining Wang, Yue Wang, Zhuanke Wu, Yifan Xia, Weibiao Yu, Ximiao Yu, Ting Zhao, Jianfeng Zhu. | |||
| Contributions of any kind are welcome! | |||
| # Release 0.2.0-alpha | |||
| ## Major Features and Improvements | |||
| @@ -14,7 +43,7 @@ Now you can use [`HistogramSummary`](https://www.mindspore.cn/api/zh-CN/master/a | |||
| * Fix unsafe functions and duplication files and redundant codes ([!14](https://gitee.com/mindspore/mindinsight/pulls/14)). | |||
| * Fix sha256 checksum missing bug ([!24](https://gitee.com/mindspore/mindinsight/pulls/24)). | |||
| * Fix graph bug when node name is empty ([!34](https://gitee.com/mindspore/mindinsight/pulls/34)). | |||
| * Fix start/stop command exit-code incorrect ([!44](https://gitee.com/mindspore/mindinsight/pulls/44)). | |||
| * Fix start/stop command error code incorrect ([!44](https://gitee.com/mindspore/mindinsight/pulls/44)). | |||
| ## Thanks to our Contributors | |||
| Thanks goes to these wonderful people: | |||
| @@ -14,4 +14,4 @@ | |||
| # ============================================================================ | |||
| """Mindinsight version module.""" | |||
| VERSION = '0.2.0' | |||
| VERSION = '0.3.0' | |||
| @@ -257,12 +257,11 @@ class SummaryWatcher: | |||
| 'mtime': mtime, | |||
| } | |||
| if relative_path not in summary_dict: | |||
| summary_dict[relative_path] = { | |||
| 'ctime': ctime, | |||
| 'mtime': mtime, | |||
| 'profiler': profiler, | |||
| } | |||
| summary_dict[relative_path] = { | |||
| 'ctime': ctime, | |||
| 'mtime': mtime, | |||
| 'profiler': profiler, | |||
| } | |||
| def is_summary_directory(self, summary_base_dir, relative_path): | |||
| """ | |||
| @@ -16,8 +16,10 @@ | |||
| from urllib.parse import unquote | |||
| from mindinsight.utils.exceptions import ParamValueError, UrlDecodeError | |||
| from mindinsight.datavisual.common.log import logger | |||
| from mindinsight.datavisual.utils.tools import if_nan_inf_to_none | |||
| from mindinsight.datavisual.common.exceptions import ScalarNotExistError | |||
| from mindinsight.datavisual.common.exceptions import TrainJobNotExistError | |||
| from mindinsight.datavisual.common.validation import Validation | |||
| from mindinsight.datavisual.processors.base_processor import BaseProcessor | |||
| @@ -71,25 +73,44 @@ class ScalarsProcessor(BaseProcessor): | |||
| scalars = [] | |||
| for train_id in train_ids: | |||
| for tag in tags: | |||
| try: | |||
| tensors = self._data_manager.list_tensors(train_id, tag) | |||
| except ParamValueError: | |||
| continue | |||
| scalar = { | |||
| 'train_id': train_id, | |||
| 'tag': tag, | |||
| 'values': [], | |||
| } | |||
| for tensor in tensors: | |||
| scalar['values'].append({ | |||
| 'wall_time': tensor.wall_time, | |||
| 'step': tensor.step, | |||
| 'value': if_nan_inf_to_none('scalar_value', tensor.value), | |||
| }) | |||
| scalars.append(scalar) | |||
| scalars += self._get_train_scalars(train_id, tags) | |||
| return scalars | |||
| def _get_train_scalars(self, train_id, tags): | |||
| """ | |||
| Get scalar data for given train_id and tags. | |||
| Args: | |||
| train_id (str): Specify train job ID. | |||
| tags (list): Specify list of tags. | |||
| Returns: | |||
| list[dict], a list of dictionaries containing the `wall_time`, `step`, `value` for each scalar. | |||
| """ | |||
| scalars = [] | |||
| for tag in tags: | |||
| try: | |||
| tensors = self._data_manager.list_tensors(train_id, tag) | |||
| except ParamValueError: | |||
| continue | |||
| except TrainJobNotExistError: | |||
| logger.warning('Can not find the given train job in cache.') | |||
| return [] | |||
| scalar = { | |||
| 'train_id': train_id, | |||
| 'tag': tag, | |||
| 'values': [], | |||
| } | |||
| for tensor in tensors: | |||
| scalar['values'].append({ | |||
| 'wall_time': tensor.wall_time, | |||
| 'step': tensor.step, | |||
| 'value': if_nan_inf_to_none('scalar_value', tensor.value), | |||
| }) | |||
| scalars.append(scalar) | |||
| return scalars | |||
| @@ -14,6 +14,7 @@ | |||
| # ============================================================================ | |||
| """Train task manager.""" | |||
| from mindinsight.utils.exceptions import ParamTypeError | |||
| from mindinsight.datavisual.common.log import logger | |||
| from mindinsight.datavisual.common import exceptions | |||
| from mindinsight.datavisual.common.enums import PluginNameEnum | |||
| @@ -141,9 +142,20 @@ class TrainTaskManager(BaseProcessor): | |||
| Returns: | |||
| dict, indicates train job ID and its current cache status. | |||
| Raises: | |||
| ParamTypeError, if the given train_ids parameter is not in valid type. | |||
| """ | |||
| if not isinstance(train_ids, list): | |||
| logger.error("train_ids must be list.") | |||
| raise ParamTypeError('train_ids', list) | |||
| cache_result = [] | |||
| for train_id in train_ids: | |||
| if not isinstance(train_id, str): | |||
| logger.error("train_id must be str.") | |||
| raise ParamTypeError('train_id', str) | |||
| try: | |||
| train_job = self._data_manager.get_train_job(train_id) | |||
| except exceptions.TrainJobNotExistError: | |||
| @@ -236,6 +236,7 @@ class EvalLineage(Callback): | |||
| """ | |||
| Collect lineage of an evaluation job. | |||
| Args: | |||
| summary_record (Union[SummaryRecord, str]): The `SummaryRecord` object which | |||
| is used to record the summary value(see mindspore.train.summary.SummaryRecord), | |||
| or a log dir(as a `str`) to be passed to `LineageSummary` to create | |||
| @@ -284,8 +285,8 @@ class EvalLineage(Callback): | |||
| self.lineage_summary = LineageSummary(self.lineage_log_dir) | |||
| self.user_defined_info = user_defined_info | |||
| if user_defined_info: | |||
| validate_user_defined_info(user_defined_info) | |||
| if self.user_defined_info: | |||
| validate_user_defined_info(self.user_defined_info) | |||
| except MindInsightException as err: | |||
| log.error(err) | |||
| @@ -410,7 +410,7 @@ def validate_path(summary_path): | |||
| def validate_user_defined_info(user_defined_info): | |||
| """ | |||
| Validate user defined info. | |||
| Validate user defined info, delete the item if its key is in lineage. | |||
| Args: | |||
| user_defined_info (dict): The user defined info. | |||
| @@ -437,10 +437,13 @@ def validate_user_defined_info(user_defined_info): | |||
| field_map = set(FIELD_MAPPING.keys()) | |||
| user_defined_keys = set(user_defined_info.keys()) | |||
| all_keys = field_map | user_defined_keys | |||
| insertion = list(field_map & user_defined_keys) | |||
| if len(field_map) + len(user_defined_keys) != len(all_keys): | |||
| raise LineageParamValueError("There are some keys have defined in lineage.") | |||
| if insertion: | |||
| for key in insertion: | |||
| user_defined_info.pop(key) | |||
| raise LineageParamValueError("There are some keys have defined in lineage. " | |||
| "Duplicated key(s): %s. " % insertion) | |||
| def validate_train_id(relative_path): | |||
| @@ -2,13 +2,6 @@ | |||
| MindConverter is a tool that converting PyTorch scripts to MindSpore scripts. With minial manual editing and the guidance from conversion reports, users may easily migrate their model from PyTorch framework to MindSpore. | |||
| ### System Requirements | |||
| * PyTorch v1.5.0 | |||
| * MindSpore v0.2.0 | |||
| ### Installation | |||
| This tool is part of MindInsight and accessible to users after installing MindInsight, no extra installation is needed. | |||
| @@ -24,8 +17,6 @@ mindconverter commandline usage: | |||
| mindconverter [-h] [--version] --in_file IN_FILE [--output OUTPUT] | |||
| [--report REPORT] | |||
| MindConverter CLI entry point (version: 0.2.0) | |||
| optional arguments: | |||
| -h, --help show this help message and exit | |||
| --version show program's version number and exit | |||
| @@ -36,13 +27,31 @@ optional arguments: | |||
| directorys | |||
| ``` | |||
| Usage example: | |||
| #### Use example: | |||
| We have a collection of PyTorch model scripts | |||
| ```buildoutcfg | |||
| ~$ ls | |||
| models | |||
| ~$ ls models | |||
| alexnet.py resnet.py vgg.py | |||
| ``` | |||
| Then we set the PYTHONPATH environment variable and convert alexnet.py | |||
| ```buildoutcfg | |||
| ~$ export PYTHONPATH=~/models | |||
| ~$ mindconverter --in_file models/alexnet.py | |||
| ``` | |||
| Then we will see a conversion report and the output MindSpore script | |||
| ```buildoutcfg | |||
| export PYTHONPATH=~/my_pt_proj/models | |||
| mindconverter --in_file lenet.py | |||
| ~$ ls | |||
| alexnet_report.txt models output | |||
| ~$ ls output | |||
| alexent.py | |||
| ``` | |||
| Since the conversion is not 100% flawless, we encourage users to checkout the reports when fixing issues of the converted scripts. | |||
| Since the conversion is not 100% flawless, we encourage users to checkout the report when fixing issues of the converted script. | |||
| ### Unsupported Situation #1 | |||
| @@ -28,6 +28,8 @@ from mindinsight.mindconverter.config import ALL_UNSUPPORTED | |||
| from mindinsight.mindconverter.common.log import logger | |||
| from mindinsight.mindconverter.forward_call import ForwardCall | |||
| LINE_NO_INDEX_DIFF = 1 | |||
| class Converter: | |||
| """Convert class""" | |||
| @@ -197,6 +199,7 @@ class Converter: | |||
| raise ValueError('"(" not found, {} should work with "("'.format(call_name)) | |||
| right = self.find_right_parentheses(code, left) | |||
| end = right | |||
| expr = code[start:end + 1] | |||
| args_str = code[left:right + 1] | |||
| @@ -336,6 +339,96 @@ class Converter: | |||
| mapping.update(convert_fun(*args)) | |||
| return mapping | |||
| @staticmethod | |||
| def get_code_start_line_num(source_lines): | |||
| """ | |||
| Get the start code line number exclude comments. | |||
| Args: | |||
| source_lines (list[str]): Split results of original code. | |||
| Returns: | |||
| int, the start line number. | |||
| """ | |||
| stack = [] | |||
| index = 0 | |||
| for i, line in enumerate(source_lines): | |||
| if line.strip().startswith('#'): | |||
| continue | |||
| if line.strip().startswith('"""'): | |||
| if not line.endswith('"""\n'): | |||
| stack.append('"""') | |||
| continue | |||
| if line.strip().startswith("'''"): | |||
| if not line.endswith("'''\n"): | |||
| stack.append("'''") | |||
| continue | |||
| if line.endswith('"""\n') or line.endswith("'''\n"): | |||
| stack.pop() | |||
| continue | |||
| if line.strip() != '' and not stack: | |||
| index = i | |||
| break | |||
| return index | |||
| def update_code_and_convert_info(self, code, mapping): | |||
| """ | |||
| Replace code according to mapping, and update convert info. | |||
| Args: | |||
| code (str): The code to replace. | |||
| mapping (dict): Mapping for original code and the replaced code. | |||
| Returns: | |||
| str, the replaced code. | |||
| """ | |||
| for key, value in mapping.items(): | |||
| code = code.replace(key, value) | |||
| source_lines = code.splitlines(keepends=True) | |||
| start_line_number = self.get_code_start_line_num(source_lines) | |||
| add_import_infos = ['import mindspore\n', | |||
| 'import mindspore.nn as nn\n', | |||
| 'import mindspore.ops.operations as P\n'] | |||
| for i, add_import_info in enumerate(add_import_infos): | |||
| source_lines.insert(start_line_number + i, add_import_info) | |||
| self.convert_info += '[Add Import] {}.\n'.format(add_import_info.strip()) | |||
| insert_count = len(add_import_infos) | |||
| line_diff = insert_count - LINE_NO_INDEX_DIFF | |||
| for i in range(start_line_number + insert_count, len(source_lines)): | |||
| line = source_lines[i] | |||
| if (line.startswith('from torch') and 'import' in line) or line.startswith('import torch'): | |||
| new_line = '# ' + line | |||
| source_lines[i] = new_line | |||
| self.convert_info += '[Annotate][Line{:3d}] {} is annotated.\n'.format(i - line_diff, line.strip()) | |||
| if line.strip().startswith('class') and '(nn.Module)' in line: | |||
| new_line = line.replace('nn.Module', 'nn.Cell') | |||
| source_lines[i] = new_line | |||
| self.convert_info += '[Convert][Line{:3d}] nn.Module is converted.\n'.format(i - line_diff) | |||
| if line.strip().startswith('def forward('): | |||
| new_line = line.replace('forward', 'construct') | |||
| source_lines[i] = new_line | |||
| self.convert_info += '[Convert][Line{:3d}] forward is converted.\n'.format(i - line_diff) | |||
| if 'nn.Linear' in line: | |||
| new_line = line.replace('nn.Linear', 'nn.Dense') | |||
| source_lines[i] = new_line | |||
| self.convert_info += '[Convert][Line{:3d}] nn.Linear is converted.\n'.format(i - line_diff) | |||
| if '(nn.Sequential)' in line: | |||
| new_line = line.replace('nn.Sequential', 'nn.SequentialCell') | |||
| source_lines[i] = new_line | |||
| self.convert_info += '[Convert][Line{:3d}] nn.Sequential is converted.\n'.format(i - line_diff) | |||
| if 'nn.init.' in line: | |||
| new_line = line.replace('nn.init', 'pass # nn.init') | |||
| source_lines[i] = new_line | |||
| self.convert_info += '[Annotate][Line{:3d}] {} is annotated.\n'.format(i - line_diff, 'nn.init') | |||
| code = ''.join(source_lines) | |||
| return code | |||
| def convert(self, import_name, output_dir, report_dir): | |||
| """ | |||
| Convert a module's code, code converted will be save in output_dir, and a report will be save in report_dir. | |||
| @@ -346,10 +439,10 @@ class Converter: | |||
| report_dir (str): The path to save report file. | |||
| """ | |||
| logger.info("Start converting %s", import_name) | |||
| self.convert_info += '[Start Convert]\nThe module is {}\n'.format(import_name) | |||
| start_info = '[Start Convert]\n' | |||
| module_info = 'The module is {}.\n'.format(import_name) | |||
| import_mod = importlib.import_module(import_name) | |||
| srcfile = inspect.getsourcefile(import_mod) | |||
| logger.info("Script file is %s", srcfile) | |||
| @@ -358,40 +451,14 @@ class Converter: | |||
| # replace python function under nn.Module | |||
| mapping = self.get_mapping(import_mod, forward_list) | |||
| code = inspect.getsource(import_mod) | |||
| for key, value in mapping.items(): | |||
| code = code.replace(key, value) | |||
| code = 'import mindspore.ops.operations as P\n' + code | |||
| code = 'import mindspore.nn as nn\n' + code | |||
| code = 'import mindspore\n' + code | |||
| self.convert_info += '||[Import Add] Add follow import sentences:\n' | |||
| self.convert_info += 'import mindspore.ops.operations as P\n' | |||
| self.convert_info += 'import mindspore.nn as nn\n' | |||
| self.convert_info += 'import mindspore\n\n' | |||
| code = code.replace('import torch', '# import torch') | |||
| code = code.replace('from torch', '# from torch') | |||
| code = code.replace('(nn.Module):', '(nn.Cell):') | |||
| code = code.replace('forward(', 'construct(') | |||
| code = code.replace('nn.Linear', 'nn.Dense') | |||
| code = code.replace('(nn.Sequential)', '(nn.SequentialCell)') | |||
| code = code.replace('nn.init.', 'pass # nn.init.') | |||
| self.convert_info += '||[Import Annotated] Annotated follow import sentences:\n' | |||
| self.convert_info += 'import sentence on torch as follows are annotated:\n' | |||
| self.convert_info += 'import torch\n' | |||
| self.convert_info += 'from torch ...\n' | |||
| self.convert_info += '||[Explicit Convert] Module or function are explicitly converted as follows:\n' | |||
| self.convert_info += '[nn.Module] is converted to [nn.Cell]\n' | |||
| self.convert_info += '[forward] is converted to [construct]\n' | |||
| self.convert_info += '[nn.Linear] is converted to [nn.Dense]\n' | |||
| self.convert_info += '[nn.Sequential] is converted to [nn.SequentialCell]\n' | |||
| self.convert_info += '[nn.init] is not converted and annotated\n' | |||
| self.convert_info += '[Convert over]' | |||
| code = self.update_code_and_convert_info(code, mapping) | |||
| convert_info_split = self.convert_info.splitlines(keepends=True) | |||
| convert_info_split = sorted(convert_info_split) | |||
| convert_info_split.insert(0, start_info) | |||
| convert_info_split.insert(1, module_info) | |||
| convert_info_split.append('[Convert Over]') | |||
| self.convert_info = ''.join(convert_info_split) | |||
| dest_file = os.path.join(output_dir, os.path.basename(srcfile)) | |||
| with os.fdopen(os.open(dest_file, self.flags, self.modes), 'w') as file: | |||
| @@ -428,7 +495,6 @@ def _path_split(file): | |||
| Returns: | |||
| list[str], list of file tail | |||
| """ | |||
| file_dir, name = os.path.split(file) | |||
| if file_dir: | |||
| @@ -456,6 +522,6 @@ def main(files_config): | |||
| module_name = '.'.join(in_file_split) | |||
| convert_ins.convert(module_name, files_config['outfile_dir'], files_config['report_dir']) | |||
| in_module = files_config['in_module'] | |||
| in_module = files_config.get('in_module') | |||
| if in_module: | |||
| convert_ins.convert(in_module, files_config['outfile_dir'], files_config['report_dir']) | |||
| @@ -12,16 +12,18 @@ The Profiler enables users to: | |||
| To enable profiling on MindSpore, the MindInsight Profiler apis should be added to the script: | |||
| 1. Import MindInsight Profiler | |||
| ``` | |||
| from mindinsight.profiler import Profiler | |||
| 2. Initialize the Profiler before training | |||
| ``` | |||
| 2. Initialize the Profiler after set context, and before the network initialization. | |||
| Example: | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=int(os.environ["DEVICE_ID"])) | |||
| profiler = Profiler(output_path="./data", is_detail=True, is_show_op_path=False, subgraph='All') | |||
| Parameters including: | |||
| net = Net() | |||
| Parameters of Profiler including: | |||
| subgraph (str): Defines which subgraph to monitor and analyse, can be 'all', 'Default', 'Gradients'. | |||
| is_detail (bool): Whether to show profiling data for op_instance level, only show optype level if False. | |||
| @@ -31,9 +33,9 @@ To enable profiling on MindSpore, the MindInsight Profiler apis should be added | |||
| will deal with all op if null. | |||
| optypes_not_deal (list): Op type names, the data of which optype will not be collected and analysed. | |||
| 3. Call Profiler.analyse() at the end of the program | |||
| 3. Call ```Profiler.analyse()``` at the end of the program | |||
| Profiler.analyse() will collect profiling data and generate the analysis results. | |||
| ```Profiler.analyse()``` will collect profiling data and generate the analysis results. | |||
| After training, we can open MindInsight UI to analyse the performance. | |||
| @@ -13,7 +13,7 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| Profiler Module Introduction | |||
| Profiler Module Introduction. | |||
| This module provides Python APIs to enable the profiling of MindSpore neural networks. | |||
| Users can import the mindinsight.profiler.Profiler, initialize the Profiler object to start profiling, | |||
| @@ -124,6 +124,8 @@ class AicoreDetailAnalyser(BaseAnalyser): | |||
| result = [] | |||
| for op_type in op_type_order: | |||
| detail_infos = type_detail_cache.get(op_type) | |||
| if detail_infos is None: | |||
| continue | |||
| detail_infos.sort(key=lambda item: item[2], reverse=True) | |||
| result.extend(detail_infos) | |||
| @@ -15,6 +15,7 @@ | |||
| """ | |||
| The parser for AI CPU preprocess data. | |||
| """ | |||
| import os | |||
| from tabulate import tabulate | |||
| @@ -50,7 +51,7 @@ class DataPreProcessParser: | |||
| def execute(self): | |||
| """Execute the parser, get result data, and write it to the output file.""" | |||
| if self._source_file_name is None: | |||
| if not os.path.exists(self._source_file_name): | |||
| logger.info("Did not find the aicpu profiling source file") | |||
| return | |||
| @@ -37,19 +37,21 @@ class Profiler: | |||
| """ | |||
| Performance profiling API. | |||
| Enable MindSpore users to profile the neural network. | |||
| Enable MindSpore users to profile the performance of neural network. | |||
| Args: | |||
| subgraph (str): Defines which subgraph to monitor and analyse, can be 'all', 'Default', 'Gradients'. | |||
| subgraph (str): Define which subgraph to monitor and analyse, can be 'all', 'Default', 'Gradients'. | |||
| is_detail (bool): Whether to show profiling data for op_instance level, only show optype level if False. | |||
| is_show_op_path (bool): Whether to save the full path for each op instance. | |||
| output_path (str): Output data path. | |||
| optypes_to_deal (list): Op type names, the data of which optype should be collected and analysed, | |||
| will deal with all op if null. | |||
| optypes_not_deal (list): Op type names, the data of which optype will not be collected and analysed. | |||
| optypes_to_deal (list[str]): Op type names, the data of which optype should be collected and analysed, | |||
| will deal with all op if null. | |||
| optypes_not_deal (list[str]): Op type names, the data of which optype will not be collected and analysed. | |||
| Examples: | |||
| >>> from mindinsight.profiler import Profiler | |||
| >>> context.set_context(mode=context.GRAPH_MODE, device_target=“Ascend”, | |||
| >>> device_id=int(os.environ["DEVICE_ID"])) | |||
| >>> profiler = Profiler(subgraph='all', is_detail=True, is_show_op_path=False, output_path='./data') | |||
| >>> model = Model(train_network) | |||
| >>> dataset = get_dataset() | |||
| @@ -64,10 +66,30 @@ class Profiler: | |||
| def __init__(self, subgraph='all', is_detail=True, is_show_op_path=False, output_path='./data', | |||
| optypes_to_deal='', optypes_not_deal='Variable', job_id=""): | |||
| dev_id = os.getenv('DEVICE_ID') | |||
| # get device_id and device_target | |||
| device_target = "" | |||
| try: | |||
| import mindspore.context as context | |||
| dev_id = str(context.get_context("device_id")) | |||
| device_target = context.get_context("device_target") | |||
| except ImportError: | |||
| logger.error("Profiling: fail to import context from mindspore.") | |||
| except ValueError as err: | |||
| logger.error("Profiling: fail to get context %s", err.message) | |||
| if not dev_id: | |||
| dev_id = str(os.getenv('DEVICE_ID')) | |||
| if not dev_id: | |||
| dev_id = "0" | |||
| logger.error("Fail to get DEVICE_ID, use 0 instead.") | |||
| if device_target and device_target != "Davinci" \ | |||
| and device_target != "Ascend": | |||
| msg = ("Profiling: unsupport backend: %s" \ | |||
| % device_target) | |||
| raise RuntimeError(msg) | |||
| self._dev_id = dev_id | |||
| self._container_path = os.path.join(self._base_profiling_container_path, dev_id) | |||
| data_path = os.path.join(self._container_path, "data") | |||
| @@ -88,7 +110,7 @@ class Profiler: | |||
| except ImportError: | |||
| logger.error("Profiling: fail to import context from mindspore.") | |||
| except ValueError as err: | |||
| logger.err("Profiling: fail to set context", err.message) | |||
| logger.error("Profiling: fail to set context, %s", err.message) | |||
| os.environ['AICPU_PROFILING_MODE'] = 'true' | |||
| os.environ['PROFILING_DIR'] = str(self._container_path) | |||
| @@ -107,6 +129,8 @@ class Profiler: | |||
| Examples: | |||
| >>> from mindinsight.profiler import Profiler | |||
| >>> context.set_context(mode=context.GRAPH_MODE, device_target=“Ascend”, | |||
| >>> device_id=int(os.environ["DEVICE_ID"])) | |||
| >>> profiler = Profiler(subgraph='all', is_detail=True, is_show_op_path=False, output_path='./data') | |||
| >>> model = Model(train_network) | |||
| >>> dataset = get_dataset() | |||
| @@ -18,7 +18,7 @@ module.exports = { | |||
| [ | |||
| '@vue/app', | |||
| { | |||
| polyfills: ['es6.promise', 'es6.symbol'], | |||
| polyfills: ['es.promise', 'es.symbol'], | |||
| }, | |||
| ], | |||
| ], | |||
| @@ -24,7 +24,6 @@ | |||
| "@intlify/vue-i18n-loader": "0.6.1", | |||
| "@vue/cli-service": "4.1.0", | |||
| "@vue/cli-plugin-babel": "4.1.0", | |||
| "babel-core": "6.26.0", | |||
| "babel-eslint": "10.0.3", | |||
| "eslint": "6.6.0", | |||
| "eslint-config-google": "0.13.0", | |||
| @@ -162,6 +162,7 @@ export default { | |||
| listSelectAll() { | |||
| this.operateSelectAll = !this.operateSelectAll; | |||
| this.multiSelectedItemNames = {}; | |||
| this.selectedNumber = 0; | |||
| // Setting the status of list items | |||
| if (this.operateSelectAll) { | |||
| if (this.isLimit) { | |||
| @@ -171,7 +172,7 @@ export default { | |||
| break; | |||
| } | |||
| const listItem = this.checkListArr[i]; | |||
| if (listItem.show) { | |||
| if ((listItem.show && !listItem.checked) || listItem.checked) { | |||
| listItem.checked = true; | |||
| this.multiSelectedItemNames[listItem.label] = true; | |||
| this.selectedNumber++; | |||
| @@ -216,14 +217,17 @@ export default { | |||
| } | |||
| this.valiableSearchInput = this.searchInput; | |||
| this.multiSelectedItemNames = {}; | |||
| this.selectedNumber = 0; | |||
| let itemSelectAll = true; | |||
| // Filter the tags that do not meet the conditions in the operation bar and hide them | |||
| this.checkListArr.forEach((listItem) => { | |||
| if (listItem.checked) { | |||
| this.multiSelectedItemNames[listItem.label] = true; | |||
| this.selectedNumber++; | |||
| } | |||
| if (reg.test(listItem.label)) { | |||
| listItem.show = true; | |||
| if (listItem.checked) { | |||
| this.multiSelectedItemNames[listItem.label] = true; | |||
| } else { | |||
| if (!listItem.checked) { | |||
| itemSelectAll = false; | |||
| } | |||
| } else { | |||
| @@ -232,7 +236,7 @@ export default { | |||
| }); | |||
| // Update the selected status of the Select All button | |||
| if (this.isLimit && !itemSelectAll) { | |||
| itemSelectAll = this.selectedNumber >= this.limitNum; | |||
| itemSelectAll = this.selectedNumber >= this.limitNum || this.selectedNumber >= this.checkListArr.length; | |||
| } | |||
| this.operateSelectAll = itemSelectAll; | |||
| this.$emit('selectedChange', this.multiSelectedItemNames); | |||
| @@ -271,7 +275,7 @@ export default { | |||
| } | |||
| }); | |||
| if (this.isLimit && !itemSelectAll) { | |||
| itemSelectAll = this.selectedNumber >= this.limitNum; | |||
| itemSelectAll = this.selectedNumber >= this.limitNum || this.selectedNumber >= this.checkListArr.length; | |||
| } | |||
| this.operateSelectAll = itemSelectAll; | |||
| // Return a dictionary containing selected items. | |||
| @@ -309,23 +313,24 @@ export default { | |||
| const loopCount = this.checkListArr.length; | |||
| for (let i = 0; i < loopCount; i++) { | |||
| const listItem = this.checkListArr[i]; | |||
| if (reg.test(listItem.label)) { | |||
| listItem.show = true; | |||
| if (listItem.checked) { | |||
| if (this.selectedNumber >= this.limitNum) { | |||
| listItem.checked = false; | |||
| itemSelectAll = false; | |||
| } else if (listItem.checked) { | |||
| } else { | |||
| this.multiSelectedItemNames[listItem.label] = true; | |||
| this.selectedNumber++; | |||
| } else { | |||
| itemSelectAll = false; | |||
| } | |||
| } | |||
| if (reg.test(listItem.label)) { | |||
| listItem.show = true; | |||
| } else { | |||
| listItem.show = false; | |||
| } | |||
| } | |||
| if (!itemSelectAll && this.selectedNumber >= this.limitNum) { | |||
| if (this.selectedNumber >= this.limitNum || this.selectedNumber >= this.checkListArr.length) { | |||
| itemSelectAll = true; | |||
| } else { | |||
| itemSelectAll = false; | |||
| } | |||
| } else { | |||
| this.checkListArr.forEach((listItem) => { | |||
| @@ -78,7 +78,7 @@ | |||
| "userDefinedLabel": "User Defined", | |||
| "hyperLabel": "Hyper", | |||
| "otherLabel": "其他", | |||
| "remarkTips": "提示:终止服务后备注及tag信息将被清除" | |||
| "remarkTips": "提示:终止服务后备注及tag将被清除" | |||
| }, | |||
| "dataTraceback": { | |||
| "details": "详情", | |||
| @@ -32,6 +32,7 @@ export default new Vuex.Store({ | |||
| : 3, | |||
| // multiSelevtGroup component count | |||
| multiSelectedGroupCount: 0, | |||
| tableId: 0, | |||
| }, | |||
| mutations: { | |||
| // set cancelTokenArr | |||
| @@ -72,6 +73,9 @@ export default new Vuex.Store({ | |||
| multiSelectedGroupComponentNum(state) { | |||
| state.multiSelectedGroupCount++; | |||
| }, | |||
| increaseTableId(state) { | |||
| state.tableId++; | |||
| }, | |||
| }, | |||
| actions: {}, | |||
| }); | |||
| @@ -24,11 +24,12 @@ limitations under the License. | |||
| type="primary" | |||
| size="mini" | |||
| plain | |||
| v-show="(summaryDirList&&!summaryDirList.length)||(totalSeries&&totalSeries.length)"> | |||
| v-show="(summaryDirList && !summaryDirList.length)||(totalSeries && totalSeries.length)"> | |||
| {{ $t('modelTraceback.showAllData') }} | |||
| </el-button> | |||
| <div class="select-container" | |||
| v-show="totalSeries&&totalSeries.length&&(!summaryDirList||(summaryDirList&&summaryDirList.length))"> | |||
| v-show="totalSeries && totalSeries.length && | |||
| (!summaryDirList || (summaryDirList && summaryDirList.length))"> | |||
| <div class="display-column"> | |||
| {{$t('modelTraceback.displayColumn')}} | |||
| </div> | |||
| @@ -50,17 +51,17 @@ limitations under the License. | |||
| <button type="text" | |||
| @click="allSelect" | |||
| class="select-all-button" | |||
| :class="[selectCheckAll?'checked-color':'button-text', | |||
| basearr.length>checkOptions.length ? 'btn-disabled' : '']" | |||
| :disabled="basearr.length>checkOptions.length"> | |||
| :class="[selectCheckAll ? 'checked-color' : 'button-text', | |||
| basearr.length > checkOptions.length ? 'btn-disabled' : '']" | |||
| :disabled="basearr.length > checkOptions.length"> | |||
| {{ $t('public.selectAll')}} | |||
| </button> | |||
| <button type="text" | |||
| @click="deselectAll" | |||
| class="deselect-all-button" | |||
| :class="[!selectCheckAll?'checked-color':'button-text', | |||
| basearr.length>checkOptions.length ? 'btn-disabled' : '']" | |||
| :disabled="basearr.length>checkOptions.length"> | |||
| :class="[!selectCheckAll ? 'checked-color' : 'button-text', | |||
| basearr.length > checkOptions.length ? 'btn-disabled' : '']" | |||
| :disabled="basearr.length > checkOptions.length"> | |||
| {{ $t('public.deselectAll')}} | |||
| </button> | |||
| </div> | |||
| @@ -69,7 +70,7 @@ limitations under the License. | |||
| :label="item.label" | |||
| :value="item.value" | |||
| :disabled="item.disabled" | |||
| :title="item.disabled?$t('modelTraceback.mustExist'):''"> | |||
| :title="item.disabled ? $t('modelTraceback.mustExist') : ''"> | |||
| </el-option> | |||
| </el-select> | |||
| </div> | |||
| @@ -79,7 +80,8 @@ limitations under the License. | |||
| <div id="data-echart" | |||
| v-show="showEchartPic && !echartNoData"></div> | |||
| <div class="echart-nodata-container" | |||
| v-show="!showEchartPic && showTable"></div> | |||
| v-show="!showEchartPic && showTable && !(summaryDirList && !summaryDirList.length)"> | |||
| </div> | |||
| <div class="btns-container" | |||
| v-show="!echartNoData && showTable"> | |||
| <el-button type="primary" | |||
| @@ -103,7 +105,7 @@ limitations under the License. | |||
| <el-table ref="table" | |||
| :data="table.data" | |||
| tooltip-effect="light" | |||
| height="calc(100% - 54px)" | |||
| height="calc(100% - 40px)" | |||
| row-key="summary_dir" | |||
| @selection-change="handleSelectionChange" | |||
| @sort-change="tableSortChange"> | |||
| @@ -116,8 +118,8 @@ limitations under the License. | |||
| :key="key" | |||
| :prop="key" | |||
| :label="table.columnOptions[key].label" | |||
| :sortable="sortArray.includes(table.columnOptions[key].label)?'custom':false" | |||
| :fixed="table.columnOptions[key].label===text?true:false" | |||
| :sortable="sortArray.includes(table.columnOptions[key].label) ? 'custom' : false" | |||
| :fixed="table.columnOptions[key].label === text?true:false" | |||
| min-width="200" | |||
| show-overflow-tooltip> | |||
| <template slot="header" | |||
| @@ -151,7 +153,7 @@ limitations under the License. | |||
| </el-table-column> | |||
| <!-- remark column --> | |||
| <el-table-column fixed="right" | |||
| width="310"> | |||
| width="260"> | |||
| <template slot="header"> | |||
| <div> | |||
| <div class="label-text">{{$t('public.remark')}}</div> | |||
| @@ -208,10 +210,10 @@ limitations under the License. | |||
| <div> | |||
| <div class="icon-image-container"> | |||
| <div class="icon-image" | |||
| :class="[item.number===scope.row.tag && scope.row.showIcon?'icon-border':'']" | |||
| :class="[item.number === scope.row.tag && scope.row.showIcon ? 'icon-border' : '']" | |||
| v-for="item in imageList" | |||
| :key="item.number" | |||
| @click="iconValueChange(scope.row,item.number,$event)"> | |||
| @click="iconValueChange(scope.row, item.number, $event)"> | |||
| <img :src="item.iconAdd"> | |||
| </div> | |||
| </div> | |||
| @@ -243,33 +245,34 @@ limitations under the License. | |||
| </template> | |||
| </el-table-column> | |||
| </el-table> | |||
| <div> | |||
| <div class="hide-count" | |||
| v-show="recordsNumber-showNumber"> | |||
| {{ $t('modelTraceback.totalHide').replace(`{n}`,(recordsNumber-showNumber))}} | |||
| </div> | |||
| <div class="pagination-container"> | |||
| <el-pagination @current-change="handleCurrentChange" | |||
| :current-page="pagination.currentPage" | |||
| :page-size="pagination.pageSize" | |||
| :layout="pagination.layout" | |||
| :total="pagination.total"> | |||
| </el-pagination> | |||
| <div class="hide-count" | |||
| v-show="recordsNumber-showNumber"> | |||
| {{ $t('modelTraceback.totalHide').replace(`{n}`, (recordsNumber-showNumber))}} | |||
| </div> | |||
| <div class="clear"></div> | |||
| </div> | |||
| </div> | |||
| <div v-show="((!lineagedata.serData || !lineagedata.serData.length) && initOver) | |||
| ||(echartNoData&&(lineagedata.serData&&!!lineagedata.serData.length))" | |||
| ||(echartNoData && (lineagedata.serData && !!lineagedata.serData.length))" | |||
| class="no-data-page"> | |||
| <div class="no-data-img"> | |||
| <img :src="require('@/assets/images/nodata.png')" | |||
| alt="" /> | |||
| <p class="no-data-text" | |||
| v-show="!summaryDirList||(summaryDirList&&summaryDirList.length)&&!lineagedata.serData"> | |||
| v-show="!summaryDirList || (summaryDirList && summaryDirList.length) && !lineagedata.serData"> | |||
| {{ $t('public.noData') }} | |||
| </p> | |||
| <div v-show="echartNoData&&(lineagedata.serData&&!!lineagedata.serData.length)"> | |||
| <div v-show="echartNoData && (lineagedata.serData && !!lineagedata.serData.length)"> | |||
| <p class="no-data-text">{{ $t('dataTraceback.noDataFound') }}</p> | |||
| </div> | |||
| <div v-show="summaryDirList&&!summaryDirList.length"> | |||
| <div v-show="summaryDirList && !summaryDirList.length"> | |||
| <p class="no-data-text">{{ $t('dataTraceback.noDataFound') }}</p> | |||
| <p class="no-data-text"> | |||
| {{ $t('dataTraceback.click') }} | |||
| @@ -494,7 +497,7 @@ export default { | |||
| obj.iconAdd = require('@/assets/images/icon' + obj.number + '.svg'); | |||
| this.imageList.push(obj); | |||
| } | |||
| document.title = this.$t('summaryManage.dataTraceback') + '-MindInsight'; | |||
| document.title = `${this.$t('summaryManage.dataTraceback')}-MindInsight`; | |||
| document.addEventListener('click', this.blurFloat, true); | |||
| this.$nextTick(() => { | |||
| this.init(); | |||
| @@ -527,8 +530,8 @@ export default { | |||
| return; | |||
| } | |||
| row.showIcon = true; | |||
| const e = window.event; | |||
| document.getElementById('icon-dialog').style.top = e.clientY + 'px'; | |||
| document.getElementById('icon-dialog').style.top = | |||
| window.event.clientY + 'px'; | |||
| }, | |||
| iconValueChange(row, num, event) { | |||
| @@ -575,6 +578,13 @@ export default { | |||
| */ | |||
| clearIcon(row) { | |||
| const classWrap = event.path.find((item) => { | |||
| return item.className === 'icon-dialog'; | |||
| }); | |||
| const classArr = classWrap.querySelectorAll('.icon-border'); | |||
| classArr.forEach((item) => { | |||
| item.classList.remove('icon-border'); | |||
| }); | |||
| row.showIcon = false; | |||
| this.iconValue = 0; | |||
| row.tag = 0; | |||
| @@ -848,7 +858,7 @@ export default { | |||
| } | |||
| this.initChart(); | |||
| const list = []; | |||
| this.checkOptions.forEach((item) => { | |||
| this.basearr.forEach((item) => { | |||
| this.selectArrayValue.forEach((i) => { | |||
| if (i === item.value) { | |||
| list.push(i); | |||
| @@ -917,7 +927,7 @@ export default { | |||
| }); | |||
| } | |||
| const list = []; | |||
| this.checkOptions.forEach((item) => { | |||
| this.basearr.forEach((item) => { | |||
| this.selectArrayValue.forEach((i) => { | |||
| if (i === item.value) { | |||
| const obj = {}; | |||
| @@ -1061,6 +1071,9 @@ export default { | |||
| this.showTable = false; | |||
| this.echartNoData = true; | |||
| } else { | |||
| const echartLength = this.echart.brushData.length; | |||
| this.recordsNumber = echartLength; | |||
| this.showNumber = echartLength; | |||
| this.echart.showData = this.echart.brushData; | |||
| this.initChart(); | |||
| this.pagination.currentPage = 1; | |||
| @@ -1431,6 +1444,7 @@ export default { | |||
| this.initOver = false; | |||
| this.echartNoData = false; | |||
| this.showEchartPic = true; | |||
| this.selectCheckAll = true; | |||
| // checkOptions initializate to an empty array | |||
| this.checkOptions = []; | |||
| this.selectArrayValue = []; | |||
| @@ -1733,7 +1747,9 @@ export default { | |||
| const item = {}; | |||
| item.key = k; | |||
| item.value = dataObj[key][k]; | |||
| item.id = (index + 1) * 10 + 1 + j; | |||
| item.id = | |||
| `${new Date().getTime()}` + `${this.$store.state.tableId}`; | |||
| this.$store.commit('increaseTableId'); | |||
| tempData.children.push(item); | |||
| }); | |||
| } | |||
| @@ -1775,14 +1791,15 @@ export default { | |||
| <style lang="scss"> | |||
| .label-text { | |||
| line-height: 20px !important; | |||
| vertical-align: bottom; | |||
| padding-top: 20px; | |||
| display: block !important; | |||
| } | |||
| .remark-tip { | |||
| line-height: 14px !important; | |||
| line-height: 20px !important; | |||
| font-size: 12px; | |||
| white-space: pre-wrap !important; | |||
| vertical-align: bottom; | |||
| color: gray; | |||
| display: block !important; | |||
| } | |||
| .el-color-dropdown__main-wrapper, | |||
| .el-color-dropdown__value, | |||
| @@ -1841,6 +1858,13 @@ export default { | |||
| height: 100%; | |||
| overflow-y: auto; | |||
| position: relative; | |||
| .el-table th.is-leaf { | |||
| background: #f5f7fa; | |||
| } | |||
| .el-table td, | |||
| .el-table th.is-leaf { | |||
| border: 1px solid #ebeef5; | |||
| } | |||
| .inline-block-set { | |||
| display: inline-block; | |||
| } | |||
| @@ -1878,7 +1902,7 @@ export default { | |||
| .no-data-page { | |||
| width: 100%; | |||
| height: 100%; | |||
| padding-top: 224px; | |||
| padding-top: 200px; | |||
| } | |||
| .no-data-img { | |||
| background: #fff; | |||
| @@ -1944,6 +1968,7 @@ export default { | |||
| .data-checkbox-area { | |||
| position: relative; | |||
| margin: 24px 32px 12px; | |||
| height: 46px; | |||
| .reset-btn { | |||
| position: absolute; | |||
| right: 0px; | |||
| @@ -1951,12 +1976,12 @@ export default { | |||
| } | |||
| } | |||
| #data-echart { | |||
| height: 34%; | |||
| height: 32%; | |||
| width: 100%; | |||
| padding: 0 12px; | |||
| } | |||
| .echart-nodata-container { | |||
| height: 34%; | |||
| height: 32%; | |||
| width: 100%; | |||
| } | |||
| .btn-container-margin { | |||
| @@ -1975,8 +2000,8 @@ export default { | |||
| .table-container { | |||
| background-color: white; | |||
| height: calc(60% - 90px); | |||
| margin: 6px 32px 0; | |||
| height: calc(68% - 130px); | |||
| padding: 6px 32px; | |||
| position: relative; | |||
| .custom-label { | |||
| max-width: calc(100% - 25px); | |||
| @@ -1997,24 +2022,33 @@ export default { | |||
| .click-span { | |||
| cursor: pointer; | |||
| } | |||
| .clear { | |||
| clear: both; | |||
| } | |||
| .hide-count { | |||
| display: inline-block; | |||
| position: absolute; | |||
| right: 450px; | |||
| height: 32px; | |||
| line-height: 32px; | |||
| padding-top: 12px; | |||
| color: red; | |||
| float: right; | |||
| margin-right: 10px; | |||
| } | |||
| .el-pagination { | |||
| position: absolute; | |||
| right: 0px; | |||
| float: right; | |||
| margin-right: 32px; | |||
| bottom: 10px; | |||
| } | |||
| .pagination-container { | |||
| height: 40px; | |||
| } | |||
| } | |||
| } | |||
| .details-data-list { | |||
| .el-table td, | |||
| .el-table th.is-leaf { | |||
| border: none; | |||
| border-top: 1px solid #ebeef5; | |||
| } | |||
| .el-table { | |||
| th { | |||
| padding: 10px 0; | |||
| @@ -20,7 +20,7 @@ limitations under the License. | |||
| <div class="select-box" | |||
| v-if="!noData && | |||
| (!summaryDirList || (summaryDirList && summaryDirList.length))"> | |||
| <div v-show="showTable&&!noData" | |||
| <div v-show="showTable && !noData" | |||
| class="select-container"> | |||
| <!-- multiple collapse-tags --> | |||
| <div class="display-column"> {{$t('modelTraceback.displayColumn')}}</div> | |||
| @@ -40,19 +40,18 @@ limitations under the License. | |||
| <button type="text" | |||
| @click="allSelect" | |||
| class="select-all-button" | |||
| :class="[selectCheckAll? | |||
| 'checked-color':'button-text', | |||
| basearr.length>checkOptions.length?'btn-disabled':'']" | |||
| :disabled="basearr.length>checkOptions.length"> | |||
| :class="[selectCheckAll ? 'checked-color' : 'button-text', | |||
| basearr.length > checkOptions.length ? 'btn-disabled' : '']" | |||
| :disabled="basearr.length > checkOptions.length"> | |||
| {{$t('public.selectAll')}} | |||
| </button> | |||
| <button type="text" | |||
| @click="deselectAll" | |||
| class="deselect-all-button" | |||
| :class="[!selectCheckAll? | |||
| 'checked-color':'button-text', | |||
| basearr.length>checkOptions.length?'btn-disabled':'']" | |||
| :disabled="basearr.length>checkOptions.length"> | |||
| 'checked-color' : 'button-text', | |||
| basearr.length > checkOptions.length ? 'btn-disabled' : '']" | |||
| :disabled="basearr.length > checkOptions.length"> | |||
| {{$t('public.deselectAll')}} | |||
| </button> | |||
| </div> | |||
| @@ -64,7 +63,7 @@ limitations under the License. | |||
| :label="item.label" | |||
| :value="item.value" | |||
| :disabled="item.disabled" | |||
| :title="item.disabled?$t('modelTraceback.mustExist'):''"> | |||
| :title="item.disabled ? $t('modelTraceback.mustExist') : ''"> | |||
| </el-option> | |||
| </el-option-group> | |||
| </el-select> | |||
| @@ -82,19 +81,19 @@ limitations under the License. | |||
| type="primary" | |||
| size="mini" | |||
| plain | |||
| v-if="(!noData&&basearr.length) || | |||
| v-if="(!noData && basearr.length) || | |||
| (noData && summaryDirList && !summaryDirList.length)"> | |||
| {{ $t('modelTraceback.showAllData') }}</el-button> | |||
| </div> | |||
| </div> | |||
| <div id="echart" | |||
| v-show="!noData&&showEchartPic"></div> | |||
| v-show="!noData && showEchartPic"></div> | |||
| <div class="echart-no-data" | |||
| v-show="!showEchartPic"> | |||
| </div> | |||
| <div class="btns-container" | |||
| v-show="showTable&&!noData"> | |||
| v-show="showTable && !noData"> | |||
| <el-button type="primary" | |||
| size="mini" | |||
| class="custom-btn" | |||
| @@ -118,7 +117,7 @@ limitations under the License. | |||
| <el-table-column type="selection" | |||
| width="55" | |||
| :reserve-selection="true" | |||
| v-show="showTable&&!noData"> | |||
| v-show="showTable && !noData"> | |||
| </el-table-column> | |||
| <!--metric table column--> | |||
| @@ -188,7 +187,7 @@ limitations under the License. | |||
| </div> | |||
| </template> | |||
| <template slot-scope="scope"> | |||
| <span>{{formatNumber(key,scope.row[key])}}</span> | |||
| <span>{{formatNumber(key, scope.row[key])}}</span> | |||
| </template> | |||
| </el-table-column> | |||
| </el-table-column> | |||
| @@ -197,7 +196,7 @@ limitations under the License. | |||
| :key="key" | |||
| :prop="key" | |||
| :label="table.columnOptions[key].label" | |||
| :fixed="table.columnOptions[key].label===text?true:false" | |||
| :fixed="table.columnOptions[key].label === text ? true : false" | |||
| show-overflow-tooltip | |||
| min-width="150" | |||
| sortable="custom"> | |||
| @@ -216,7 +215,7 @@ limitations under the License. | |||
| </el-table-column> | |||
| <!-- remark column --> | |||
| <el-table-column fixed="right" | |||
| width="310"> | |||
| width="260"> | |||
| <template slot="header"> | |||
| <div> | |||
| <div class="label-text">{{$t('public.remark')}}</div> | |||
| @@ -271,7 +270,7 @@ limitations under the License. | |||
| <div> | |||
| <div class="icon-image-container"> | |||
| <div class="icon-image" | |||
| :class="[item.number===scope.row.tag&&scope.row.showIcon ? 'icon-border':'']" | |||
| :class="[item.number === scope.row.tag && scope.row.showIcon ? 'icon-border' : '']" | |||
| v-for="item in imageList" | |||
| :key="item.number" | |||
| @click="iconValueChange(scope.row,item.number,$event)"> | |||
| @@ -300,17 +299,18 @@ limitations under the License. | |||
| </template> | |||
| </el-table-column> | |||
| </el-table> | |||
| <div> | |||
| <div class="hide-count" | |||
| v-show="recordsNumber-showNumber"> | |||
| {{$t('modelTraceback.totalHide').replace(`{n}`,(recordsNumber-showNumber))}} | |||
| </div> | |||
| <div class="pagination-container"> | |||
| <el-pagination @current-change="pagination.pageChange" | |||
| :current-page="pagination.currentPage" | |||
| :page-size="pagination.pageSize" | |||
| :layout="pagination.layout" | |||
| :total="pagination.total"> | |||
| </el-pagination> | |||
| <div class="hide-count" | |||
| v-show="recordsNumber-showNumber"> | |||
| {{$t('modelTraceback.totalHide').replace(`{n}`, (recordsNumber-showNumber))}} | |||
| </div> | |||
| <div class="clear"></div> | |||
| </div> | |||
| </div> | |||
| @@ -425,7 +425,7 @@ export default { | |||
| obj.iconAdd = require('@/assets/images/icon' + obj.number + '.svg'); | |||
| this.imageList.push(obj); | |||
| } | |||
| document.title = this.$t('summaryManage.modelTraceback') + '-MindInsight'; | |||
| document.title = `${this.$t('summaryManage.modelTraceback')}-MindInsight`; | |||
| document.addEventListener('click', this.blurFloat, true); | |||
| this.$store.commit('setSelectedBarList', []); | |||
| this.getStoreList(); | |||
| @@ -466,8 +466,8 @@ export default { | |||
| return; | |||
| } | |||
| row.showIcon = true; | |||
| const e = window.event; | |||
| document.getElementById('icon-dialog').style.top = e.clientY + 'px'; | |||
| document.getElementById('icon-dialog').style.top = | |||
| window.event.clientY + 'px'; | |||
| }, | |||
| /** | |||
| @@ -514,6 +514,13 @@ export default { | |||
| }, | |||
| // clear icon | |||
| clearIcon(row) { | |||
| const classWrap = event.path.find((item) => { | |||
| return item.className === 'icon-dialog'; | |||
| }); | |||
| const classArr = classWrap.querySelectorAll('.icon-border'); | |||
| classArr.forEach((item) => { | |||
| item.classList.remove('icon-border'); | |||
| }); | |||
| row.showIcon = false; | |||
| this.iconValue = 0; | |||
| row.tag = 0; | |||
| @@ -1345,10 +1352,11 @@ export default { | |||
| this.echart.brushData = list; | |||
| this.echart.showData = this.echart.brushData; | |||
| this.initChart(); | |||
| this.table.data = list.slice( | |||
| const showList = list.slice( | |||
| (this.pagination.currentPage - 1) * this.pagination.pageSize, | |||
| this.pagination.currentPage * this.pagination.pageSize, | |||
| ); | |||
| this.table.data = showList; | |||
| this.recordsNumber = this.table.data.length; | |||
| this.showNumber = this.table.data.length; | |||
| this.pagination.total = res.data.count || 0; | |||
| @@ -1365,6 +1373,8 @@ export default { | |||
| sortChange(column) { | |||
| this.sortInfo.sorted_name = column.prop; | |||
| this.sortInfo.sorted_type = column.order; | |||
| this.recordsNumber = 0; | |||
| this.showNumber = 0; | |||
| this.getStoreList(); | |||
| const tempParam = { | |||
| limit: this.pagination.pageSize, | |||
| @@ -1384,9 +1394,21 @@ export default { | |||
| (res) => { | |||
| if (res && res.data && res.data.object) { | |||
| const list = this.setDataOfModel(res.data.object); | |||
| this.table.data = list; | |||
| const tempList = list.slice(0, this.pagination.pageSize); | |||
| this.recordsNumber = tempList.length; | |||
| if (this.hidenDirChecked.length) { | |||
| this.hidenDirChecked.forEach((dir) => { | |||
| tempList.forEach((item, index) => { | |||
| if (item.summary_dir === dir) { | |||
| tempList.splice(index, 1); | |||
| } | |||
| }); | |||
| }); | |||
| } | |||
| this.showNumber = tempList.length; | |||
| this.table.data = tempList; | |||
| this.pagination.total = res.data.count || 0; | |||
| this.pagination.currentPage = 0; | |||
| this.pagination.currentPage = 1; | |||
| } | |||
| }, | |||
| (error) => {}, | |||
| @@ -1741,6 +1763,7 @@ export default { | |||
| this.$store.commit('setSelectedBarList', []); | |||
| this.noData = false; | |||
| this.showTable = false; | |||
| this.selectCheckAll = true; | |||
| this.chartFilter = {}; | |||
| this.tableFilter.summary_dir = undefined; | |||
| this.sortInfo = {}; | |||
| @@ -1838,14 +1861,15 @@ export default { | |||
| <style lang="scss"> | |||
| .label-text { | |||
| line-height: 20px !important; | |||
| vertical-align: bottom; | |||
| padding-top: 20px; | |||
| display: block !important; | |||
| } | |||
| .remark-tip { | |||
| line-height: 14px !important; | |||
| line-height: 20px !important; | |||
| font-size: 12px; | |||
| white-space: pre-wrap !important; | |||
| vertical-align: bottom; | |||
| color: gray; | |||
| display: block !important; | |||
| } | |||
| .el-color-dropdown__main-wrapper, | |||
| .el-color-dropdown__value, | |||
| @@ -1943,6 +1967,7 @@ export default { | |||
| .btns { | |||
| margin-left: 20px; | |||
| padding-top: 12px; | |||
| height: 46px; | |||
| } | |||
| .btn-container-margin { | |||
| margin: 0 55px 10px; | |||
| @@ -2048,7 +2073,7 @@ export default { | |||
| } | |||
| .table-container { | |||
| background-color: white; | |||
| height: calc(60% - 40px); | |||
| height: calc(68% - 130px); | |||
| padding: 6px 32px; | |||
| position: relative; | |||
| .custom-label { | |||
| @@ -2059,21 +2084,24 @@ export default { | |||
| a { | |||
| cursor: pointer; | |||
| } | |||
| .clear { | |||
| clear: both; | |||
| } | |||
| .hide-count { | |||
| display: inline-block; | |||
| position: absolute; | |||
| right: 450px; | |||
| height: 32px; | |||
| line-height: 32px; | |||
| padding-top: 4px; | |||
| color: red; | |||
| float: right; | |||
| margin-right: 10px; | |||
| } | |||
| .el-pagination { | |||
| float: right; | |||
| margin-right: 32px; | |||
| position: absolute; | |||
| right: 0; | |||
| bottom: 10px; | |||
| } | |||
| .pagination-container { | |||
| height: 40px; | |||
| } | |||
| } | |||
| .no-data-page { | |||
| width: 100%; | |||
| @@ -61,12 +61,10 @@ | |||
| <div class="cl-search-box"> | |||
| <el-input v-model="searchByTypeInput" | |||
| v-if="statisticType === 0" | |||
| suffix-icon="el-icon-search" | |||
| :placeholder="$t('profiler.searchByType')" | |||
| @keyup.enter.native="searchOpCoreList()"></el-input> | |||
| <el-input v-model="searchByNameInput" | |||
| v-if="statisticType === 1" | |||
| suffix-icon="el-icon-search" | |||
| :placeholder="$t('profiler.searchByName')" | |||
| @keyup.enter.native="searchOpCoreList()"></el-input> | |||
| </div> | |||
| @@ -90,6 +88,8 @@ | |||
| :property="ele" | |||
| :key="key" | |||
| :sortable="ele === 'op_info' ? false : 'custom'" | |||
| :width="(ele==='execution_time'|| ele==='subgraph' || | |||
| ele==='op_name'|| ele==='op_type')?'220':''" | |||
| show-overflow-tooltip | |||
| :label="ele"> | |||
| </el-table-column> | |||
| @@ -124,6 +124,8 @@ | |||
| :key="$index" | |||
| :label="item" | |||
| :sortable="item === 'op_info' ? false : 'custom'" | |||
| :width="(item==='execution_time'|| item==='subgraph' || | |||
| item==='op_name'|| item==='op_type')?'220':''" | |||
| show-overflow-tooltip> | |||
| </el-table-column> | |||
| </el-table> | |||
| @@ -168,7 +170,6 @@ | |||
| </span> | |||
| <div class="cl-search-box"> | |||
| <el-input v-model="searchByCPUNameInput" | |||
| suffix-icon="el-icon-search" | |||
| :placeholder="$t('profiler.searchByName')" | |||
| @keyup.enter.native="searchOpCpuList()"></el-input> | |||
| </div> | |||
| @@ -814,7 +815,8 @@ export default { | |||
| option.xAxis = { | |||
| type: 'category', | |||
| axisLabel: { | |||
| interval: 1, | |||
| interval: 0, | |||
| rotate: -30, | |||
| }, | |||
| data: [], | |||
| }; | |||
| @@ -822,7 +824,7 @@ export default { | |||
| left: 50, | |||
| top: 20, | |||
| right: 0, | |||
| bottom: 30, | |||
| bottom: 50, | |||
| }; | |||
| option.yAxis = { | |||
| type: 'value', | |||
| @@ -925,7 +927,7 @@ export default { | |||
| const item = {}; | |||
| item.key = k; | |||
| item.value = dataObj[key][k]; | |||
| item.id = (index + 1) * 10 + 1 + j; | |||
| item.id = item.key + Math.random(); | |||
| tempData.children.push(item); | |||
| }); | |||
| } | |||
| @@ -955,20 +957,12 @@ export default { | |||
| }, | |||
| }, | |||
| mounted() { | |||
| if ( | |||
| this.$route.query && | |||
| this.$route.query.dir && | |||
| this.$route.query.id | |||
| ) { | |||
| if (this.$route.query && this.$route.query.dir && this.$route.query.id) { | |||
| this.profile_dir = this.$route.query.dir; | |||
| this.train_id = this.$route.query.id; | |||
| document.title = | |||
| decodeURIComponent(this.train_id) + | |||
| '-' + | |||
| this.$t('profiler.titleText') + | |||
| '-MindInsight'; | |||
| document.title = `${ decodeURIComponent(this.train_id)}-${this.$t('profiler.titleText')}-MindInsight`; | |||
| } else { | |||
| document.title = this.$t('profiler.titleText') + '-MindInsight'; | |||
| document.title = `${this.$t('profiler.titleText')}-MindInsight`; | |||
| } | |||
| this.init(); | |||
| window.addEventListener('resize', this.resizeCallback, false); | |||
| @@ -31,6 +31,7 @@ from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageFileNotF | |||
| LineageSearchConditionParamError) | |||
| from ..conftest import BASE_SUMMARY_DIR, DATASET_GRAPH, SUMMARY_DIR, SUMMARY_DIR_2 | |||
| from .....ut.lineagemgr.querier import event_data | |||
| from .....utils.tools import assert_equal_lineages | |||
| LINEAGE_INFO_RUN1 = { | |||
| 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'), | |||
| @@ -39,7 +40,7 @@ LINEAGE_INFO_RUN1 = { | |||
| }, | |||
| 'hyper_parameters': { | |||
| 'optimizer': 'Momentum', | |||
| 'learning_rate': 0.11999999731779099, | |||
| 'learning_rate': 0.12, | |||
| 'loss_function': 'SoftmaxCrossEntropyWithLogits', | |||
| 'epoch': 14, | |||
| 'parallel_mode': 'stand_alone', | |||
| @@ -73,11 +74,11 @@ LINEAGE_FILTRATION_EXCEPT_RUN = { | |||
| 'user_defined': {}, | |||
| 'network': 'ResNet', | |||
| 'optimizer': 'Momentum', | |||
| 'learning_rate': 0.11999999731779099, | |||
| 'learning_rate': 0.12, | |||
| 'epoch': 10, | |||
| 'batch_size': 32, | |||
| 'device_num': 2, | |||
| 'loss': 0.029999999329447746, | |||
| 'loss': 0.03, | |||
| 'model_size': 64, | |||
| 'metric': {}, | |||
| 'dataset_mark': 2 | |||
| @@ -92,10 +93,14 @@ LINEAGE_FILTRATION_RUN1 = { | |||
| 'train_dataset_count': 1024, | |||
| 'test_dataset_path': None, | |||
| 'test_dataset_count': 1024, | |||
| 'user_defined': {}, | |||
| 'user_defined': { | |||
| 'info': 'info1', | |||
| 'version': 'v1', | |||
| 'eval_version': 'version2' | |||
| }, | |||
| 'network': 'ResNet', | |||
| 'optimizer': 'Momentum', | |||
| 'learning_rate': 0.11999999731779099, | |||
| 'learning_rate': 0.12, | |||
| 'epoch': 14, | |||
| 'batch_size': 32, | |||
| 'device_num': 2, | |||
| @@ -119,14 +124,14 @@ LINEAGE_FILTRATION_RUN2 = { | |||
| 'user_defined': {}, | |||
| 'network': "ResNet", | |||
| 'optimizer': "Momentum", | |||
| 'learning_rate': 0.11999999731779099, | |||
| 'learning_rate': 0.12, | |||
| 'epoch': 10, | |||
| 'batch_size': 32, | |||
| 'device_num': 2, | |||
| 'loss': 0.029999999329447746, | |||
| 'loss': 0.03, | |||
| 'model_size': 10, | |||
| 'metric': { | |||
| 'accuracy': 2.7800000000000002 | |||
| 'accuracy': 2.78 | |||
| }, | |||
| 'dataset_mark': 3 | |||
| }, | |||
| @@ -173,7 +178,7 @@ class TestModelApi(TestCase): | |||
| 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'), | |||
| 'hyper_parameters': { | |||
| 'optimizer': 'Momentum', | |||
| 'learning_rate': 0.11999999731779099, | |||
| 'learning_rate': 0.12, | |||
| 'loss_function': 'SoftmaxCrossEntropyWithLogits', | |||
| 'epoch': 14, | |||
| 'parallel_mode': 'stand_alone', | |||
| @@ -190,9 +195,9 @@ class TestModelApi(TestCase): | |||
| 'network': 'ResNet' | |||
| } | |||
| } | |||
| assert expect_total_res == total_res | |||
| assert expect_partial_res1 == partial_res1 | |||
| assert expect_partial_res2 == partial_res2 | |||
| assert_equal_lineages(expect_total_res, total_res, self.assertDictEqual) | |||
| assert_equal_lineages(expect_partial_res1, partial_res1, self.assertDictEqual) | |||
| assert_equal_lineages(expect_partial_res2, partial_res2, self.assertDictEqual) | |||
| # the lineage summary file is empty | |||
| result = get_summary_lineage(self.dir_with_empty_lineage) | |||
| @@ -329,7 +334,7 @@ class TestModelApi(TestCase): | |||
| def test_filter_summary_lineage(self): | |||
| """Test the interface of filter_summary_lineage.""" | |||
| expect_result = { | |||
| 'customized': event_data.CUSTOMIZED__0, | |||
| 'customized': event_data.CUSTOMIZED__1, | |||
| 'object': [ | |||
| LINEAGE_FILTRATION_EXCEPT_RUN, | |||
| LINEAGE_FILTRATION_RUN1, | |||
| @@ -345,7 +350,7 @@ class TestModelApi(TestCase): | |||
| expect_objects = expect_result.get('object') | |||
| for idx, res_object in enumerate(res.get('object')): | |||
| expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') | |||
| assert expect_result == res | |||
| assert_equal_lineages(expect_result, res, self.assertDictEqual) | |||
| expect_result = { | |||
| 'customized': {}, | |||
| @@ -356,7 +361,7 @@ class TestModelApi(TestCase): | |||
| expect_objects = expect_result.get('object') | |||
| for idx, res_object in enumerate(res.get('object')): | |||
| expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') | |||
| assert expect_result == res | |||
| assert_equal_lineages(expect_result, res, self.assertDictEqual) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @@ -383,7 +388,7 @@ class TestModelApi(TestCase): | |||
| 'offset': 0 | |||
| } | |||
| expect_result = { | |||
| 'customized': event_data.CUSTOMIZED__0, | |||
| 'customized': event_data.CUSTOMIZED__1, | |||
| 'object': [ | |||
| LINEAGE_FILTRATION_RUN2, | |||
| LINEAGE_FILTRATION_RUN1 | |||
| @@ -394,7 +399,7 @@ class TestModelApi(TestCase): | |||
| expect_objects = expect_result.get('object') | |||
| for idx, res_object in enumerate(partial_res.get('object')): | |||
| expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') | |||
| assert expect_result == partial_res | |||
| assert_equal_lineages(expect_result, partial_res, self.assertDictEqual) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @@ -421,7 +426,7 @@ class TestModelApi(TestCase): | |||
| 'offset': 0 | |||
| } | |||
| expect_result = { | |||
| 'customized': event_data.CUSTOMIZED__0, | |||
| 'customized': event_data.CUSTOMIZED__1, | |||
| 'object': [ | |||
| LINEAGE_FILTRATION_RUN2, | |||
| LINEAGE_FILTRATION_RUN1 | |||
| @@ -432,7 +437,7 @@ class TestModelApi(TestCase): | |||
| expect_objects = expect_result.get('object') | |||
| for idx, res_object in enumerate(partial_res.get('object')): | |||
| expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') | |||
| assert expect_result == partial_res | |||
| assert_equal_lineages(expect_result, partial_res, self.assertDictEqual) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @@ -449,7 +454,7 @@ class TestModelApi(TestCase): | |||
| 'sorted_name': 'metric/accuracy', | |||
| } | |||
| expect_result = { | |||
| 'customized': event_data.CUSTOMIZED__0, | |||
| 'customized': event_data.CUSTOMIZED__1, | |||
| 'object': [ | |||
| LINEAGE_FILTRATION_EXCEPT_RUN, | |||
| LINEAGE_FILTRATION_RUN1, | |||
| @@ -461,7 +466,7 @@ class TestModelApi(TestCase): | |||
| expect_objects = expect_result.get('object') | |||
| for idx, res_object in enumerate(partial_res1.get('object')): | |||
| expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') | |||
| assert expect_result == partial_res1 | |||
| assert_equal_lineages(expect_result, partial_res1, self.assertDictEqual) | |||
| search_condition2 = { | |||
| 'batch_size': { | |||
| @@ -477,9 +482,6 @@ class TestModelApi(TestCase): | |||
| 'count': 0 | |||
| } | |||
| partial_res2 = filter_summary_lineage(BASE_SUMMARY_DIR, search_condition2) | |||
| expect_objects = expect_result.get('object') | |||
| for idx, res_object in enumerate(partial_res2.get('object')): | |||
| expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') | |||
| assert expect_result == partial_res2 | |||
| @pytest.mark.level0 | |||
| @@ -33,7 +33,7 @@ from ..api.test_model_api import LINEAGE_INFO_RUN1, LINEAGE_FILTRATION_EXCEPT_RU | |||
| LINEAGE_FILTRATION_RUN1, LINEAGE_FILTRATION_RUN2 | |||
| from ..conftest import BASE_SUMMARY_DIR | |||
| from .....ut.lineagemgr.querier import event_data | |||
| from .....utils.tools import check_loading_done | |||
| from .....utils.tools import check_loading_done, assert_equal_lineages | |||
| @pytest.mark.usefixtures("create_summary_dir") | |||
| @@ -58,8 +58,7 @@ class TestModelApi(TestCase): | |||
| """Test the interface of get_summary_lineage.""" | |||
| total_res = general_get_summary_lineage(data_manager=self._data_manger, summary_dir="./run1") | |||
| expect_total_res = LINEAGE_INFO_RUN1 | |||
| assert expect_total_res == total_res | |||
| assert_equal_lineages(expect_total_res, total_res, self.assertDictEqual) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @@ -70,7 +69,7 @@ class TestModelApi(TestCase): | |||
| def test_filter_summary_lineage(self): | |||
| """Test the interface of filter_summary_lineage.""" | |||
| expect_result = { | |||
| 'customized': event_data.CUSTOMIZED__0, | |||
| 'customized': event_data.CUSTOMIZED__1, | |||
| 'object': [ | |||
| LINEAGE_FILTRATION_EXCEPT_RUN, | |||
| LINEAGE_FILTRATION_RUN1, | |||
| @@ -86,7 +85,7 @@ class TestModelApi(TestCase): | |||
| expect_objects = expect_result.get('object') | |||
| for idx, res_object in enumerate(res.get('object')): | |||
| expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') | |||
| assert expect_result == res | |||
| assert_equal_lineages(expect_result, res, self.assertDictEqual) | |||
| expect_result = { | |||
| 'customized': {}, | |||
| @@ -100,4 +99,4 @@ class TestModelApi(TestCase): | |||
| } | |||
| } | |||
| res = general_filter_summary_lineage(data_manager=self._data_manger, search_condition=search_condition) | |||
| assert expect_result == res | |||
| assert_equal_lineages(expect_result, res, self.assertDictEqual) | |||
| @@ -28,7 +28,7 @@ from unittest import mock, TestCase | |||
| import numpy as np | |||
| import pytest | |||
| from mindinsight.lineagemgr import get_summary_lineage | |||
| from mindinsight.lineagemgr import get_summary_lineage, filter_summary_lineage | |||
| from mindinsight.lineagemgr.collection.model.model_lineage import TrainLineage, EvalLineage, \ | |||
| AnalyzeObject | |||
| from mindinsight.lineagemgr.common.utils import make_directory | |||
| @@ -73,6 +73,10 @@ class TestModelLineage(TestCase): | |||
| TrainLineage(cls.summary_record) | |||
| ] | |||
| cls.run_context['list_callback'] = _ListCallback(callback) | |||
| cls.user_defined_info = { | |||
| "info": "info1", | |||
| "version": "v1" | |||
| } | |||
| @pytest.mark.scene_train(2) | |||
| @pytest.mark.level0 | |||
| @@ -83,7 +87,7 @@ class TestModelLineage(TestCase): | |||
| @pytest.mark.env_single | |||
| def test_train_begin(self): | |||
| """Test the begin function in TrainLineage.""" | |||
| train_callback = TrainLineage(self.summary_record, True) | |||
| train_callback = TrainLineage(self.summary_record, True, self.user_defined_info) | |||
| train_callback.begin(RunContext(self.run_context)) | |||
| assert train_callback.initial_learning_rate == 0.12 | |||
| lineage_log_path = train_callback.lineage_summary.lineage_log_path | |||
| @@ -98,7 +102,11 @@ class TestModelLineage(TestCase): | |||
| @pytest.mark.env_single | |||
| def test_train_begin_with_user_defined_info(self): | |||
| """Test TrainLineage with nested user defined info.""" | |||
| user_defined_info = {"info": {"version": "v1"}} | |||
| user_defined_info = { | |||
| "info": "info1", | |||
| "version": "v1", | |||
| "network": "LeNet" | |||
| } | |||
| train_callback = TrainLineage( | |||
| self.summary_record, | |||
| False, | |||
| @@ -108,6 +116,8 @@ class TestModelLineage(TestCase): | |||
| assert train_callback.initial_learning_rate == 0.12 | |||
| lineage_log_path = train_callback.lineage_summary.lineage_log_path | |||
| assert os.path.isfile(lineage_log_path) is True | |||
| res = filter_summary_lineage(os.path.dirname(lineage_log_path)) | |||
| assert self.user_defined_info == res['object'][0]['model_lineage']['user_defined'] | |||
| @pytest.mark.scene_train(2) | |||
| @pytest.mark.level0 | |||
| @@ -138,7 +148,7 @@ class TestModelLineage(TestCase): | |||
| def test_training_end(self, *args): | |||
| """Test the end function in TrainLineage.""" | |||
| args[0].return_value = 64 | |||
| train_callback = TrainLineage(self.summary_record, True) | |||
| train_callback = TrainLineage(self.summary_record, True, self.user_defined_info) | |||
| train_callback.initial_learning_rate = 0.12 | |||
| train_callback.end(RunContext(self.run_context)) | |||
| res = get_summary_lineage(SUMMARY_DIR) | |||
| @@ -158,7 +168,7 @@ class TestModelLineage(TestCase): | |||
| @pytest.mark.env_single | |||
| def test_eval_end(self): | |||
| """Test the end function in EvalLineage.""" | |||
| eval_callback = EvalLineage(self.summary_record, True) | |||
| eval_callback = EvalLineage(self.summary_record, True, {'eval_version': 'version2'}) | |||
| eval_run_context = self.run_context | |||
| eval_run_context['metrics'] = {'accuracy': 0.78} | |||
| eval_run_context['valid_dataset'] = self.run_context['train_dataset'] | |||
| @@ -331,7 +341,7 @@ class TestModelLineage(TestCase): | |||
| def test_train_with_customized_network(self, *args): | |||
| """Test train with customized network.""" | |||
| args[0].return_value = 64 | |||
| train_callback = TrainLineage(self.summary_record, True) | |||
| train_callback = TrainLineage(self.summary_record, True, self.user_defined_info) | |||
| run_context_customized = self.run_context | |||
| del run_context_customized['optimizer'] | |||
| del run_context_customized['net_outputs'] | |||
| @@ -22,7 +22,7 @@ EVENT_TRAIN_DICT_0 = { | |||
| 'train_lineage': { | |||
| 'hyper_parameters': { | |||
| 'optimizer': 'ApplyMomentum0', | |||
| 'learning_rate': 0.10000000149011612, | |||
| 'learning_rate': 0.11, | |||
| 'loss_function': '', | |||
| 'epoch': 1, | |||
| 'parallel_mode': 'stand_alone0', | |||
| @@ -31,7 +31,7 @@ EVENT_TRAIN_DICT_0 = { | |||
| }, | |||
| 'algorithm': { | |||
| 'network': 'TrainOneStepCell0', | |||
| 'loss': 2.3025848865509033 | |||
| 'loss': 2.3025841 | |||
| }, | |||
| 'train_dataset': { | |||
| 'train_dataset_path': '', | |||
| @@ -49,7 +49,7 @@ EVENT_TRAIN_DICT_1 = { | |||
| 'train_lineage': { | |||
| 'hyper_parameters': { | |||
| 'optimizer': 'ApplyMomentum1', | |||
| 'learning_rate': 0.20000000298023224, | |||
| 'learning_rate': 0.2100001, | |||
| 'loss_function': 'loss_function1', | |||
| 'epoch': 1, | |||
| 'parallel_mode': 'stand_alone1', | |||
| @@ -58,7 +58,7 @@ EVENT_TRAIN_DICT_1 = { | |||
| }, | |||
| 'algorithm': { | |||
| 'network': 'TrainOneStepCell1', | |||
| 'loss': 2.4025847911834717 | |||
| 'loss': 2.4025841 | |||
| }, | |||
| 'train_dataset': { | |||
| 'train_dataset_path': '/path/to/train_dataset1', | |||
| @@ -76,7 +76,7 @@ EVENT_TRAIN_DICT_2 = { | |||
| 'train_lineage': { | |||
| 'hyper_parameters': { | |||
| 'optimizer': 'ApplyMomentum2', | |||
| 'learning_rate': 0.30000001192092896, | |||
| 'learning_rate': 0.3100001, | |||
| 'loss_function': 'loss_function2', | |||
| 'epoch': 2, | |||
| 'parallel_mode': 'stand_alone2', | |||
| @@ -85,7 +85,7 @@ EVENT_TRAIN_DICT_2 = { | |||
| }, | |||
| 'algorithm': { | |||
| 'network': 'TrainOneStepCell2', | |||
| 'loss': 2.502584934234619 | |||
| 'loss': 2.5025841 | |||
| }, | |||
| 'train_dataset': { | |||
| 'train_dataset_path': '/path/to/train_dataset2', | |||
| @@ -103,7 +103,7 @@ EVENT_TRAIN_DICT_3 = { | |||
| 'train_lineage': { | |||
| 'hyper_parameters': { | |||
| 'optimizer': 'ApplyMomentum3', | |||
| 'learning_rate': 0.4000000059604645, | |||
| 'learning_rate': 0.4, | |||
| 'loss_function': 'loss_function3', | |||
| 'epoch': 2, | |||
| 'parallel_mode': 'stand_alone3', | |||
| @@ -112,7 +112,7 @@ EVENT_TRAIN_DICT_3 = { | |||
| }, | |||
| 'algorithm': { | |||
| 'network': 'TrainOneStepCell3', | |||
| 'loss': 2.6025848388671875 | |||
| 'loss': 2.6025841 | |||
| }, | |||
| 'train_dataset': { | |||
| 'train_dataset_path': '/path/to/train_dataset3', | |||
| @@ -139,7 +139,7 @@ EVENT_TRAIN_DICT_4 = { | |||
| }, | |||
| 'algorithm': { | |||
| 'network': 'TrainOneStepCell4', | |||
| 'loss': 2.702584981918335 | |||
| 'loss': 2.7025841 | |||
| }, | |||
| 'train_dataset': { | |||
| 'train_dataset_path': '/path/to/train_dataset4', | |||
| @@ -166,7 +166,7 @@ EVENT_TRAIN_DICT_5 = { | |||
| }, | |||
| 'algorithm': { | |||
| 'network': 'TrainOneStepCell5', | |||
| 'loss': 2.702584981918335 | |||
| 'loss': 2.7025841 | |||
| }, | |||
| 'train_dataset': { | |||
| 'train_dataset_size': 35 | |||
| @@ -192,6 +192,13 @@ CUSTOMIZED__0 = { | |||
| 'metric/accuracy': {'label': 'metric/accuracy', 'required': True, 'type': 'float'}, | |||
| } | |||
| CUSTOMIZED__1 = { | |||
| **CUSTOMIZED__0, | |||
| 'user_defined/info': {'label': 'user_defined/info', 'required': False, 'type': 'str'}, | |||
| 'user_defined/version': {'label': 'user_defined/version', 'required': False, 'type': 'str'}, | |||
| 'user_defined/eval_version': {'label': 'user_defined/eval_version', 'required': False, 'type': 'str'} | |||
| } | |||
| CUSTOMIZED_0 = { | |||
| **CUSTOMIZED__0, | |||
| 'metric/mae': {'label': 'metric/mae', 'required': True, 'type': 'float'}, | |||
| @@ -211,33 +218,33 @@ CUSTOMIZED_2 = { | |||
| } | |||
| METRIC_1 = { | |||
| 'accuracy': 1.0000002, | |||
| 'accuracy': 1.2000002, | |||
| 'mae': 2.00000002, | |||
| 'mse': 3.00000002 | |||
| } | |||
| METRIC_2 = { | |||
| 'accuracy': 1.0000003, | |||
| 'mae': 2.00000003, | |||
| 'mse': 3.00000003 | |||
| 'accuracy': 1.3000003, | |||
| 'mae': 2.30000003, | |||
| 'mse': 3.30000003 | |||
| } | |||
| METRIC_3 = { | |||
| 'accuracy': 1.0000004, | |||
| 'mae': 2.00000004, | |||
| 'mse': 3.00000004 | |||
| 'accuracy': 1.4000004, | |||
| 'mae': 2.40000004, | |||
| 'mse': 3.40000004 | |||
| } | |||
| METRIC_4 = { | |||
| 'accuracy': 1.0000005, | |||
| 'mae': 2.00000005, | |||
| 'mse': 3.00000005 | |||
| 'accuracy': 1.5000005, | |||
| 'mae': 2.50000005, | |||
| 'mse': 3.50000005 | |||
| } | |||
| METRIC_5 = { | |||
| 'accuracy': 1.0000006, | |||
| 'mae': 2.00000006, | |||
| 'mse': 3.00000006 | |||
| 'accuracy': 1.7000006, | |||
| 'mae': 2.60000006, | |||
| 'mse': 3.60000006 | |||
| } | |||
| EVENT_EVAL_DICT_0 = { | |||
| @@ -27,6 +27,7 @@ from mindinsight.lineagemgr.querier.querier import Querier | |||
| from mindinsight.lineagemgr.summary.lineage_summary_analyzer import LineageInfo | |||
| from . import event_data | |||
| from ....utils.tools import assert_equal_lineages | |||
| def create_lineage_info(train_event_dict, eval_event_dict, dataset_event_dict): | |||
| @@ -266,7 +267,6 @@ class TestQuerier(TestCase): | |||
| mock_file_handler = MagicMock() | |||
| mock_file_handler.size = 1 | |||
| args[2].return_value = [{'relative_path': './', 'update_time': 1}] | |||
| single_summary_path = '/path/to/summary0' | |||
| lineage_objects = LineageOrganizer(summary_base_dir=single_summary_path).super_lineage_objs | |||
| @@ -286,13 +286,13 @@ class TestQuerier(TestCase): | |||
| """Test the success of get_summary_lineage.""" | |||
| expected_result = [LINEAGE_INFO_0] | |||
| result = self.single_querier.get_summary_lineage() | |||
| self.assertListEqual(expected_result, result) | |||
| assert_equal_lineages(expected_result, result, self.assertListEqual) | |||
| def test_get_summary_lineage_success_2(self): | |||
| """Test the success of get_summary_lineage.""" | |||
| expected_result = [LINEAGE_INFO_0] | |||
| result = self.single_querier.get_summary_lineage() | |||
| self.assertListEqual(expected_result, result) | |||
| assert_equal_lineages(expected_result, result, self.assertListEqual) | |||
| def test_get_summary_lineage_success_3(self): | |||
| """Test the success of get_summary_lineage.""" | |||
| @@ -306,7 +306,7 @@ class TestQuerier(TestCase): | |||
| result = self.single_querier.get_summary_lineage( | |||
| filter_keys=['model', 'algorithm'] | |||
| ) | |||
| self.assertListEqual(expected_result, result) | |||
| assert_equal_lineages(expected_result, result, self.assertListEqual) | |||
| def test_get_summary_lineage_success_4(self): | |||
| """Test the success of get_summary_lineage.""" | |||
| @@ -353,7 +353,7 @@ class TestQuerier(TestCase): | |||
| } | |||
| ] | |||
| result = self.multi_querier.get_summary_lineage() | |||
| self.assertListEqual(expected_result, result) | |||
| assert_equal_lineages(expected_result, result, self.assertListEqual) | |||
| def test_get_summary_lineage_success_5(self): | |||
| """Test the success of get_summary_lineage.""" | |||
| @@ -361,7 +361,7 @@ class TestQuerier(TestCase): | |||
| result = self.multi_querier.get_summary_lineage( | |||
| summary_dir='/path/to/summary1' | |||
| ) | |||
| self.assertListEqual(expected_result, result) | |||
| assert_equal_lineages(expected_result, result, self.assertListEqual) | |||
| def test_get_summary_lineage_success_6(self): | |||
| """Test the success of get_summary_lineage.""" | |||
| @@ -380,7 +380,7 @@ class TestQuerier(TestCase): | |||
| result = self.multi_querier.get_summary_lineage( | |||
| summary_dir='/path/to/summary0', filter_keys=filter_keys | |||
| ) | |||
| self.assertListEqual(expected_result, result) | |||
| assert_equal_lineages(expected_result, result, self.assertListEqual) | |||
| def test_get_summary_lineage_fail(self): | |||
| """Test the function of get_summary_lineage with exception.""" | |||
| @@ -423,7 +423,7 @@ class TestQuerier(TestCase): | |||
| 'count': 2, | |||
| } | |||
| result = self.multi_querier.filter_summary_lineage(condition=condition) | |||
| self.assertDictEqual(expected_result, result) | |||
| assert_equal_lineages(expected_result, result, self.assertDictEqual) | |||
| def test_filter_summary_lineage_success_2(self): | |||
| """Test the success of filter_summary_lineage.""" | |||
| @@ -448,7 +448,7 @@ class TestQuerier(TestCase): | |||
| 'count': 2, | |||
| } | |||
| result = self.multi_querier.filter_summary_lineage(condition=condition) | |||
| self.assertDictEqual(expected_result, result) | |||
| assert_equal_lineages(expected_result, result, self.assertDictEqual) | |||
| def test_filter_summary_lineage_success_3(self): | |||
| """Test the success of filter_summary_lineage.""" | |||
| @@ -465,7 +465,7 @@ class TestQuerier(TestCase): | |||
| 'count': 7, | |||
| } | |||
| result = self.multi_querier.filter_summary_lineage(condition=condition) | |||
| self.assertDictEqual(expected_result, result) | |||
| assert_equal_lineages(expected_result, result, self.assertDictEqual) | |||
| def test_filter_summary_lineage_success_4(self): | |||
| """Test the success of filter_summary_lineage.""" | |||
| @@ -483,7 +483,7 @@ class TestQuerier(TestCase): | |||
| 'count': 7, | |||
| } | |||
| result = self.multi_querier.filter_summary_lineage() | |||
| self.assertDictEqual(expected_result, result) | |||
| assert_equal_lineages(expected_result, result, self.assertDictEqual) | |||
| def test_filter_summary_lineage_success_5(self): | |||
| """Test the success of filter_summary_lineage.""" | |||
| @@ -498,7 +498,7 @@ class TestQuerier(TestCase): | |||
| 'count': 1, | |||
| } | |||
| result = self.multi_querier.filter_summary_lineage(condition=condition) | |||
| self.assertDictEqual(expected_result, result) | |||
| assert_equal_lineages(expected_result, result, self.assertDictEqual) | |||
| def test_filter_summary_lineage_success_6(self): | |||
| """Test the success of filter_summary_lineage.""" | |||
| @@ -520,7 +520,7 @@ class TestQuerier(TestCase): | |||
| 'count': 7, | |||
| } | |||
| result = self.multi_querier.filter_summary_lineage(condition=condition) | |||
| self.assertDictEqual(expected_result, result) | |||
| assert_equal_lineages(expected_result, result, self.assertDictEqual) | |||
| def test_filter_summary_lineage_success_7(self): | |||
| """Test the success of filter_summary_lineage.""" | |||
| @@ -542,14 +542,14 @@ class TestQuerier(TestCase): | |||
| 'count': 7, | |||
| } | |||
| result = self.multi_querier.filter_summary_lineage(condition=condition) | |||
| self.assertDictEqual(expected_result, result) | |||
| assert_equal_lineages(expected_result, result, self.assertDictEqual) | |||
| def test_filter_summary_lineage_success_8(self): | |||
| """Test the success of filter_summary_lineage.""" | |||
| condition = { | |||
| 'metric/accuracy': { | |||
| 'lt': 1.0000006, | |||
| 'gt': 1.0000004 | |||
| 'lt': 1.6000006, | |||
| 'gt': 1.4000004 | |||
| } | |||
| } | |||
| expected_result = { | |||
| @@ -558,7 +558,7 @@ class TestQuerier(TestCase): | |||
| 'count': 1, | |||
| } | |||
| result = self.multi_querier.filter_summary_lineage(condition=condition) | |||
| self.assertDictEqual(expected_result, result) | |||
| assert_equal_lineages(expected_result, result, self.assertDictEqual) | |||
| def test_filter_summary_lineage_success_9(self): | |||
| """Test the success of filter_summary_lineage.""" | |||
| @@ -572,14 +572,14 @@ class TestQuerier(TestCase): | |||
| 'count': 7, | |||
| } | |||
| result = self.multi_querier.filter_summary_lineage(condition=condition) | |||
| self.assertDictEqual(expected_result, result) | |||
| assert_equal_lineages(expected_result, result, self.assertDictEqual) | |||
| def test_filter_summary_lineage_fail(self): | |||
| """Test the function of filter_summary_lineage with exception.""" | |||
| condition = { | |||
| 'xxx': { | |||
| 'lt': 1.0000006, | |||
| 'gt': 1.0000004 | |||
| 'lt': 1.6000006, | |||
| 'gt': 1.4000004 | |||
| } | |||
| } | |||
| self.assertRaises( | |||
| @@ -21,6 +21,7 @@ from mindinsight.lineagemgr.querier.query_model import LineageObj | |||
| from . import event_data | |||
| from .test_querier import create_filtration_result, create_lineage_info | |||
| from ....utils.tools import assert_equal_lineages | |||
| class TestLineageObj(TestCase): | |||
| @@ -53,49 +54,62 @@ class TestLineageObj(TestCase): | |||
| def test_property(self): | |||
| """Test the function of getting property.""" | |||
| self.assertEqual(self.summary_dir, self.lineage_obj.summary_dir) | |||
| self.assertDictEqual( | |||
| assert_equal_lineages( | |||
| event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'], | |||
| self.lineage_obj.algorithm | |||
| self.lineage_obj.algorithm, | |||
| self.assertDictEqual | |||
| ) | |||
| self.assertDictEqual( | |||
| assert_equal_lineages( | |||
| event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'], | |||
| self.lineage_obj.model | |||
| self.lineage_obj.model, | |||
| self.assertDictEqual | |||
| ) | |||
| self.assertDictEqual( | |||
| assert_equal_lineages( | |||
| event_data.EVENT_TRAIN_DICT_0['train_lineage']['train_dataset'], | |||
| self.lineage_obj.train_dataset | |||
| self.lineage_obj.train_dataset, | |||
| self.assertDictEqual | |||
| ) | |||
| self.assertDictEqual( | |||
| assert_equal_lineages( | |||
| event_data.EVENT_TRAIN_DICT_0['train_lineage']['hyper_parameters'], | |||
| self.lineage_obj.hyper_parameters | |||
| self.lineage_obj.hyper_parameters, | |||
| self.assertDictEqual | |||
| ) | |||
| self.assertDictEqual(event_data.METRIC_0, self.lineage_obj.metric) | |||
| self.assertDictEqual( | |||
| assert_equal_lineages( | |||
| event_data.METRIC_0, | |||
| self.lineage_obj.metric, | |||
| self.assertDictEqual | |||
| ) | |||
| assert_equal_lineages( | |||
| event_data.EVENT_EVAL_DICT_0['evaluation_lineage']['valid_dataset'], | |||
| self.lineage_obj.valid_dataset | |||
| self.lineage_obj.valid_dataset, | |||
| self.assertDictEqual | |||
| ) | |||
| def test_property_eval_not_exist(self): | |||
| """Test the function of getting property with no evaluation event.""" | |||
| self.assertEqual(self.summary_dir, self.lineage_obj.summary_dir) | |||
| self.assertDictEqual( | |||
| assert_equal_lineages( | |||
| event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'], | |||
| self.lineage_obj_no_eval.algorithm | |||
| self.lineage_obj_no_eval.algorithm, | |||
| self.assertDictEqual | |||
| ) | |||
| self.assertDictEqual( | |||
| assert_equal_lineages( | |||
| event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'], | |||
| self.lineage_obj_no_eval.model | |||
| self.lineage_obj_no_eval.model, | |||
| self.assertDictEqual | |||
| ) | |||
| self.assertDictEqual( | |||
| assert_equal_lineages( | |||
| event_data.EVENT_TRAIN_DICT_0['train_lineage']['train_dataset'], | |||
| self.lineage_obj_no_eval.train_dataset | |||
| self.lineage_obj_no_eval.train_dataset, | |||
| self.assertDictEqual | |||
| ) | |||
| self.assertDictEqual( | |||
| assert_equal_lineages( | |||
| event_data.EVENT_TRAIN_DICT_0['train_lineage']['hyper_parameters'], | |||
| self.lineage_obj_no_eval.hyper_parameters | |||
| self.lineage_obj_no_eval.hyper_parameters, | |||
| self.assertDictEqual | |||
| ) | |||
| self.assertDictEqual({}, self.lineage_obj_no_eval.metric) | |||
| self.assertDictEqual({}, self.lineage_obj_no_eval.valid_dataset) | |||
| assert_equal_lineages({}, self.lineage_obj_no_eval.metric, self.assertDictEqual) | |||
| assert_equal_lineages({}, self.lineage_obj_no_eval.valid_dataset, self.assertDictEqual) | |||
| def test_get_summary_info(self): | |||
| """Test the function of get_summary_info.""" | |||
| @@ -106,7 +120,7 @@ class TestLineageObj(TestCase): | |||
| 'model': event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'] | |||
| } | |||
| result = self.lineage_obj.get_summary_info(filter_keys) | |||
| self.assertDictEqual(expected_result, result) | |||
| assert_equal_lineages(expected_result, result, self.assertDictEqual) | |||
| def test_to_model_lineage_dict(self): | |||
| """Test the function of to_model_lineage_dict.""" | |||
| @@ -120,7 +134,7 @@ class TestLineageObj(TestCase): | |||
| expected_result['model_lineage']['dataset_mark'] = None | |||
| expected_result.pop('dataset_graph') | |||
| result = self.lineage_obj.to_model_lineage_dict() | |||
| self.assertDictEqual(expected_result, result) | |||
| assert_equal_lineages(expected_result, result, self.assertDictEqual) | |||
| def test_to_dataset_lineage_dict(self): | |||
| """Test the function of to_dataset_lineage_dict.""" | |||
| @@ -267,7 +267,7 @@ class TestAicoreDetailAnalyser(TestCase): | |||
| result = self._analyser.query(condition) | |||
| self.assertDictEqual(expect_result, result) | |||
| def test_query_and_sort_by_op_type(self): | |||
| def test_query_and_sort_by_op_type_1(self): | |||
| """Test the success of the querying and sorting function by operator type.""" | |||
| detail_infos = get_detail_infos(indexes=[9, 0, 2, 1, 5, 3, 4]) | |||
| expect_result = { | |||
| @@ -289,6 +289,31 @@ class TestAicoreDetailAnalyser(TestCase): | |||
| ) | |||
| self.assertDictEqual(expect_result, result) | |||
| def test_query_and_sort_by_op_type_2(self): | |||
| """Test the success of the querying and sorting function by operator type.""" | |||
| detail_infos = get_detail_infos(indexes=[9, 0, 2, 1, 3, 4, 8, 6]) | |||
| expect_result = { | |||
| 'col_name': AicoreDetailAnalyser.__col_names__[0:4], | |||
| 'object': [item[0:4] for item in detail_infos] | |||
| } | |||
| filter_condition = { | |||
| 'op_type': {}, | |||
| 'subgraph': { | |||
| 'in': ['Default'] | |||
| }, | |||
| 'is_display_detail': False, | |||
| 'is_display_full_op_name': False | |||
| } | |||
| op_type_order = [ | |||
| 'MatMul', 'AtomicAddrClean', 'Cast', 'Conv2D', 'TransData' | |||
| ] | |||
| result = self._analyser.query_and_sort_by_op_type( | |||
| filter_condition, op_type_order | |||
| ) | |||
| print(result) | |||
| self.assertDictEqual(expect_result, result) | |||
| def test_col_names(self): | |||
| """Test the querying column names function.""" | |||
| self.assertListEqual( | |||
| @@ -81,3 +81,77 @@ def compare_result_with_file(result, expected_file_path): | |||
| with open(expected_file_path, 'r') as file: | |||
| expected_results = json.load(file) | |||
| assert result == expected_results | |||
| def deal_float_for_dict(res: dict, expected_res: dict, decimal_num=2): | |||
| """ | |||
| Deal float rounded to specified decimals in dict. | |||
| For example: | |||
| res:{ | |||
| "model_lineages": { | |||
| "metric": {"acc": 0.1234561} | |||
| } | |||
| } | |||
| expected_res: | |||
| { | |||
| "model_lineages": { | |||
| "metric": {"acc": 0.1234562} | |||
| } | |||
| } | |||
| After: | |||
| res:{ | |||
| "model_lineages": { | |||
| "metric": {"acc": 0.12346} | |||
| } | |||
| } | |||
| expected_res: | |||
| { | |||
| "model_lineages": { | |||
| "metric": {"acc": 0.12346} | |||
| } | |||
| } | |||
| Args: | |||
| res (dict): e.g. | |||
| { | |||
| "model_lineages": { | |||
| "metric": {"acc": 0.1234561} | |||
| } | |||
| } | |||
| expected_res (dict): | |||
| { | |||
| "model_lineages": { | |||
| "metric": {"acc": 0.1234562} | |||
| } | |||
| } | |||
| decimal_num (int): decimal rounded digits. | |||
| """ | |||
| for key in res: | |||
| value = res[key] | |||
| expected_value = expected_res[key] | |||
| if isinstance(value, dict): | |||
| deal_float_for_dict(value, expected_value) | |||
| elif isinstance(value, float): | |||
| res[key] = round(value, decimal_num) | |||
| expected_res[key] = round(expected_value, decimal_num) | |||
| def _deal_float_for_list(list1, list2, decimal_num): | |||
| """Deal float for list1 and list2.""" | |||
| index = 0 | |||
| for _ in list1: | |||
| deal_float_for_dict(list1[index], list2[index], decimal_num) | |||
| index += 1 | |||
| def assert_equal_lineages(lineages1, lineages2, assert_func, decimal_num=2): | |||
| """Assert float almost equal for lineage data.""" | |||
| if isinstance(lineages1, list) and isinstance(lineages2, list): | |||
| _deal_float_for_list(lineages1, lineages2, decimal_num) | |||
| elif lineages1.get('object') is not None and lineages2.get('object') is not None: | |||
| _deal_float_for_list(lineages1['object'], lineages2['object'], decimal_num) | |||
| else: | |||
| deal_float_for_dict(lineages1, lineages2, decimal_num) | |||
| assert_func(lineages1, lineages2) | |||