|
- # Copyright 2019 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Train task manager."""
-
- from mindinsight.utils.exceptions import ParamTypeError
- from mindinsight.datavisual.common.log import logger
- from mindinsight.datavisual.common import exceptions
- from mindinsight.datavisual.common.enums import PluginNameEnum
- from mindinsight.datavisual.common.enums import CacheStatus
- from mindinsight.datavisual.common.exceptions import QueryStringContainsNullByteError
- from mindinsight.datavisual.common.validation import Validation
- from mindinsight.datavisual.utils.utils import contains_null_byte
- from mindinsight.datavisual.processors.base_processor import BaseProcessor
- from mindinsight.datavisual.data_transform.data_manager import DATAVISUAL_PLUGIN_KEY, DATAVISUAL_CACHE_KEY
-
-
- class TrainTaskManager(BaseProcessor):
- """Train task manager."""
-
- def get_single_train_task(self, plugin_name, train_id):
- """
- get single train task.
-
- Args:
- plugin_name (str): Plugin name, refer `PluginNameEnum`.
- train_id (str): Specify a training job to query.
-
- Returns:
- {'train_jobs': list[TrainJob]}, refer to restful api.
- """
- Validation.check_param_empty(plugin_name=plugin_name, train_id=train_id)
- Validation.check_plugin_name(plugin_name=plugin_name)
- train_job = self._data_manager.get_train_job_by_plugin(train_id=train_id, plugin_name=plugin_name)
- if train_job is None:
- raise exceptions.TrainJobNotExistError()
- return dict(train_jobs=[train_job])
-
- def get_plugins(self, train_id, manual_update=True):
- """
- Queries the plug-in data for the specified training job
-
- Args:
- train_id (str): Specify a training job to query.
- manual_update (bool): Specifies whether to refresh automatically.
-
- Returns:
- dict, refer to restful api.
- """
- Validation.check_param_empty(train_id=train_id)
- if contains_null_byte(train_id=train_id):
- raise QueryStringContainsNullByteError("train job id: {} contains null byte.".format(train_id))
-
- if manual_update:
- self._data_manager.cache_train_job(train_id)
-
- train_job = self._data_manager.get_train_job(train_id)
-
- try:
- data_visual_content = train_job.get_detail(DATAVISUAL_CACHE_KEY)
- plugins = data_visual_content.get(DATAVISUAL_PLUGIN_KEY)
- except exceptions.TrainJobDetailNotInCacheError:
- plugins = []
-
- if not plugins:
- default_result = dict()
- for plugin_name in PluginNameEnum.list_members():
- default_result.update({plugin_name: list()})
- return dict(plugins=default_result)
-
- return dict(
- plugins=plugins
- )
-
- def query_train_jobs(self, offset=0, limit=10, request_train_id=None):
- """
- Query train jobs.
-
- Args:
- offset (int): Specify page number. Default is 0.
- limit (int): Specify page size. Default is 10.
- request_train_id (str): Specify train id. Default is None.
-
- Returns:
- tuple, return quantity of total train jobs and list of train jobs specified by offset and limit.
- """
- if request_train_id is not None:
- train_job_item = self._get_train_job_item(request_train_id)
- if train_job_item is None:
- return 0, []
- return 1, [train_job_item]
-
- brief_cache = self._data_manager.get_brief_cache()
- brief_train_jobs = list(brief_cache.get_train_jobs().values())
- brief_train_jobs.sort(key=lambda x: x.basic_info.update_time, reverse=True)
- total = len(brief_train_jobs)
-
- start = offset * limit
- end = (offset + 1) * limit
- train_jobs = []
-
- train_ids = [train_job.basic_info.train_id for train_job in brief_train_jobs[start:end]]
-
- for train_id in train_ids:
- train_job_item = self._get_train_job_item(train_id)
- if train_job_item is None:
- continue
- train_jobs.append(train_job_item)
-
- return total, train_jobs
-
- def _get_train_job_item(self, train_id):
- """
- Get train job item.
-
- Args:
- train_id (str): Specify train id.
-
- Returns:
- dict, a dict of train job item.
- """
- try:
- train_job = self._data_manager.get_train_job(train_id)
- except exceptions.TrainJobNotExistError:
- logger.warning('Train job %s not existed', train_id)
- return None
-
- basic_info = train_job.get_basic_info()
- train_job_item = dict(
- train_id=basic_info.train_id,
- relative_path=basic_info.train_id,
- create_time=basic_info.create_time.strftime('%Y-%m-%d %H:%M:%S'),
- update_time=basic_info.update_time.strftime('%Y-%m-%d %H:%M:%S'),
- profiler_dir=basic_info.profiler_dir,
- cache_status=train_job.cache_status.value,
- profiler_type=basic_info.profiler_type,
- summary_files=basic_info.summary_files,
- graph_files=basic_info.graph_files,
- lineage_files=basic_info.lineage_files
- )
-
- if train_job.cache_status != CacheStatus.NOT_IN_CACHE:
- plugins = self.get_plugins(train_id, manual_update=False)
- else:
- plugins = dict(plugins={plugin: [] for plugin in PluginNameEnum.list_members()})
-
- train_job_item.update(plugins)
- return train_job_item
-
- def cache_train_jobs(self, train_ids):
- """
- Cache train jobs.
-
- Args:
- train_ids (list): Specify list of train_ids to be cached.
-
- Returns:
- dict, indicates train job ID and its current cache status.
-
- Raises:
- ParamTypeError, if the given train_ids parameter is not in valid type.
- """
- if not isinstance(train_ids, list):
- logger.error("train_ids must be list.")
- raise ParamTypeError('train_ids', list)
-
- cache_result = []
- for train_id in train_ids:
- if not isinstance(train_id, str):
- logger.error("train_id must be str.")
- raise ParamTypeError('train_id', str)
-
- try:
- train_job = self._data_manager.get_train_job(train_id)
- except exceptions.TrainJobNotExistError:
- logger.warning('Train job %s not existed', train_id)
- continue
-
- self._data_manager.cache_train_job(train_id)
-
- cache_result.append(dict(
- train_id=train_id,
- cache_status=train_job.cache_status.value,
- ))
-
- return cache_result
|