You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_task_manager.py 6.5 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Train task manager."""
  16. from mindinsight.utils.exceptions import ParamTypeError
  17. from mindinsight.datavisual.common.log import logger
  18. from mindinsight.datavisual.common import exceptions
  19. from mindinsight.datavisual.common.enums import PluginNameEnum
  20. from mindinsight.datavisual.common.enums import CacheStatus
  21. from mindinsight.datavisual.common.validation import Validation
  22. from mindinsight.datavisual.processors.base_processor import BaseProcessor
  23. from mindinsight.datavisual.data_transform.data_manager import DATAVISUAL_PLUGIN_KEY, DATAVISUAL_CACHE_KEY
  24. class TrainTaskManager(BaseProcessor):
  25. """Train task manager."""
  26. def get_single_train_task(self, plugin_name, train_id):
  27. """
  28. get single train task.
  29. Args:
  30. plugin_name (str): Plugin name, refer `PluginNameEnum`.
  31. train_id (str): Specify a training job to query.
  32. Returns:
  33. {'train_jobs': list[TrainJob]}, refer to restful api.
  34. """
  35. Validation.check_param_empty(plugin_name=plugin_name, train_id=train_id)
  36. Validation.check_plugin_name(plugin_name=plugin_name)
  37. train_job = self._data_manager.get_train_job_by_plugin(train_id=train_id, plugin_name=plugin_name)
  38. if train_job is None:
  39. raise exceptions.TrainJobNotExistError()
  40. return dict(train_jobs=[train_job])
  41. def get_plugins(self, train_id, manual_update=True):
  42. """
  43. Queries the plug-in data for the specified training job
  44. Args:
  45. train_id (str): Specify a training job to query.
  46. manual_update (bool): Specifies whether to refresh automatically.
  47. Returns:
  48. dict, refer to restful api.
  49. """
  50. Validation.check_param_empty(train_id=train_id)
  51. if manual_update:
  52. self._data_manager.cache_train_job(train_id)
  53. train_job = self._data_manager.get_train_job(train_id)
  54. try:
  55. data_visual_content = train_job.get_detail(DATAVISUAL_CACHE_KEY)
  56. plugins = data_visual_content.get(DATAVISUAL_PLUGIN_KEY)
  57. except exceptions.TrainJobDetailNotInCacheError:
  58. plugins = []
  59. if not plugins:
  60. default_result = dict()
  61. for plugin_name in PluginNameEnum.list_members():
  62. default_result.update({plugin_name: list()})
  63. return dict(plugins=default_result)
  64. return dict(
  65. plugins=plugins
  66. )
  67. def query_train_jobs(self, offset=0, limit=10):
  68. """
  69. Query train jobs.
  70. Args:
  71. offset (int): Specify page number. Default is 0.
  72. limit (int): Specify page size. Default is 10.
  73. Returns:
  74. tuple, return quantity of total train jobs and list of train jobs specified by offset and limit.
  75. """
  76. brief_cache = self._data_manager.get_brief_cache()
  77. brief_train_jobs = list(brief_cache.get_train_jobs().values())
  78. brief_train_jobs.sort(key=lambda x: x.basic_info.update_time, reverse=True)
  79. total = len(brief_train_jobs)
  80. start = offset * limit
  81. end = (offset + 1) * limit
  82. train_jobs = []
  83. train_ids = [train_job.basic_info.train_id for train_job in brief_train_jobs[start:end]]
  84. for train_id in train_ids:
  85. try:
  86. train_job = self._data_manager.get_train_job(train_id)
  87. except exceptions.TrainJobNotExistError:
  88. logger.warning('Train job %s not existed', train_id)
  89. continue
  90. basic_info = train_job.get_basic_info()
  91. train_job_item = dict(
  92. train_id=basic_info.train_id,
  93. relative_path=basic_info.train_id,
  94. create_time=basic_info.create_time.strftime('%Y-%m-%d %H:%M:%S'),
  95. update_time=basic_info.update_time.strftime('%Y-%m-%d %H:%M:%S'),
  96. profiler_dir=basic_info.profiler_dir,
  97. cache_status=train_job.cache_status.value,
  98. )
  99. if train_job.cache_status == CacheStatus.CACHED:
  100. plugins = self.get_plugins(train_id)
  101. else:
  102. plugins = dict(plugins={
  103. 'graph': [],
  104. 'scalar': [],
  105. 'image': [],
  106. 'histogram': [],
  107. })
  108. train_job_item.update(plugins)
  109. train_jobs.append(train_job_item)
  110. return total, train_jobs
  111. def cache_train_jobs(self, train_ids):
  112. """
  113. Cache train jobs.
  114. Args:
  115. train_ids (list): Specify list of train_ids to be cached.
  116. Returns:
  117. dict, indicates train job ID and its current cache status.
  118. Raises:
  119. ParamTypeError, if the given train_ids parameter is not in valid type.
  120. """
  121. if not isinstance(train_ids, list):
  122. logger.error("train_ids must be list.")
  123. raise ParamTypeError('train_ids', list)
  124. cache_result = []
  125. for train_id in train_ids:
  126. if not isinstance(train_id, str):
  127. logger.error("train_id must be str.")
  128. raise ParamTypeError('train_id', str)
  129. try:
  130. train_job = self._data_manager.get_train_job(train_id)
  131. except exceptions.TrainJobNotExistError:
  132. logger.warning('Train job %s not existed', train_id)
  133. continue
  134. if train_job.cache_status == CacheStatus.NOT_IN_CACHE:
  135. self._data_manager.cache_train_job(train_id)
  136. # Update loader cache status to CACHING for consistency in response.
  137. train_job.cache_status = CacheStatus.CACHING
  138. cache_result.append(dict(
  139. train_id=train_id,
  140. cache_status=train_job.cache_status.value,
  141. ))
  142. return cache_result