You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model_parameter.py 9.1 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago

  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Define schema of model lineage input parameters."""
  16. from marshmallow import Schema, fields, ValidationError, pre_load, validates
  17. from mindinsight.lineagemgr.common.exceptions.error_code import LineageErrorMsg, LineageErrors
  18. from mindinsight.lineagemgr.common.exceptions.exceptions import LineageParamTypeError, LineageParamValueError
  19. from mindinsight.lineagemgr.common.utils import enum_to_list
  20. from mindinsight.lineagemgr.querier.querier import LineageType
  21. from mindinsight.lineagemgr.querier.query_model import FIELD_MAPPING
  22. from mindinsight.utils.exceptions import MindInsightException
  23. class SearchModelConditionParameter(Schema):
  24. """Define the search model condition parameter schema."""
  25. summary_dir = fields.Dict()
  26. loss_function = fields.Dict()
  27. train_dataset_path = fields.Dict()
  28. train_dataset_count = fields.Dict()
  29. test_dataset_path = fields.Dict()
  30. test_dataset_count = fields.Dict()
  31. network = fields.Dict()
  32. optimizer = fields.Dict()
  33. learning_rate = fields.Dict()
  34. epoch = fields.Dict()
  35. batch_size = fields.Dict()
  36. device_num = fields.Dict()
  37. loss = fields.Dict()
  38. model_size = fields.Dict()
  39. limit = fields.Int(validate=lambda n: 0 < n <= 100)
  40. offset = fields.Int(validate=lambda n: 0 <= n <= 100000)
  41. sorted_name = fields.Str()
  42. sorted_type = fields.Str(allow_none=True)
  43. dataset_mark = fields.Dict()
  44. lineage_type = fields.Dict()
  45. @staticmethod
  46. def check_dict_value_type(data, value_type):
  47. """Check dict value type and int scope."""
  48. for key, value in data.items():
  49. if key in ["in", "not_in"]:
  50. if not isinstance(value, (list, tuple)):
  51. raise ValidationError("The value of `in` operation must be list or tuple.")
  52. else:
  53. if not isinstance(value, value_type):
  54. raise ValidationError("Wrong value type.")
  55. if value_type is int:
  56. if value < 0 or value > pow(2, 63) - 1:
  57. raise ValidationError("Int value should <= pow(2, 63) - 1.")
  58. if isinstance(value, bool):
  59. raise ValidationError("Wrong value type.")
  60. @staticmethod
  61. def check_param_value_type(data):
  62. """Check input param's value type."""
  63. for key, value in data.items():
  64. if key == "in":
  65. if not isinstance(value, (list, tuple)):
  66. raise ValidationError("The value of `in` operation must be list or tuple.")
  67. else:
  68. if isinstance(value, bool) or \
  69. (not isinstance(value, float) and not isinstance(value, int)):
  70. raise ValidationError("Wrong value type.")
  71. @staticmethod
  72. def check_operation(data):
  73. """Check input param's compare operation."""
  74. if not set(data.keys()).issubset(['in', 'eq', 'not_in']):
  75. raise ValidationError("Its operation should be `eq`, `in` or `not_in`.")
  76. @validates("loss")
  77. def check_loss(self, data):
  78. """Check loss."""
  79. SearchModelConditionParameter.check_param_value_type(data)
  80. @validates("learning_rate")
  81. def check_learning_rate(self, data):
  82. """Check learning_rate."""
  83. SearchModelConditionParameter.check_param_value_type(data)
  84. @validates("loss_function")
  85. def check_loss_function(self, data):
  86. """Check loss function."""
  87. SearchModelConditionParameter.check_operation(data)
  88. SearchModelConditionParameter.check_dict_value_type(data, str)
  89. @validates("train_dataset_path")
  90. def check_train_dataset_path(self, data):
  91. """Check train dataset path."""
  92. SearchModelConditionParameter.check_operation(data)
  93. SearchModelConditionParameter.check_dict_value_type(data, str)
  94. @validates("train_dataset_count")
  95. def check_train_dataset_count(self, data):
  96. """Check train dataset count."""
  97. SearchModelConditionParameter.check_dict_value_type(data, int)
  98. @validates("test_dataset_path")
  99. def check_test_dataset_path(self, data):
  100. """Check test dataset path."""
  101. SearchModelConditionParameter.check_operation(data)
  102. SearchModelConditionParameter.check_dict_value_type(data, str)
  103. @validates("test_dataset_count")
  104. def check_test_dataset_count(self, data):
  105. """Check test dataset count."""
  106. SearchModelConditionParameter.check_dict_value_type(data, int)
  107. @validates("network")
  108. def check_network(self, data):
  109. """Check network."""
  110. SearchModelConditionParameter.check_operation(data)
  111. SearchModelConditionParameter.check_dict_value_type(data, str)
  112. @validates("optimizer")
  113. def check_optimizer(self, data):
  114. """Check optimizer."""
  115. SearchModelConditionParameter.check_operation(data)
  116. SearchModelConditionParameter.check_dict_value_type(data, str)
  117. @validates("epoch")
  118. def check_epoch(self, data):
  119. """Check epoch."""
  120. SearchModelConditionParameter.check_dict_value_type(data, int)
  121. @validates("batch_size")
  122. def check_batch_size(self, data):
  123. """Check batch size."""
  124. SearchModelConditionParameter.check_dict_value_type(data, int)
  125. @validates("device_num")
  126. def check_device_num(self, data):
  127. """Check device num."""
  128. SearchModelConditionParameter.check_dict_value_type(data, int)
  129. @validates("model_size")
  130. def check_model_size(self, data):
  131. """Check model size."""
  132. SearchModelConditionParameter.check_dict_value_type(data, int)
  133. @validates("summary_dir")
  134. def check_summary_dir(self, data):
  135. """Check summary dir."""
  136. SearchModelConditionParameter.check_operation(data)
  137. SearchModelConditionParameter.check_dict_value_type(data, str)
  138. @validates("dataset_mark")
  139. def check_dataset_mark(self, data):
  140. """Check dataset mark."""
  141. SearchModelConditionParameter.check_operation(data)
  142. SearchModelConditionParameter.check_dict_value_type(data, str)
  143. @validates("lineage_type")
  144. def check_lineage_type(self, data):
  145. """Check lineage type."""
  146. SearchModelConditionParameter.check_operation(data)
  147. SearchModelConditionParameter.check_dict_value_type(data, str)
  148. recv_types = []
  149. for key, value in data.items():
  150. if key == "in":
  151. recv_types = value
  152. else:
  153. recv_types.append(value)
  154. lineage_types = enum_to_list(LineageType)
  155. if not set(recv_types).issubset(lineage_types):
  156. raise ValidationError("Given lineage type should be one of %s." % lineage_types)
  157. @pre_load
  158. def check_comparison(self, data, **kwargs):
  159. """Check comparison for all parameters in schema."""
  160. for attr, condition in data.items():
  161. if attr in ["limit", "offset", "sorted_name", "sorted_type", 'lineage_type']:
  162. continue
  163. if not isinstance(attr, str):
  164. raise LineageParamValueError('The search attribute not supported.')
  165. if attr not in FIELD_MAPPING and not attr.startswith(('metric/', 'user_defined/')):
  166. raise LineageParamValueError('The search attribute not supported.')
  167. if not isinstance(condition, dict):
  168. raise LineageParamTypeError("The search_condition element {} should be dict."
  169. .format(attr))
  170. for key in condition.keys():
  171. if key not in ["eq", "lt", "gt", "le", "ge", "in", "not_in"]:
  172. raise LineageParamValueError("The compare condition should be in "
  173. "('eq', 'lt', 'gt', 'le', 'ge', 'in', 'not_in').")
  174. if attr.startswith('metric/'):
  175. if len(attr) == 7:
  176. raise LineageParamValueError(
  177. 'The search attribute not supported.'
  178. )
  179. try:
  180. SearchModelConditionParameter.check_param_value_type(condition)
  181. except ValidationError:
  182. raise MindInsightException(
  183. error=LineageErrors.LINEAGE_PARAM_METRIC_ERROR,
  184. message=LineageErrorMsg.LINEAGE_METRIC_ERROR.value.format(attr)
  185. )
  186. return data