You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

optimizer_api.py 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Optimizer API module."""
  16. import json
  17. import pandas as pd
  18. from flask import Blueprint, jsonify, request
  19. from mindinsight.conf import settings
  20. from mindinsight.datavisual.data_transform.data_manager import DATA_MANAGER
  21. from mindinsight.lineagemgr.model import get_flattened_lineage, LineageTable
  22. from mindinsight.optimizer.common.enums import ReasonCode
  23. from mindinsight.optimizer.common.exceptions import SamplesNotEnoughError, CorrelationNanError
  24. from mindinsight.optimizer.utils.importances import calc_hyper_param_importance
  25. from mindinsight.optimizer.utils.utils import calc_histogram
  26. from mindinsight.utils.exceptions import ParamValueError
  27. BLUEPRINT = Blueprint("optimizer", __name__, url_prefix=settings.URL_PATH_PREFIX+settings.API_PREFIX)
  28. @BLUEPRINT.route("/optimizer/targets/search", methods=["POST"])
  29. def get_optimize_targets():
  30. """Get optimize targets."""
  31. search_condition = request.stream.read()
  32. try:
  33. search_condition = json.loads(search_condition if search_condition else "{}")
  34. except Exception:
  35. raise ParamValueError("Json data parse failed.")
  36. response = _get_optimize_targets(DATA_MANAGER, search_condition)
  37. return jsonify(response)
  38. def _get_optimize_targets(data_manager, search_condition=None):
  39. """Get optimize targets."""
  40. flatten_lineage = get_flattened_lineage(data_manager, search_condition)
  41. table = LineageTable(pd.DataFrame(flatten_lineage))
  42. target_summaries = []
  43. for target in table.target_names:
  44. hyper_parameters = []
  45. for hyper_param in table.hyper_param_names:
  46. param_info = {"name": hyper_param}
  47. try:
  48. importance = calc_hyper_param_importance(table.dataframe_data, hyper_param, target)
  49. param_info.update({"importance": importance})
  50. except SamplesNotEnoughError:
  51. param_info.update({"importance": 0})
  52. param_info.update({"reason_code": ReasonCode.SAMPLES_NOT_ENOUGH.value})
  53. except CorrelationNanError:
  54. param_info.update({"importance": 0})
  55. param_info.update({"reason_code": ReasonCode.CORRELATION_NAN.value})
  56. hyper_parameters.append(param_info)
  57. # Sort `hyper_parameters` in descending order of `importance` and ascending order of `name`.
  58. # If the automatically collected parameters and user-defined parameters have the same importance,
  59. # the user-defined parameters will be ranked behind.
  60. hyper_parameters.sort(key=lambda hyper_param: (-hyper_param.get("importance"),
  61. hyper_param.get("name").startswith('['),
  62. hyper_param.get("name")))
  63. target_summary = {
  64. "name": target,
  65. "buckets": calc_histogram(table.get_column(target)),
  66. "hyper_parameters": hyper_parameters,
  67. "data": table.get_column_values(target)
  68. }
  69. target_summaries.append(target_summary)
  70. target_summaries.sort(key=lambda summary: summary.get("name"))
  71. hyper_params_metadata = [{
  72. "name": hyper_param,
  73. "data": table.get_column_values(hyper_param)
  74. } for hyper_param in table.hyper_param_names]
  75. result = {
  76. "metadata": {
  77. "train_ids": table.train_ids,
  78. "possible_hyper_parameters": hyper_params_metadata,
  79. "unrecognized_params": table.drop_column_info
  80. },
  81. "targets": target_summaries
  82. }
  83. return result
  84. def init_module(app):
  85. """
  86. Init module entry.
  87. Args:
  88. app: the application obj.
  89. """
  90. app.register_blueprint(BLUEPRINT)