|
- # Copyright 2019 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Define schema of model lineage input parameters."""
- from marshmallow import Schema, fields, ValidationError, pre_load, validates
- from marshmallow.validate import Range, OneOf
-
- from mindinsight.lineagemgr.common.exceptions.error_code import LineageErrorMsg, \
- LineageErrors
- from mindinsight.lineagemgr.common.exceptions.exceptions import \
- LineageParamTypeError, LineageParamValueError
- from mindinsight.lineagemgr.common.log import logger
- from mindinsight.lineagemgr.common.utils import enum_to_list
- from mindinsight.lineagemgr.querier.querier import LineageType
- from mindinsight.lineagemgr.querier.query_model import FIELD_MAPPING
- from mindinsight.utils.exceptions import MindInsightException
-
- try:
- from mindspore.dataset.engine import Dataset
- from mindspore.nn import Cell, Optimizer
- from mindspore.common.tensor import Tensor
- from mindspore.train.callback import _ListCallback
- except (ImportError, ModuleNotFoundError):
- logger.error('MindSpore Not Found!')
-
-
- class RunContextArgs(Schema):
- """Define the parameter schema for RunContext."""
- optimizer = fields.Function(allow_none=True)
- loss_fn = fields.Function(allow_none=True)
- net_outputs = fields.Function(allow_none=True)
- train_network = fields.Function(allow_none=True)
- train_dataset = fields.Function(allow_none=True)
- epoch_num = fields.Int(allow_none=True, validate=Range(min=1))
- batch_num = fields.Int(allow_none=True, validate=Range(min=0))
- cur_step_num = fields.Int(allow_none=True, validate=Range(min=0))
- parallel_mode = fields.Str(allow_none=True)
- device_number = fields.Int(allow_none=True, validate=Range(min=1))
- list_callback = fields.Function(allow_none=True)
-
- @pre_load
- def check_optimizer(self, data, **kwargs):
- optimizer = data.get("optimizer")
- if optimizer and not isinstance(optimizer, Optimizer):
- raise ValidationError({'optimizer': [
- "Parameter optimizer must be an instance of mindspore.nn.optim.Optimizer."
- ]})
- return data
-
- @pre_load
- def check_train_network(self, data, **kwargs):
- train_network = data.get("train_network")
- if train_network and not isinstance(train_network, Cell):
- raise ValidationError({'train_network': [
- "Parameter train_network must be an instance of mindspore.nn.Cell."]})
- return data
-
- @pre_load
- def check_train_dataset(self, data, **kwargs):
- train_dataset = data.get("train_dataset")
- if train_dataset and not isinstance(train_dataset, Dataset):
- raise ValidationError({'train_dataset': [
- "Parameter train_dataset must be an instance of "
- "mindspore.dataengine.datasets.Dataset"]})
- return data
-
- @pre_load
- def check_loss(self, data, **kwargs):
- net_outputs = data.get("net_outputs")
- if net_outputs and not isinstance(net_outputs, Tensor):
- raise ValidationError({'net_outpus': [
- "The parameter net_outputs is invalid. It should be a Tensor."
- ]})
- return data
-
- @pre_load
- def check_list_callback(self, data, **kwargs):
- list_callback = data.get("list_callback")
- if list_callback and not isinstance(list_callback, _ListCallback):
- raise ValidationError({'list_callback': [
- "Parameter list_callback must be an instance of "
- "mindspore.train.callback._ListCallback."
- ]})
- return data
-
-
- class EvalParameter(Schema):
- """Define the parameter schema for Evaluation job."""
- valid_dataset = fields.Function(allow_none=True)
- metrics = fields.Dict(allow_none=True)
-
- @pre_load
- def check_valid_dataset(self, data, **kwargs):
- valid_dataset = data.get("valid_dataset")
- if valid_dataset and not isinstance(valid_dataset, Dataset):
- raise ValidationError({'valid_dataset': [
- "Parameter valid_dataset must be an instance of "
- "mindspore.dataengine.datasets.Dataset"]})
- return data
-
-
- class SearchModelConditionParameter(Schema):
- """Define the search model condition parameter schema."""
- summary_dir = fields.Dict()
- loss_function = fields.Dict()
- train_dataset_path = fields.Dict()
- train_dataset_count = fields.Dict()
- test_dataset_path = fields.Dict()
- test_dataset_count = fields.Dict()
- network = fields.Dict()
- optimizer = fields.Dict()
- learning_rate = fields.Dict()
- epoch = fields.Dict()
- batch_size = fields.Dict()
- loss = fields.Dict()
- model_size = fields.Dict()
- limit = fields.Int(validate=lambda n: 0 < n <= 100)
- offset = fields.Int(validate=lambda n: 0 <= n <= 100000)
- sorted_name = fields.Str()
- sorted_type = fields.Str(allow_none=True)
- lineage_type = fields.Str(
- validate=OneOf(enum_to_list(LineageType)),
- allow_none=True
- )
-
- @staticmethod
- def check_dict_value_type(data, value_type):
- """Check dict value type and int scope."""
- for key, value in data.items():
- if key == "in":
- if not isinstance(value, (list, tuple)):
- raise ValidationError("In operation's value must be list or tuple.")
- else:
- if not isinstance(value, value_type):
- raise ValidationError("Wrong value type.")
- if value_type is int:
- if value < 0 or value > pow(2, 63) - 1:
- raise ValidationError("Int value should <= pow(2, 63) - 1.")
- if isinstance(value, bool):
- raise ValidationError("Wrong value type.")
-
- @staticmethod
- def check_param_value_type(data):
- """Check input param's value type."""
- for key, value in data.items():
- if key == "in":
- if not isinstance(value, (list, tuple)):
- raise ValidationError("In operation's value must be list or tuple.")
- else:
- if isinstance(value, bool) or \
- (not isinstance(value, float) and not isinstance(value, int)):
- raise ValidationError("Wrong value type.")
-
- @validates("loss")
- def check_loss(self, data):
- """Check loss."""
- SearchModelConditionParameter.check_param_value_type(data)
-
- @validates("learning_rate")
- def check_learning_rate(self, data):
- """Check learning_rate."""
- SearchModelConditionParameter.check_param_value_type(data)
-
- @validates("loss_function")
- def check_loss_function(self, data):
- SearchModelConditionParameter.check_dict_value_type(data, str)
-
- @validates("train_dataset_path")
- def check_train_dataset_path(self, data):
- SearchModelConditionParameter.check_dict_value_type(data, str)
-
- @validates("train_dataset_count")
- def check_train_dataset_count(self, data):
- SearchModelConditionParameter.check_dict_value_type(data, int)
-
- @validates("test_dataset_path")
- def check_test_dataset_path(self, data):
- SearchModelConditionParameter.check_dict_value_type(data, str)
-
- @validates("test_dataset_count")
- def check_test_dataset_count(self, data):
- SearchModelConditionParameter.check_dict_value_type(data, int)
-
- @validates("network")
- def check_network(self, data):
- SearchModelConditionParameter.check_dict_value_type(data, str)
-
- @validates("optimizer")
- def check_optimizer(self, data):
- SearchModelConditionParameter.check_dict_value_type(data, str)
-
- @validates("epoch")
- def check_epoch(self, data):
- SearchModelConditionParameter.check_dict_value_type(data, int)
-
- @validates("batch_size")
- def check_batch_size(self, data):
- SearchModelConditionParameter.check_dict_value_type(data, int)
-
- @validates("model_size")
- def check_model_size(self, data):
- SearchModelConditionParameter.check_dict_value_type(data, int)
-
- @validates("summary_dir")
- def check_summary_dir(self, data):
- SearchModelConditionParameter.check_dict_value_type(data, str)
-
- @pre_load
- def check_comparision(self, data, **kwargs):
- """Check comparision for all parameters in schema."""
- for attr, condition in data.items():
- if attr in ["limit", "offset", "sorted_name", "sorted_type", "lineage_type"]:
- continue
-
- if not isinstance(attr, str):
- raise LineageParamValueError('The search attribute not supported.')
-
- if attr not in FIELD_MAPPING and not attr.startswith(('metric/','user_defined/')):
- raise LineageParamValueError('The search attribute not supported.')
-
- if not isinstance(condition, dict):
- raise LineageParamTypeError("The search_condition element {} should be dict."
- .format(attr))
-
- for key in condition.keys():
- if key not in ["eq", "lt", "gt", "le", "ge", "in"]:
- raise LineageParamValueError("The compare condition should be in "
- "('eq', 'lt', 'gt', 'le', 'ge', 'in').")
-
- if attr.startswith('metric/'):
- if len(attr) == 7:
- raise LineageParamValueError(
- 'The search attribute not supported.'
- )
- try:
- SearchModelConditionParameter.check_param_value_type(condition)
- except ValidationError:
- raise MindInsightException(
- error=LineageErrors.LINEAGE_PARAM_METRIC_ERROR,
- message=LineageErrorMsg.LINEAGE_METRIC_ERROR.value.format(attr)
- )
- return data
|