|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Utils for optimizer."""
- import string
-
- import numpy as np
-
- _DEFAULT_HISTOGRAM_BINS = 5
-
-
- def calc_histogram(np_value: np.ndarray, bins=_DEFAULT_HISTOGRAM_BINS):
- """
- Calculates histogram.
-
- This is a simple wrapper around the error-prone np.histogram() to improve robustness.
- """
- ma_value = np.ma.masked_invalid(np_value)
-
- valid_cnt = ma_value.count()
- if not valid_cnt:
- max_val = 0
- min_val = 0
- else:
- # Note that max of a masked array with dtype np.float16 returns inf (numpy issue#15077).
- if np.issubdtype(np_value.dtype, np.floating):
- max_val = ma_value.max(fill_value=np.NINF)
- min_val = ma_value.min(fill_value=np.PINF)
- else:
- max_val = ma_value.max()
- min_val = ma_value.min()
-
- range_left = min_val
- range_right = max_val
-
- default_half_range = 0.5
- if range_left >= range_right:
- range_left -= default_half_range
- range_right += default_half_range
-
- with np.errstate(invalid='ignore'):
- # if don't ignore state above, when np.nan exists,
- # it will occur RuntimeWarning: invalid value encountered in less_equal
- counts, edges = np.histogram(np_value, bins=bins, range=(range_left, range_right))
-
- histogram_bins = [None] * len(counts)
- for ind, count in enumerate(counts):
- histogram_bins[ind] = [float(edges[ind]), float(edges[ind + 1] - edges[ind]), float(count)]
-
- return histogram_bins
-
-
- def is_simple_numpy_number(dtype):
- """Verify if it is simple number."""
- if np.issubdtype(dtype, np.integer):
- return True
-
- if np.issubdtype(dtype, np.floating):
- return True
-
- return False
-
-
- def get_nested_message(info: dict, out_err_msg=""):
- """Get error message from the error dict generated by schema validation."""
- if not isinstance(info, dict):
- if isinstance(info, list):
- info = info[0]
- return f'Error in {out_err_msg}: {info}'
- for key in info:
- if isinstance(key, str) and key != '_schema':
- if out_err_msg:
- out_err_msg = f'{out_err_msg}.{key}'
- else:
- out_err_msg = key
- return get_nested_message(info[key], out_err_msg)
-
-
- def is_number(uchar):
- """If it is a number, return True."""
- if uchar in string.digits:
- return True
- return False
-
-
- def is_alphabet(uchar):
- """If it is a alphabet, return True."""
- if uchar in string.ascii_letters:
- return True
- return False
-
-
- def is_allowed_symbols(uchar):
- """If it is a allowed symbol, return True."""
- if uchar in ['_']:
- return True
- return False
-
-
- def is_param_name_valid(param_name: str):
- """If parameter name only contains number or alphabet."""
- for uchar in param_name:
- if not is_number(uchar) and not is_alphabet(uchar) and not is_allowed_symbols(uchar):
- return False
- return True
|