You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

hyper_config.py 4.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Hyper config."""
  16. import json
  17. import os
  18. from mindinsight.optimizer.common.constants import HYPER_CONFIG_ENV_NAME, HYPER_CONFIG_LEN_LIMIT
  19. from mindinsight.optimizer.common.exceptions import HyperConfigEnvError, HyperConfigError
  20. class AttributeDict(dict):
  21. """A dict can be accessed by attribute."""
  22. def __init__(self, d=None):
  23. super().__init__()
  24. if d is not None:
  25. for k, v in d.items():
  26. self[k] = v
  27. def __key(self, key):
  28. """Get key."""
  29. return "" if key is None else key
  30. def __setattr__(self, key, value):
  31. """Set attribute for object."""
  32. self[self.__key(key)] = value
  33. def __getattr__(self, key):
  34. """
  35. Get attribute value according by attribute name.
  36. Args:
  37. key (str): attribute name.
  38. Returns:
  39. Any, attribute value.
  40. Raises:
  41. AttributeError: If the key does not exists, will raise Exception.
  42. """
  43. value = self.get(self.__key(key))
  44. if value is None:
  45. raise AttributeError("The attribute %r is not exist." % key)
  46. return value
  47. def __getitem__(self, key):
  48. """Get attribute value according by attribute name."""
  49. value = super().get(self.__key(key))
  50. if value is None:
  51. raise AttributeError("The attribute %r is not exist." % key)
  52. return value
  53. def __setitem__(self, key, value):
  54. """Set attribute for object."""
  55. return super().__setitem__(self.__key(key), value)
  56. class HyperConfig:
  57. """
  58. Hyper config.
  59. 1. Init HyperConfig.
  60. 2. Get suggested params and summary_dir.
  61. 3. Record by SummaryCollector with summary_dir.
  62. Examples:
  63. >>> hyper_config = HyperConfig()
  64. >>> params = hyper_config.params
  65. >>> learning_rate = params.learning_rate
  66. >>> batch_size = params.batch_size
  67. >>> summary_dir = hyper_config.summary_dir
  68. >>> summary_cb = SummaryCollector(summary_dir)
  69. """
  70. def __init__(self):
  71. self._init_validate_hyper_config()
  72. def _init_validate_hyper_config(self):
  73. """Init and validate hyper config."""
  74. hyper_config = os.environ.get(HYPER_CONFIG_ENV_NAME)
  75. if hyper_config is None:
  76. raise HyperConfigEnvError("Hyper config is not in system environment.")
  77. if len(hyper_config) > HYPER_CONFIG_LEN_LIMIT:
  78. raise HyperConfigEnvError("Hyper config is too long. The length limit is %s, the length of "
  79. "hyper_config is %s." % (HYPER_CONFIG_LEN_LIMIT, len(hyper_config)))
  80. try:
  81. hyper_config = json.loads(hyper_config)
  82. except TypeError as exc:
  83. raise HyperConfigError("Hyper config type error. detail: %s." % str(exc))
  84. except Exception as exc:
  85. raise HyperConfigError("Hyper config decode error. detail: %s." % str(exc))
  86. self._validate_hyper_config(hyper_config)
  87. self._hyper_config = hyper_config
  88. def _validate_hyper_config(self, hyper_config):
  89. """Validate hyper config."""
  90. for key in ['summary_dir', 'params']:
  91. if key not in hyper_config:
  92. raise HyperConfigError("%r must exist in hyper_config." % key)
  93. # validate summary_dir
  94. summary_dir = hyper_config.get('summary_dir')
  95. if not isinstance(summary_dir, str):
  96. raise HyperConfigError("The 'summary_dir' should be string.")
  97. hyper_config['summary_dir'] = os.path.realpath(summary_dir)
  98. # validate params
  99. params = hyper_config.get('params')
  100. if not isinstance(params, dict):
  101. raise HyperConfigError("'params' is not a dict.")
  102. for key, value in params.items():
  103. if not isinstance(value, (int, float)):
  104. raise HyperConfigError("The value of %r is not integer or float." % key)
  105. @property
  106. def params(self):
  107. """Get params."""
  108. return AttributeDict(self._hyper_config.get('params'))
  109. @property
  110. def summary_dir(self):
  111. """Get train summary dir path."""
  112. return self._hyper_config.get('summary_dir')
  113. @property
  114. def custom_lineage_data(self):
  115. return self._hyper_config.get('custom_lineage_data')