You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_task_manager.py 7.5 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Train task manager."""
  16. from mindinsight.utils.exceptions import ParamTypeError
  17. from mindinsight.datavisual.common.log import logger
  18. from mindinsight.datavisual.common import exceptions
  19. from mindinsight.datavisual.common.enums import PluginNameEnum
  20. from mindinsight.datavisual.common.enums import CacheStatus
  21. from mindinsight.datavisual.common.exceptions import QueryStringContainsNullByteError
  22. from mindinsight.datavisual.common.validation import Validation
  23. from mindinsight.datavisual.utils.utils import contains_null_byte
  24. from mindinsight.datavisual.processors.base_processor import BaseProcessor
  25. from mindinsight.datavisual.data_transform.data_manager import DATAVISUAL_PLUGIN_KEY, DATAVISUAL_CACHE_KEY
  26. class TrainTaskManager(BaseProcessor):
  27. """Train task manager."""
  28. def get_single_train_task(self, plugin_name, train_id):
  29. """
  30. get single train task.
  31. Args:
  32. plugin_name (str): Plugin name, refer `PluginNameEnum`.
  33. train_id (str): Specify a training job to query.
  34. Returns:
  35. {'train_jobs': list[TrainJob]}, refer to restful api.
  36. """
  37. Validation.check_param_empty(plugin_name=plugin_name, train_id=train_id)
  38. Validation.check_plugin_name(plugin_name=plugin_name)
  39. train_job = self._data_manager.get_train_job_by_plugin(train_id=train_id, plugin_name=plugin_name)
  40. if train_job is None:
  41. raise exceptions.TrainJobNotExistError()
  42. return dict(train_jobs=[train_job])
  43. def get_plugins(self, train_id, manual_update=True):
  44. """
  45. Queries the plug-in data for the specified training job
  46. Args:
  47. train_id (str): Specify a training job to query.
  48. manual_update (bool): Specifies whether to refresh automatically.
  49. Returns:
  50. dict, refer to restful api.
  51. """
  52. Validation.check_param_empty(train_id=train_id)
  53. if contains_null_byte(train_id=train_id):
  54. raise QueryStringContainsNullByteError("train job id: {} contains null byte.".format(train_id))
  55. if manual_update:
  56. self._data_manager.cache_train_job(train_id)
  57. train_job = self._data_manager.get_train_job(train_id)
  58. try:
  59. data_visual_content = train_job.get_detail(DATAVISUAL_CACHE_KEY)
  60. plugins = data_visual_content.get(DATAVISUAL_PLUGIN_KEY)
  61. except exceptions.TrainJobDetailNotInCacheError:
  62. plugins = []
  63. if not plugins:
  64. default_result = dict()
  65. for plugin_name in PluginNameEnum.list_members():
  66. default_result.update({plugin_name: list()})
  67. return dict(plugins=default_result)
  68. for plugin_name, value in plugins.items():
  69. plugins[plugin_name] = sorted(value)
  70. return dict(
  71. plugins=plugins
  72. )
  73. def query_train_jobs(self, offset=0, limit=10, request_train_id=None):
  74. """
  75. Query train jobs.
  76. Args:
  77. offset (int): Specify page number. Default is 0.
  78. limit (int): Specify page size. Default is 10.
  79. request_train_id (str): Specify train id. Default is None.
  80. Returns:
  81. tuple, return quantity of total train jobs and list of train jobs specified by offset and limit.
  82. """
  83. if request_train_id is not None:
  84. train_job_item = self._get_train_job_item(request_train_id)
  85. if train_job_item is None:
  86. return 0, []
  87. return 1, [train_job_item]
  88. brief_cache = self._data_manager.get_brief_cache()
  89. brief_train_jobs = list(brief_cache.get_train_jobs().values())
  90. brief_train_jobs.sort(key=lambda x: x.basic_info.update_time, reverse=True)
  91. total = len(brief_train_jobs)
  92. start = offset * limit
  93. end = (offset + 1) * limit
  94. train_jobs = []
  95. train_ids = [train_job.basic_info.train_id for train_job in brief_train_jobs[start:end]]
  96. for train_id in train_ids:
  97. train_job_item = self._get_train_job_item(train_id)
  98. if train_job_item is None:
  99. continue
  100. train_jobs.append(train_job_item)
  101. return total, train_jobs
  102. def _get_train_job_item(self, train_id):
  103. """
  104. Get train job item.
  105. Args:
  106. train_id (str): Specify train id.
  107. Returns:
  108. dict, a dict of train job item.
  109. """
  110. try:
  111. train_job = self._data_manager.get_train_job(train_id)
  112. except exceptions.TrainJobNotExistError:
  113. logger.warning('Train job %s not existed', train_id)
  114. return None
  115. basic_info = train_job.get_basic_info()
  116. train_job_item = dict(
  117. train_id=basic_info.train_id,
  118. relative_path=basic_info.train_id,
  119. create_time=basic_info.create_time.strftime('%Y-%m-%d %H:%M:%S'),
  120. update_time=basic_info.update_time.strftime('%Y-%m-%d %H:%M:%S'),
  121. profiler_dir=basic_info.profiler_dir,
  122. cache_status=train_job.cache_status.value,
  123. profiler_type=basic_info.profiler_type,
  124. summary_files=basic_info.summary_files,
  125. graph_files=basic_info.graph_files,
  126. lineage_files=basic_info.lineage_files
  127. )
  128. if train_job.cache_status != CacheStatus.NOT_IN_CACHE:
  129. plugins = self.get_plugins(train_id, manual_update=False)
  130. else:
  131. plugins = dict(plugins={plugin: [] for plugin in PluginNameEnum.list_members()})
  132. train_job_item.update(plugins)
  133. return train_job_item
  134. def cache_train_jobs(self, train_ids):
  135. """
  136. Cache train jobs.
  137. Args:
  138. train_ids (list): Specify list of train_ids to be cached.
  139. Returns:
  140. dict, indicates train job ID and its current cache status.
  141. Raises:
  142. ParamTypeError, if the given train_ids parameter is not in valid type.
  143. """
  144. if not isinstance(train_ids, list):
  145. logger.error("train_ids must be list.")
  146. raise ParamTypeError('train_ids', list)
  147. cache_result = []
  148. for train_id in train_ids:
  149. if not isinstance(train_id, str):
  150. logger.error("train_id must be str.")
  151. raise ParamTypeError('train_id', str)
  152. try:
  153. train_job = self._data_manager.get_train_job(train_id)
  154. except exceptions.TrainJobNotExistError:
  155. logger.warning('Train job %s not existed', train_id)
  156. continue
  157. self._data_manager.cache_train_job(train_id)
  158. cache_result.append(dict(
  159. train_id=train_id,
  160. cache_status=train_job.cache_status.value,
  161. ))
  162. return cache_result