You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

optimizer_api.py 3.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Optimizer API module."""
  16. import json
  17. from flask import Blueprint, jsonify, request
  18. from mindinsight.conf import settings
  19. from mindinsight.datavisual.data_transform.data_manager import DATA_MANAGER
  20. from mindinsight.lineagemgr.model import get_lineage_table
  21. from mindinsight.optimizer.common.enums import ReasonCode
  22. from mindinsight.optimizer.common.exceptions import SamplesNotEnoughError, CorrelationNanError
  23. from mindinsight.optimizer.utils.importances import calc_hyper_param_importance
  24. from mindinsight.optimizer.utils.utils import calc_histogram
  25. from mindinsight.utils.exceptions import ParamValueError
  26. BLUEPRINT = Blueprint("optimizer", __name__, url_prefix=settings.URL_PATH_PREFIX+settings.API_PREFIX)
  27. @BLUEPRINT.route("/optimizer/targets/search", methods=["POST"])
  28. def get_optimize_targets():
  29. """Get optimize targets."""
  30. search_condition = request.stream.read()
  31. try:
  32. search_condition = json.loads(search_condition if search_condition else "{}")
  33. except Exception:
  34. raise ParamValueError("Json data parse failed.")
  35. response = _get_optimize_targets(DATA_MANAGER, search_condition)
  36. return jsonify(response)
  37. def _get_optimize_targets(data_manager, search_condition):
  38. """Get optimize targets."""
  39. table = get_lineage_table(data_manager, search_condition)
  40. target_summaries = []
  41. for target in table.target_names:
  42. hyper_parameters = []
  43. for hyper_param in table.hyper_param_names:
  44. param_info = {"name": hyper_param}
  45. try:
  46. importance = calc_hyper_param_importance(table.df, hyper_param, target)
  47. param_info.update({"importance": importance})
  48. except SamplesNotEnoughError:
  49. param_info.update({"importance": 0})
  50. param_info.update({"reason_code": ReasonCode.SAMPLES_NOT_ENOUGH.value})
  51. except CorrelationNanError:
  52. param_info.update({"importance": 0})
  53. param_info.update({"reason_code": ReasonCode.CORRELATION_NAN.value})
  54. hyper_parameters.append(param_info)
  55. hyper_parameters.sort(key=lambda hyper_param: hyper_param.get("importance"), reverse=True)
  56. target_summary = {
  57. "name": target,
  58. "buckets": calc_histogram(table.get_column(target)),
  59. "hyper_parameters": hyper_parameters,
  60. "data": table.get_column_values(target)
  61. }
  62. target_summaries.append(target_summary)
  63. target_summaries.sort(key=lambda summary: summary.get("name"))
  64. hyper_params_metadata = [{
  65. "name": hyper_param,
  66. "data": table.get_column_values(hyper_param)
  67. } for hyper_param in table.hyper_param_names]
  68. result = {
  69. "metadata": {
  70. "train_ids": table.train_ids,
  71. "possible_hyper_parameters": hyper_params_metadata,
  72. "unrecognized_params": table.drop_column_info
  73. },
  74. "targets": target_summaries
  75. }
  76. return result
  77. def init_module(app):
  78. """
  79. Init module entry.
  80. Args:
  81. app: the application obj.
  82. """
  83. app.register_blueprint(BLUEPRINT)