You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model_parameter.py 10 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Define schema of model lineage input parameters."""
  16. from marshmallow import Schema, fields, ValidationError, pre_load, validates
  17. from marshmallow.validate import Range, OneOf
  18. from mindinsight.lineagemgr.common.exceptions.error_code import LineageErrorMsg, \
  19. LineageErrors
  20. from mindinsight.lineagemgr.common.exceptions.exceptions import \
  21. LineageParamTypeError, LineageParamValueError
  22. from mindinsight.lineagemgr.common.log import logger
  23. from mindinsight.lineagemgr.common.utils import enum_to_list
  24. from mindinsight.lineagemgr.querier.querier import LineageType
  25. from mindinsight.lineagemgr.querier.query_model import FIELD_MAPPING
  26. from mindinsight.utils.exceptions import MindInsightException
  27. try:
  28. from mindspore.dataset.engine import Dataset
  29. from mindspore.nn import Cell, Optimizer
  30. from mindspore.common.tensor import Tensor
  31. from mindspore.train.callback import _ListCallback
  32. except (ImportError, ModuleNotFoundError):
  33. logger.error('MindSpore Not Found!')
  34. class RunContextArgs(Schema):
  35. """Define the parameter schema for RunContext."""
  36. optimizer = fields.Function(allow_none=True)
  37. loss_fn = fields.Function(allow_none=True)
  38. net_outputs = fields.Function(allow_none=True)
  39. train_network = fields.Function(allow_none=True)
  40. train_dataset = fields.Function(allow_none=True)
  41. epoch_num = fields.Int(allow_none=True, validate=Range(min=1))
  42. batch_num = fields.Int(allow_none=True, validate=Range(min=0))
  43. cur_step_num = fields.Int(allow_none=True, validate=Range(min=0))
  44. parallel_mode = fields.Str(allow_none=True)
  45. device_number = fields.Int(allow_none=True, validate=Range(min=1))
  46. list_callback = fields.Function(allow_none=True)
  47. @pre_load
  48. def check_optimizer(self, data, **kwargs):
  49. optimizer = data.get("optimizer")
  50. if optimizer and not isinstance(optimizer, Optimizer):
  51. raise ValidationError({'optimizer': [
  52. "Parameter optimizer must be an instance of mindspore.nn.optim.Optimizer."
  53. ]})
  54. return data
  55. @pre_load
  56. def check_train_network(self, data, **kwargs):
  57. train_network = data.get("train_network")
  58. if train_network and not isinstance(train_network, Cell):
  59. raise ValidationError({'train_network': [
  60. "Parameter train_network must be an instance of mindspore.nn.Cell."]})
  61. return data
  62. @pre_load
  63. def check_train_dataset(self, data, **kwargs):
  64. train_dataset = data.get("train_dataset")
  65. if train_dataset and not isinstance(train_dataset, Dataset):
  66. raise ValidationError({'train_dataset': [
  67. "Parameter train_dataset must be an instance of "
  68. "mindspore.dataengine.datasets.Dataset"]})
  69. return data
  70. @pre_load
  71. def check_loss(self, data, **kwargs):
  72. net_outputs = data.get("net_outputs")
  73. if net_outputs and not isinstance(net_outputs, Tensor):
  74. raise ValidationError({'net_outpus': [
  75. "The parameter net_outputs is invalid. It should be a Tensor."
  76. ]})
  77. return data
  78. @pre_load
  79. def check_list_callback(self, data, **kwargs):
  80. list_callback = data.get("list_callback")
  81. if list_callback and not isinstance(list_callback, _ListCallback):
  82. raise ValidationError({'list_callback': [
  83. "Parameter list_callback must be an instance of "
  84. "mindspore.train.callback._ListCallback."
  85. ]})
  86. return data
  87. class EvalParameter(Schema):
  88. """Define the parameter schema for Evaluation job."""
  89. valid_dataset = fields.Function(allow_none=True)
  90. metrics = fields.Dict(allow_none=True)
  91. @pre_load
  92. def check_valid_dataset(self, data, **kwargs):
  93. valid_dataset = data.get("valid_dataset")
  94. if valid_dataset and not isinstance(valid_dataset, Dataset):
  95. raise ValidationError({'valid_dataset': [
  96. "Parameter valid_dataset must be an instance of "
  97. "mindspore.dataengine.datasets.Dataset"]})
  98. return data
  99. class SearchModelConditionParameter(Schema):
  100. """Define the search model condition parameter schema."""
  101. summary_dir = fields.Dict()
  102. loss_function = fields.Dict()
  103. train_dataset_path = fields.Dict()
  104. train_dataset_count = fields.Dict()
  105. test_dataset_path = fields.Dict()
  106. test_dataset_count = fields.Dict()
  107. network = fields.Dict()
  108. optimizer = fields.Dict()
  109. learning_rate = fields.Dict()
  110. epoch = fields.Dict()
  111. batch_size = fields.Dict()
  112. loss = fields.Dict()
  113. model_size = fields.Dict()
  114. limit = fields.Int(validate=lambda n: 0 < n <= 100)
  115. offset = fields.Int(validate=lambda n: 0 <= n <= 100000)
  116. sorted_name = fields.Str()
  117. sorted_type = fields.Str(allow_none=True)
  118. lineage_type = fields.Str(
  119. validate=OneOf(enum_to_list(LineageType)),
  120. allow_none=True
  121. )
  122. @staticmethod
  123. def check_dict_value_type(data, value_type):
  124. """Check dict value type and int scope."""
  125. for key, value in data.items():
  126. if key == "in":
  127. if not isinstance(value, (list, tuple)):
  128. raise ValidationError("In operation's value must be list or tuple.")
  129. else:
  130. if not isinstance(value, value_type):
  131. raise ValidationError("Wrong value type.")
  132. if value_type is int:
  133. if value < 0 or value > pow(2, 63) - 1:
  134. raise ValidationError("Int value should <= pow(2, 63) - 1.")
  135. if isinstance(value, bool):
  136. raise ValidationError("Wrong value type.")
  137. @staticmethod
  138. def check_param_value_type(data):
  139. """Check input param's value type."""
  140. for key, value in data.items():
  141. if key == "in":
  142. if not isinstance(value, (list, tuple)):
  143. raise ValidationError("In operation's value must be list or tuple.")
  144. else:
  145. if isinstance(value, bool) or \
  146. (not isinstance(value, float) and not isinstance(value, int)):
  147. raise ValidationError("Wrong value type.")
  148. @validates("loss")
  149. def check_loss(self, data):
  150. """Check loss."""
  151. SearchModelConditionParameter.check_param_value_type(data)
  152. @validates("learning_rate")
  153. def check_learning_rate(self, data):
  154. """Check learning_rate."""
  155. SearchModelConditionParameter.check_param_value_type(data)
  156. @validates("loss_function")
  157. def check_loss_function(self, data):
  158. SearchModelConditionParameter.check_dict_value_type(data, str)
  159. @validates("train_dataset_path")
  160. def check_train_dataset_path(self, data):
  161. SearchModelConditionParameter.check_dict_value_type(data, str)
  162. @validates("train_dataset_count")
  163. def check_train_dataset_count(self, data):
  164. SearchModelConditionParameter.check_dict_value_type(data, int)
  165. @validates("test_dataset_path")
  166. def check_test_dataset_path(self, data):
  167. SearchModelConditionParameter.check_dict_value_type(data, str)
  168. @validates("test_dataset_count")
  169. def check_test_dataset_count(self, data):
  170. SearchModelConditionParameter.check_dict_value_type(data, int)
  171. @validates("network")
  172. def check_network(self, data):
  173. SearchModelConditionParameter.check_dict_value_type(data, str)
  174. @validates("optimizer")
  175. def check_optimizer(self, data):
  176. SearchModelConditionParameter.check_dict_value_type(data, str)
  177. @validates("epoch")
  178. def check_epoch(self, data):
  179. SearchModelConditionParameter.check_dict_value_type(data, int)
  180. @validates("batch_size")
  181. def check_batch_size(self, data):
  182. SearchModelConditionParameter.check_dict_value_type(data, int)
  183. @validates("model_size")
  184. def check_model_size(self, data):
  185. SearchModelConditionParameter.check_dict_value_type(data, int)
  186. @validates("summary_dir")
  187. def check_summary_dir(self, data):
  188. SearchModelConditionParameter.check_dict_value_type(data, str)
  189. @pre_load
  190. def check_comparision(self, data, **kwargs):
  191. """Check comparision for all parameters in schema."""
  192. for attr, condition in data.items():
  193. if attr in ["limit", "offset", "sorted_name", "sorted_type", "lineage_type"]:
  194. continue
  195. if not isinstance(attr, str):
  196. raise LineageParamValueError('The search attribute not supported.')
  197. if attr not in FIELD_MAPPING and not attr.startswith(('metric/','user_defined/')):
  198. raise LineageParamValueError('The search attribute not supported.')
  199. if not isinstance(condition, dict):
  200. raise LineageParamTypeError("The search_condition element {} should be dict."
  201. .format(attr))
  202. for key in condition.keys():
  203. if key not in ["eq", "lt", "gt", "le", "ge", "in"]:
  204. raise LineageParamValueError("The compare condition should be in "
  205. "('eq', 'lt', 'gt', 'le', 'ge', 'in').")
  206. if attr.startswith('metric/'):
  207. if len(attr) == 7:
  208. raise LineageParamValueError(
  209. 'The search attribute not supported.'
  210. )
  211. try:
  212. SearchModelConditionParameter.check_param_value_type(condition)
  213. except ValidationError:
  214. raise MindInsightException(
  215. error=LineageErrors.LINEAGE_PARAM_METRIC_ERROR,
  216. message=LineageErrorMsg.LINEAGE_METRIC_ERROR.value.format(attr)
  217. )
  218. return data

MindInsight为MindSpore提供了简单易用的调优调试能力。在训练过程中,可以将标量、张量、图像、计算图、模型超参、训练耗时等数据记录到文件中,通过MindInsight可视化页面进行查看及分析。