You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cache_item_updater.py 4.2 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Cache item updater."""
  16. import os
  17. from mindinsight.datavisual.data_transform.data_manager import BaseCacheItemUpdater, CachedTrainJob
  18. from mindinsight.lineagemgr.common.log import logger
  19. from mindinsight.lineagemgr.common.exceptions.exceptions import LineageFileNotFoundError
  20. from mindinsight.lineagemgr.common.validator.validate import validate_train_id, validate_added_info
  21. from mindinsight.lineagemgr.lineage_parser import LineageParser, LINEAGE
  22. from mindinsight.utils.exceptions import ParamValueError
  23. def update_lineage_object(data_manager, train_id, added_info: dict):
  24. """Update lineage objects about tag and remark."""
  25. validate_train_id(train_id)
  26. validate_added_info(added_info)
  27. cache_item = data_manager.get_brief_train_job(train_id)
  28. lineage_item = cache_item.get(key=LINEAGE, raise_exception=False)
  29. if lineage_item is None:
  30. logger.warning("Cannot update the lineage for tran job %s, because it does not exist.", train_id)
  31. raise ParamValueError("Cannot update the lineage for tran job %s, because it does not exist." % train_id)
  32. cached_added_info = lineage_item.super_lineage_obj.added_info
  33. new_added_info = dict(cached_added_info)
  34. for key, value in added_info.items():
  35. new_added_info.update({key: value})
  36. with cache_item.lock_key(LINEAGE):
  37. cache_item.get(key=LINEAGE).super_lineage_obj.added_info = new_added_info
  38. class LineageCacheItemUpdater(BaseCacheItemUpdater):
  39. """Cache item updater for lineage info."""
  40. def update_item(self, cache_item: CachedTrainJob):
  41. """Update cache item in place."""
  42. summary_base_dir = cache_item.summary_base_dir
  43. summary_dir = cache_item.abs_summary_dir
  44. # The summary_base_dir and summary_dir have been normalized in data_manager.
  45. if summary_base_dir == summary_dir:
  46. relative_path = "./"
  47. else:
  48. relative_path = f'./{os.path.basename(summary_dir)}'
  49. try:
  50. lineage_parser = self._lineage_parsing(cache_item)
  51. except LineageFileNotFoundError:
  52. self._delete_lineage_in_cache(cache_item, LINEAGE, relative_path)
  53. return
  54. super_lineage_obj = lineage_parser.super_lineage_obj
  55. if super_lineage_obj is None:
  56. logger.debug("There is no lineage to update in train job %s.", relative_path)
  57. return
  58. cache_item.set(key=LINEAGE, value=lineage_parser)
  59. def _lineage_parsing(self, cache_item):
  60. """Parse summaries and return lineage parser."""
  61. train_id = cache_item.train_id
  62. summary_dir = cache_item.abs_summary_dir
  63. update_time = cache_item.basic_info.update_time
  64. cached_lineage_item = cache_item.get(key=LINEAGE, raise_exception=False)
  65. if cached_lineage_item is None:
  66. lineage_parser = LineageParser(train_id, summary_dir, update_time)
  67. else:
  68. lineage_parser = cached_lineage_item
  69. with cache_item.lock_key(LINEAGE):
  70. lineage_parser.update_time = update_time
  71. lineage_parser.load()
  72. return lineage_parser
  73. def _delete_lineage_in_cache(self, cache_item, key, relative_path):
  74. with cache_item.lock_key(key):
  75. try:
  76. cache_item.delete(key=key)
  77. logger.info("Parse failed, delete the tran job %s.", relative_path)
  78. except ParamValueError:
  79. logger.debug("Parse failed, and it is not in cache, "
  80. "no need to delete the train job %s.", relative_path)