You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_hyper_config.py 4.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Test tuner and hyper config."""
  16. import json
  17. import os
  18. import shutil
  19. import pytest
  20. from mindinsight.optimizer.tuner import Tuner
  21. from mindinsight.optimizer.hyper_config import HyperConfig
  22. from tests.utils.lineage_writer import LineageWriter
  23. from tests.utils.lineage_writer.base import Metadata
  24. from tests.utils.tools import convert_dict_to_yaml
  25. from tests.st.func.optimizer.conftest import SUMMARY_BASE_DIR
  26. def _create_summaries(summary_base_dir):
  27. """Create summaries."""
  28. learning_rate = [0.01, 0.001, 0.02, 0.04, 0.05]
  29. acc = [0.8, 0.9, 0.8, 0.7, 0.6]
  30. momentum = [0.8, 0.9, 0.8, 0.9, 0.7]
  31. train_ids = []
  32. train_id_prefix = 'train_'
  33. params = {}
  34. for i, lr in enumerate(learning_rate):
  35. train_id = f'./{train_id_prefix}{i + 1}'
  36. train_ids.append(train_id)
  37. params.update({
  38. train_id: {
  39. 'train': {
  40. Metadata.learning_rate: lr
  41. },
  42. 'eval': {
  43. Metadata.metrics: json.dumps({'acc': acc[i]}),
  44. 'user_defined_info': {'momentum': momentum[i]}
  45. }
  46. }
  47. })
  48. lineage_writer = LineageWriter(summary_base_dir)
  49. lineage_writer.create_summaries(train_id_prefix=train_id_prefix, train_job_num=5, params=params)
  50. return train_ids
  51. def _prepare_script_and_yaml(output_dir, script_name='test.py', yaml_name='config.yaml'):
  52. """Prepare script and yaml file."""
  53. script_path = os.path.join(output_dir, script_name)
  54. with open(script_path, 'w'):
  55. pass
  56. config_dict = {
  57. 'command': 'python %s' % script_path,
  58. 'summary_base_dir': SUMMARY_BASE_DIR,
  59. 'tuner': {
  60. 'name': 'gp',
  61. },
  62. 'target': {
  63. 'group': 'metric',
  64. 'name': 'acc',
  65. 'goal': 'maximize'
  66. },
  67. 'parameters': {
  68. 'learning_rate': {
  69. 'bounds': [0.0001, 0.01],
  70. 'type': 'float'
  71. },
  72. 'momentum': {
  73. 'choice': [0.8, 0.9]
  74. }
  75. }
  76. }
  77. convert_dict_to_yaml(config_dict, output_dir, yaml_name)
  78. return script_path, os.path.join(output_dir, yaml_name)
  79. class TestHyperConfig:
  80. """Test HyperConfig."""
  81. def setup_class(self):
  82. """Setup class."""
  83. self._generated_file_path = []
  84. self._train_ids = _create_summaries(SUMMARY_BASE_DIR)
  85. script_path, self._yaml_path = _prepare_script_and_yaml(SUMMARY_BASE_DIR)
  86. self._generated_file_path.append(script_path)
  87. self._generated_file_path.append(self._yaml_path)
  88. def teardown_class(self):
  89. """Delete the files generated before."""
  90. for train_id in self._train_ids:
  91. summary_dir = os.path.join(SUMMARY_BASE_DIR, train_id)
  92. if os.path.exists(summary_dir):
  93. shutil.rmtree(summary_dir)
  94. for file in self._generated_file_path:
  95. if os.path.exists(file):
  96. os.remove(file)
  97. @pytest.mark.level0
  98. @pytest.mark.env_single
  99. @pytest.mark.platform_x86_cpu
  100. @pytest.mark.platform_arm_ascend_training
  101. @pytest.mark.platform_x86_gpu_training
  102. @pytest.mark.platform_x86_ascend_training
  103. @pytest.mark.usefixtures("init_summary_logs")
  104. def test_tuner_success(self):
  105. """Test tuner successfully."""
  106. tuner = Tuner(self._yaml_path)
  107. tuner.optimize()
  108. hyper_config = HyperConfig()
  109. params = hyper_config.params
  110. assert list(params.keys()) == ['learning_rate', 'momentum']
  111. assert 0.0001 <= params.learning_rate < 0.01
  112. assert params.momentum in [0.8, 0.9]
  113. assert list(hyper_config.custom_lineage_data.keys()) == ['momentum']