|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Hyper config."""
- import json
- import os
-
- from mindinsight.optimizer.common.constants import HYPER_CONFIG_ENV_NAME, HYPER_CONFIG_LEN_LIMIT
- from mindinsight.optimizer.common.exceptions import HyperConfigEnvError, HyperConfigError
-
-
- class AttributeDict(dict):
- """A dict can be accessed by attribute."""
- def __init__(self, d=None):
- super().__init__()
- if d is not None:
- for k, v in d.items():
- self[k] = v
-
- def __key(self, key):
- """Get key."""
- return "" if key is None else key
-
- def __setattr__(self, key, value):
- """Set attribute for object."""
- self[self.__key(key)] = value
-
- def __getattr__(self, key):
- """
- Get attribute value according by attribute name.
-
- Args:
- key (str): attribute name.
-
- Returns:
- Any, attribute value.
-
- Raises:
- AttributeError: If the key does not exists, will raise Exception.
-
- """
- value = self.get(self.__key(key))
- if value is None:
- raise AttributeError("The attribute %r is not exist." % key)
- return value
-
- def __getitem__(self, key):
- """Get attribute value according by attribute name."""
- value = super().get(self.__key(key))
- if value is None:
- raise AttributeError("The attribute %r is not exist." % key)
- return value
-
- def __setitem__(self, key, value):
- """Set attribute for object."""
- return super().__setitem__(self.__key(key), value)
-
-
- class HyperConfig:
- """
- Hyper config.
- 1. Init HyperConfig.
- 2. Get suggested params and summary_dir.
- 3. Record by SummaryCollector with summary_dir.
-
- Examples:
- >>> hyper_config = HyperConfig()
- >>> params = hyper_config.params
- >>> learning_rate = params.learning_rate
- >>> batch_size = params.batch_size
-
- >>> summary_dir = hyper_config.summary_dir
- >>> summary_cb = SummaryCollector(summary_dir)
- """
- def __init__(self):
- self._init_validate_hyper_config()
-
- def _init_validate_hyper_config(self):
- """Init and validate hyper config."""
- hyper_config = os.environ.get(HYPER_CONFIG_ENV_NAME)
- if hyper_config is None:
- raise HyperConfigEnvError("Hyper config is not in system environment.")
- if len(hyper_config) > HYPER_CONFIG_LEN_LIMIT:
- raise HyperConfigEnvError("Hyper config is too long. The length limit is %s, the length of "
- "hyper_config is %s." % (HYPER_CONFIG_LEN_LIMIT, len(hyper_config)))
-
- try:
- hyper_config = json.loads(hyper_config)
- except TypeError as exc:
- raise HyperConfigError("Hyper config type error. detail: %s." % str(exc))
- except Exception as exc:
- raise HyperConfigError("Hyper config decode error. detail: %s." % str(exc))
-
- self._validate_hyper_config(hyper_config)
- self._hyper_config = hyper_config
-
- def _validate_hyper_config(self, hyper_config):
- """Validate hyper config."""
- for key in ['summary_dir', 'params']:
- if key not in hyper_config:
- raise HyperConfigError("%r must exist in hyper_config." % key)
-
- # validate summary_dir
- summary_dir = hyper_config.get('summary_dir')
- if not isinstance(summary_dir, str):
- raise HyperConfigError("The 'summary_dir' should be string.")
- hyper_config['summary_dir'] = os.path.realpath(summary_dir)
-
- # validate params
- params = hyper_config.get('params')
- if not isinstance(params, dict):
- raise HyperConfigError("'params' is not a dict.")
- for key, value in params.items():
- if not isinstance(value, (int, float)):
- raise HyperConfigError("The value of %r is not integer or float." % key)
-
- @property
- def params(self):
- """Get params."""
- return AttributeDict(self._hyper_config.get('params'))
-
- @property
- def summary_dir(self):
- """Get train summary dir path."""
- return self._hyper_config.get('summary_dir')
-
- @property
- def custom_lineage_data(self):
- return self._hyper_config.get('custom_lineage_data')
|