diff --git a/mindinsight/backend/profiler/profile_api.py b/mindinsight/backend/profiler/profile_api.py index 6c6dc2d6..5e57008d 100644 --- a/mindinsight/backend/profiler/profile_api.py +++ b/mindinsight/backend/profiler/profile_api.py @@ -30,9 +30,9 @@ from mindinsight.datavisual.utils.tools import get_train_id, get_profiler_dir, t from mindinsight.datavisual.utils.tools import unquote_args from mindinsight.profiler.analyser.analyser_factory import AnalyserFactory from mindinsight.profiler.analyser.minddata_analyser import MinddataAnalyser -from mindinsight.profiler.common.exceptions.exceptions import ProfilerFileNotFoundException, \ - ProfilerDirNotFoundException -from mindinsight.profiler.common.util import analyse_device_list_from_profiler_dir +from mindinsight.profiler.common.exceptions.exceptions import ProfilerFileNotFoundException +from mindinsight.profiler.common.util import analyse_device_list_from_profiler_dir, \ + check_train_job_and_profiler_dir from mindinsight.profiler.common.validator.validate import validate_condition, validate_ui_proc from mindinsight.profiler.common.validator.validate import validate_minddata_pipeline_condition from mindinsight.profiler.common.validator.validate_path import \ @@ -80,6 +80,8 @@ def get_profile_op_info(): except ValidationError: raise ParamValueError("Invalid profiler dir") + check_train_job_and_profiler_dir(profiler_dir_abs) + op_type = search_condition.get("op_type") analyser = AnalyserFactory.instance().get_analyser( @@ -115,6 +117,8 @@ def get_profile_device_list(): except ValidationError: raise ParamValueError("Invalid profiler dir") + check_train_job_and_profiler_dir(profiler_dir_abs) + device_list, _ = analyse_device_list_from_profiler_dir(profiler_dir_abs) return jsonify(device_list) @@ -131,7 +135,9 @@ def get_training_trace_graph(): >>> GET http://xxxx/v1/mindinsight/profile/training-trace/graph """ summary_dir = request.args.get("dir") - profiler_dir = validate_and_normalize_profiler_path(summary_dir, settings.SUMMARY_BASE_DIR) + profiler_dir_abs = validate_and_normalize_profiler_path(summary_dir, settings.SUMMARY_BASE_DIR) + check_train_job_and_profiler_dir(profiler_dir_abs) + graph_type = request.args.get("type", default='0') graph_type = to_int(graph_type, 'graph_type') device_id = request.args.get("device_id", default='0') @@ -139,7 +145,7 @@ def get_training_trace_graph(): graph_info = {} try: analyser = AnalyserFactory.instance().get_analyser( - 'step_trace', profiler_dir, device_id) + 'step_trace', profiler_dir_abs, device_id) except ProfilerFileNotFoundException: return jsonify(graph_info) @@ -165,14 +171,16 @@ def get_target_time_info(): >>> GET http://xxxx/v1/mindinsight/profile/training-trace/target-time-info """ summary_dir = request.args.get("dir") - profiler_dir = validate_and_normalize_profiler_path(summary_dir, settings.SUMMARY_BASE_DIR) + profiler_dir_abs = validate_and_normalize_profiler_path(summary_dir, settings.SUMMARY_BASE_DIR) + check_train_job_and_profiler_dir(profiler_dir_abs) + proc_name = request.args.get("type") validate_ui_proc(proc_name) device_id = request.args.get("device_id", default='0') _ = to_int(device_id, 'device_id') analyser = AnalyserFactory.instance().get_analyser( - 'step_trace', profiler_dir, device_id) + 'step_trace', profiler_dir_abs, device_id) target_time_info = analyser.query({ 'filter_condition': { 'mode': 'proc', @@ -193,7 +201,8 @@ def get_queue_info(): Examples: >>> GET http://xxxx/v1/mindinsight/profile/queue_info """ - profile_dir = get_profiler_abs_dir(request) + profiler_dir_abs = get_profiler_abs_dir(request) + check_train_job_and_profiler_dir(profiler_dir_abs) device_id = unquote_args(request, "device_id") to_int(device_id, 'device_id') @@ -201,7 +210,7 @@ def get_queue_info(): queue_info = {} minddata_analyser = AnalyserFactory.instance().get_analyser( - 'minddata', profile_dir, device_id) + 'minddata', profiler_dir_abs, device_id) if queue_type == "get_next": queue_info, _ = minddata_analyser.analyse_get_next_info(info_type="queue") elif queue_type == "device_queue": @@ -221,7 +230,9 @@ def get_time_info(): Examples: >>> GET http://xxxx/v1/mindinsight/profile/minddata_op """ - profile_dir = get_profiler_abs_dir(request) + profiler_dir_abs = get_profiler_abs_dir(request) + check_train_job_and_profiler_dir(profiler_dir_abs) + device_id = unquote_args(request, "device_id") to_int(device_id, 'device_id') op_type = unquote_args(request, "type") @@ -233,7 +244,7 @@ def get_time_info(): "advise": {} } minddata_analyser = AnalyserFactory.instance().get_analyser( - 'minddata', profile_dir, device_id) + 'minddata', profiler_dir_abs, device_id) if op_type == "get_next": _, time_info = minddata_analyser.analyse_get_next_info(info_type="time") elif op_type == "device_queue": @@ -253,12 +264,14 @@ def get_process_summary(): Examples: >>> GET http://xxxx/v1/mindinsight/profile/process_summary """ - profile_dir = get_profiler_abs_dir(request) + profiler_dir_abs = get_profiler_abs_dir(request) + check_train_job_and_profiler_dir(profiler_dir_abs) + device_id = unquote_args(request, "device_id") to_int(device_id, 'device_id') minddata_analyser = AnalyserFactory.instance().get_analyser( - 'minddata', profile_dir, device_id) + 'minddata', profiler_dir_abs, device_id) get_next_queue_info, _ = minddata_analyser.analyse_get_next_info(info_type="queue") device_queue_info, _ = minddata_analyser.analyse_device_queue_info(info_type="queue") @@ -318,6 +331,8 @@ def get_profile_summary_proposal(): except ValidationError: raise ParamValueError("Invalid profiler dir") + check_train_job_and_profiler_dir(profiler_dir_abs) + step_trace_condition = {"filter_condition": {"mode": "proc", "proc_name": "iteration_interval", "step_id": 0}} @@ -359,6 +374,7 @@ def get_minddata_pipeline_op_queue_info(): except ValidationError: raise ParamValueError("Invalid profiler dir.") + check_train_job_and_profiler_dir(profiler_dir_abs) condition = request.stream.read() try: condition = json.loads(condition) if condition else {} @@ -404,6 +420,8 @@ def get_minddata_pipeline_queue_info(): except ValidationError: raise ParamValueError("Invalid profiler dir.") + check_train_job_and_profiler_dir(profiler_dir_abs) + device_id = request.args.get('device_id', default='0') to_int(device_id, 'device_id') op_id = request.args.get('op_id', type=int) @@ -429,9 +447,9 @@ def get_timeline_summary(): >>> GET http://xxxx/v1/mindinsight/profile/timeline-summary """ summary_dir = request.args.get("dir") - profiler_dir = validate_and_normalize_profiler_path(summary_dir, settings.SUMMARY_BASE_DIR) - if not os.path.exists(profiler_dir): - raise ProfilerDirNotFoundException(msg=summary_dir) + profiler_dir_abs = validate_and_normalize_profiler_path(summary_dir, settings.SUMMARY_BASE_DIR) + check_train_job_and_profiler_dir(profiler_dir_abs) + device_id = request.args.get("device_id", default='0') _ = to_int(device_id, 'device_id') device_type = request.args.get("device_type", default='ascend') @@ -440,7 +458,7 @@ def get_timeline_summary(): raise ParamValueError("Invalid device_type.") analyser = AnalyserFactory.instance().get_analyser( - 'timeline', profiler_dir, device_id) + 'timeline', profiler_dir_abs, device_id) summary = analyser.get_timeline_summary(device_type) return summary @@ -458,9 +476,9 @@ def get_timeline_detail(): >>> GET http://xxxx/v1/mindinsight/profile/timeline """ summary_dir = request.args.get("dir") - profiler_dir = validate_and_normalize_profiler_path(summary_dir, settings.SUMMARY_BASE_DIR) - if not os.path.exists(profiler_dir): - raise ProfilerDirNotFoundException(msg=summary_dir) + profiler_dir_abs = validate_and_normalize_profiler_path(summary_dir, settings.SUMMARY_BASE_DIR) + check_train_job_and_profiler_dir(profiler_dir_abs) + device_id = request.args.get("device_id", default='0') _ = to_int(device_id, 'device_id') device_type = request.args.get("device_type", default='ascend') @@ -469,7 +487,7 @@ def get_timeline_detail(): raise ParamValueError("Invalid device_type.") analyser = AnalyserFactory.instance().get_analyser( - 'timeline', profiler_dir, device_id) + 'timeline', profiler_dir_abs, device_id) timeline = analyser.get_display_timeline(device_type) return jsonify(timeline) diff --git a/mindinsight/profiler/common/exceptions/error_code.py b/mindinsight/profiler/common/exceptions/error_code.py index a592e192..3376d91b 100644 --- a/mindinsight/profiler/common/exceptions/error_code.py +++ b/mindinsight/profiler/common/exceptions/error_code.py @@ -52,8 +52,6 @@ class ProfilerErrors(ProfilerMgrErrors): PIPELINE_OP_NOT_EXIST_ERROR = 8 | _ANALYSER_MASK - - @unique class ProfilerErrorMsg(Enum): """Profiler error messages.""" diff --git a/mindinsight/profiler/common/util.py b/mindinsight/profiler/common/util.py index 4612280c..447bc06a 100644 --- a/mindinsight/profiler/common/util.py +++ b/mindinsight/profiler/common/util.py @@ -20,6 +20,8 @@ This module provides the utils. import os from mindinsight.datavisual.utils.tools import to_int +from mindinsight.profiler.common.exceptions.exceptions import ProfilerDirNotFoundException +from mindinsight.datavisual.common.exceptions import TrainJobNotExistError # one sys count takes 10 ns, 1 ms has 100000 system count PER_MS_SYSCNT = 100000 @@ -172,7 +174,18 @@ def get_field_value(row_info, field_name, header, time_type='realtime'): return value + def get_options(options): + """Get options.""" if options is None: options = {} return options + + +def check_train_job_and_profiler_dir(profiler_dir_abs): + """ check the existence of train_job and profiler dir """ + train_job_dir_abs = os.path.abspath(os.path.join(profiler_dir_abs, '..')) + if not os.path.exists(train_job_dir_abs): + raise TrainJobNotExistError(error_detail=train_job_dir_abs) + if not os.path.exists(profiler_dir_abs): + raise ProfilerDirNotFoundException(msg=profiler_dir_abs) diff --git a/tests/ut/backend/profiler/test_profiler_api_minddata_pipeline.py b/tests/ut/backend/profiler/test_profiler_api_minddata_pipeline.py index af720bfb..ba34119e 100644 --- a/tests/ut/backend/profiler/test_profiler_api_minddata_pipeline.py +++ b/tests/ut/backend/profiler/test_profiler_api_minddata_pipeline.py @@ -31,6 +31,7 @@ class TestMinddataPipelineApi: self._url_op_queue = '/v1/mindinsight/profile/minddata-pipeline/op-queue' self._url_queue = '/v1/mindinsight/profile/minddata-pipeline/queue' + @mock.patch('mindinsight.backend.profiler.profile_api.check_train_job_and_profiler_dir') @mock.patch('mindinsight.backend.profiler.profile_api.settings') @mock.patch('mindinsight.profiler.analyser.base_analyser.BaseAnalyser.query') def test_get_minddata_pipeline_op_queue_info_1(self, *args): @@ -83,6 +84,7 @@ class TestMinddataPipelineApi: assert response.status_code == 400 assert expect_result == response.get_json() + @mock.patch('mindinsight.backend.profiler.profile_api.check_train_job_and_profiler_dir') @mock.patch('mindinsight.backend.profiler.profile_api.settings') def test_get_minddata_pipeline_op_queue_info_4(self, *args): """Test the function of querying operator and queue information.""" @@ -98,6 +100,7 @@ class TestMinddataPipelineApi: assert response.status_code == 400 assert expect_result == response.get_json() + @mock.patch('mindinsight.backend.profiler.profile_api.check_train_job_and_profiler_dir') @mock.patch('mindinsight.backend.profiler.profile_api.settings') @mock.patch('mindinsight.profiler.analyser.minddata_pipeline_analyser.' 'MinddataPipelineAnalyser.get_op_and_parent_op_info') @@ -141,6 +144,7 @@ class TestMinddataPipelineApi: assert response.status_code == 400 assert expect_result == response.get_json() + @mock.patch('mindinsight.backend.profiler.profile_api.check_train_job_and_profiler_dir') @mock.patch('mindinsight.backend.profiler.profile_api.validate_and_normalize_path') @mock.patch('mindinsight.backend.profiler.profile_api.settings') def test_get_minddata_pipeline_queue_info_3(self, *args): @@ -158,6 +162,7 @@ class TestMinddataPipelineApi: assert response.status_code == 400 assert expect_result == response.get_json() + @mock.patch('mindinsight.backend.profiler.profile_api.check_train_job_and_profiler_dir') @mock.patch('mindinsight.backend.profiler.profile_api.settings') def test_get_minddata_pipeline_queue_info_4(self, *args): """Test the function of querying queue information.""" @@ -174,6 +179,7 @@ class TestMinddataPipelineApi: assert response.status_code == 400 assert expect_result == response.get_json() + @mock.patch('mindinsight.backend.profiler.profile_api.check_train_job_and_profiler_dir') @mock.patch('mindinsight.backend.profiler.profile_api.settings') def test_get_minddata_pipeline_queue_info_5(self, *args): """Test the function of querying queue information.""" diff --git a/tests/ut/backend/profiler/test_profiler_restful_api.py b/tests/ut/backend/profiler/test_profiler_restful_api.py index faa6c90a..022767fe 100644 --- a/tests/ut/backend/profiler/test_profiler_restful_api.py +++ b/tests/ut/backend/profiler/test_profiler_restful_api.py @@ -30,6 +30,7 @@ class TestProfilerRestfulApi(TestCase): self.app_client = APP.test_client() self.url = '/v1/mindinsight/profile/ops/search?train_id=run1&profile=profiler' + @mock.patch('mindinsight.backend.profiler.profile_api.check_train_job_and_profiler_dir') @mock.patch('mindinsight.backend.lineagemgr.lineage_api.settings') @mock.patch('mindinsight.profiler.analyser.base_analyser.BaseAnalyser.query') def test_ops_search_success(self, *args):