| @@ -0,0 +1,14 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| @@ -0,0 +1,55 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| Function: | |||
| Test histograms restful api. | |||
| Usage: | |||
| pytest tests/st/func/datavisual | |||
| """ | |||
| import pytest | |||
| from mindinsight.datavisual.common.enums import PluginNameEnum | |||
| from .....utils.tools import get_url | |||
| from .. import globals as gbl | |||
| BASE_URL = '/v1/mindinsight/datavisual/histograms' | |||
| class TestHistograms: | |||
| """Test Histograms.""" | |||
| @pytest.mark.level0 | |||
| @pytest.mark.env_single | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.usefixtures("init_summary_logs") | |||
| def test_histograms(self, client): | |||
| """Test getting histogram data.""" | |||
| plugin_name = PluginNameEnum.HISTOGRAM.value | |||
| train_id = gbl.get_train_ids()[0] | |||
| tag_name = gbl.get_tags(train_id, plugin_name)[0] | |||
| expected_histograms = gbl.get_metadata(train_id, tag_name) | |||
| params = dict(train_id=train_id, tag=tag_name) | |||
| url = get_url(BASE_URL, params) | |||
| response = client.get(url) | |||
| histograms = response.get_json().get("histograms") | |||
| for histograms, expected_histograms in zip(histograms, expected_histograms): | |||
| assert histograms.get("wall_time") == expected_histograms.get("wall_time") | |||
| assert histograms.get("step") == expected_histograms.get("step") | |||
| @@ -54,5 +54,6 @@ TRAIN_ROUTES = dict( | |||
| graph_single_node='/v1/mindinsight/datavisual/graphs/single-node', | |||
| image_metadata='/v1/mindinsight/datavisual/image/metadata', | |||
| image_single_image='/v1/mindinsight/datavisual/image/single-image', | |||
| scalar_metadata='/v1/mindinsight/datavisual/scalar/metadata' | |||
| scalar_metadata='/v1/mindinsight/datavisual/scalar/metadata', | |||
| histograms='/v1/mindinsight/datavisual/histograms' | |||
| ) | |||
| @@ -26,6 +26,7 @@ from mindinsight.datavisual.data_transform.graph import NodeTypeEnum | |||
| from mindinsight.datavisual.processors.graph_processor import GraphProcessor | |||
| from mindinsight.datavisual.processors.images_processor import ImageProcessor | |||
| from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor | |||
| from mindinsight.datavisual.processors.histogram_processor import HistogramProcessor | |||
| from ....utils.tools import get_url | |||
| from .conftest import TRAIN_ROUTES | |||
| @@ -432,3 +433,42 @@ class TestTrainVisual: | |||
| assert response.status_code == 200 | |||
| results = response.get_json() | |||
| assert results == test_name | |||
| def test_histograms_with_params_miss(self, client): | |||
| """Parsing missing params to get histogram data.""" | |||
| params = dict() | |||
| url = get_url(TRAIN_ROUTES['histograms'], params) | |||
| response = client.get(url) | |||
| results = response.get_json() | |||
| assert response.status_code == 400 | |||
| assert results['error_code'] == '50540003' | |||
| assert results['error_msg'] == "Param missing. 'train_id' is required." | |||
| train_id = "aa" | |||
| params = dict(train_id=train_id) | |||
| url = get_url(TRAIN_ROUTES['histograms'], params) | |||
| response = client.get(url) | |||
| results = response.get_json() | |||
| assert response.status_code == 400 | |||
| assert results['error_code'] == '50540003' | |||
| assert results['error_msg'] == "Param missing. 'tag' is required." | |||
| @patch.object(HistogramProcessor, 'get_histograms') | |||
| def test_histograms_success(self, mock_histogram_processor, client): | |||
| """Parsing available params to get histogram data.""" | |||
| test_train_id = "aa" | |||
| test_tag = "bb" | |||
| expect_resp = { | |||
| 'histograms': [{'buckets': [[1, 2, 3]]}], | |||
| 'train_id': test_train_id, | |||
| 'tag': test_tag | |||
| } | |||
| get_histograms = Mock(return_value=expect_resp) | |||
| mock_histogram_processor.side_effect = get_histograms | |||
| params = dict(train_id=test_train_id, tag=test_tag) | |||
| url = get_url(TRAIN_ROUTES['histograms'], params) | |||
| response = client.get(url) | |||
| assert response.status_code == 200 | |||
| results = response.get_json() | |||
| assert results == expect_resp | |||
| @@ -0,0 +1,117 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| Function: | |||
| Test histogram processor. | |||
| Usage: | |||
| pytest tests/ut/datavisual | |||
| """ | |||
| import tempfile | |||
| from unittest.mock import Mock | |||
| import pytest | |||
| from mindinsight.datavisual.common.enums import PluginNameEnum | |||
| from mindinsight.datavisual.common.exceptions import TrainJobNotExistError | |||
| from mindinsight.datavisual.common.exceptions import HistogramNotExistError | |||
| from mindinsight.datavisual.data_transform import data_manager | |||
| from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator | |||
| from mindinsight.datavisual.processors.histogram_processor import HistogramProcessor | |||
| from mindinsight.datavisual.utils import crc32 | |||
| from ....utils.log_operations import LogOperations | |||
| from ....utils.tools import check_loading_done, delete_files_or_dirs | |||
| from ..mock import MockLogger | |||
| class TestHistogramProcessor: | |||
| """Test histogram processor api.""" | |||
| _steps_list = [1, 3, 5] | |||
| _tag_name = 'tag_name' | |||
| _plugin_name = 'histogram' | |||
| _complete_tag_name = f'{_tag_name}/{_plugin_name}' | |||
| _temp_path = None | |||
| _histograms = None | |||
| _mock_data_manager = None | |||
| _train_id = None | |||
| _generated_path = [] | |||
| @classmethod | |||
| def setup_class(cls): | |||
| """Mock common environment for histograms unittest.""" | |||
| crc32.CheckValueAgainstData = Mock(return_value=True) | |||
| data_manager.logger = MockLogger | |||
| def teardown_class(self): | |||
| """Delete temp files.""" | |||
| delete_files_or_dirs(self._generated_path) | |||
| @pytest.fixture(scope='function') | |||
| def load_histogram_record(self): | |||
| """Load histogram record.""" | |||
| summary_base_dir = tempfile.mkdtemp() | |||
| log_dir = tempfile.mkdtemp(dir=summary_base_dir) | |||
| self._train_id = log_dir.replace(summary_base_dir, ".") | |||
| log_operation = LogOperations() | |||
| self._temp_path, self._histograms, _ = log_operation.generate_log( | |||
| PluginNameEnum.HISTOGRAM.value, log_dir, dict(step=self._steps_list, tag=self._tag_name)) | |||
| self._generated_path.append(summary_base_dir) | |||
| self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)]) | |||
| self._mock_data_manager.start_load_data(reload_interval=0) | |||
| # wait for loading done | |||
| check_loading_done(self._mock_data_manager, time_limit=5) | |||
| @pytest.mark.usefixtures('load_histogram_record') | |||
| def test_get_histograms_with_not_exist_id(self): | |||
| """Get histogram data with not exist id.""" | |||
| test_train_id = 'not_exist_id' | |||
| processor = HistogramProcessor(self._mock_data_manager) | |||
| with pytest.raises(TrainJobNotExistError) as exc_info: | |||
| processor.get_histograms(test_train_id, self._tag_name) | |||
| assert exc_info.value.error_code == '50545005' | |||
| assert exc_info.value.message == "Train job is not exist. Detail: Can not find the given train job in cache." | |||
| @pytest.mark.usefixtures('load_histogram_record') | |||
| def test_get_histograms_with_not_exist_tag(self): | |||
| """Get histogram data with not exist tag.""" | |||
| test_tag_name = 'not_exist_tag_name' | |||
| processor = HistogramProcessor(self._mock_data_manager) | |||
| with pytest.raises(HistogramNotExistError) as exc_info: | |||
| processor.get_histograms(self._train_id, test_tag_name) | |||
| assert exc_info.value.error_code == '5054500F' | |||
| assert "Can not find any data in this train job by given tag." in exc_info.value.message | |||
| @pytest.mark.usefixtures('load_histogram_record') | |||
| def test_get_histograms_success(self): | |||
| """Get histogram data success.""" | |||
| test_tag_name = self._complete_tag_name | |||
| processor = HistogramProcessor(self._mock_data_manager) | |||
| results = processor.get_histograms(self._train_id, test_tag_name) | |||
| recv_metadata = results.get('histograms') | |||
| for recv_values, expected_values in zip(recv_metadata, self._histograms): | |||
| assert recv_values.get('wall_time') == expected_values.get('wall_time') | |||
| assert recv_values.get('step') == expected_values.get('step') | |||