| @@ -14,9 +14,9 @@ | |||
| # ============================================================================ | |||
| """Enums.""" | |||
| from enum import Enum | |||
| import enum | |||
| class BaseEnum(Enum): | |||
| class BaseEnum(enum.Enum): | |||
| @classmethod | |||
| def list_members(cls): | |||
| @@ -38,3 +38,11 @@ class PluginNameEnum(BaseEnum): | |||
| SCALAR = 'scalar' | |||
| GRAPH = 'graph' | |||
| HISTOGRAM = 'histogram' | |||
| @enum.unique | |||
| class CacheStatus(enum.Enum): | |||
| """Train job cache status.""" | |||
| NOT_IN_CACHE = "NOT_IN_CACHE" | |||
| CACHING = "CACHING" | |||
| CACHED = "CACHED" | |||
| @@ -22,7 +22,6 @@ This module also acts as a thread pool manager. | |||
| """ | |||
| import abc | |||
| import datetime | |||
| import enum | |||
| import threading | |||
| import time | |||
| import os | |||
| @@ -34,6 +33,7 @@ from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher | |||
| from mindinsight.conf import settings | |||
| from mindinsight.datavisual.common import exceptions | |||
| from mindinsight.datavisual.common.enums import CacheStatus | |||
| from mindinsight.datavisual.common.log import logger | |||
| from mindinsight.datavisual.common.enums import DataManagerStatus | |||
| from mindinsight.datavisual.common.enums import PluginNameEnum | |||
| @@ -44,14 +44,6 @@ from mindinsight.utils.exceptions import MindInsightException | |||
| from mindinsight.utils.exceptions import ParamValueError | |||
| @enum.unique | |||
| class CacheStatus(enum.Enum): | |||
| """Train job cache status.""" | |||
| NOT_IN_CACHE = "NOT_IN_CACHE" | |||
| CACHING = "CACHING" | |||
| CACHED = "CACHED" | |||
| class _BasicTrainJob: | |||
| """ | |||
| Basic info about train job. | |||
| @@ -267,6 +259,11 @@ class TrainJob: | |||
| """Get cache status.""" | |||
| return self._cache_status | |||
| @cache_status.setter | |||
| def cache_status(self, cache_status): | |||
| """Set cache status.""" | |||
| self._cache_status = cache_status | |||
| class BaseCacheItemUpdater(abc.ABC): | |||
| """Abstract base class for other modules to update cache content.""" | |||
| @@ -464,6 +461,11 @@ class _DetailCacheManager(_BaseCacheManager): | |||
| if loader is None: | |||
| raise TrainJobNotExistError(train_id) | |||
| # Update cache status loader to CACHING if loader is NOT_IN_CACHE | |||
| # before triggering the next interval. | |||
| if loader.cache_status == CacheStatus.NOT_IN_CACHE: | |||
| loader.cache_status = CacheStatus.CACHING | |||
| self._add_loader(loader) | |||
| need_reload = True | |||
| @@ -520,7 +522,13 @@ class _DetailCacheManager(_BaseCacheManager): | |||
| if loader is None: | |||
| logger.debug("Loader %r has been deleted, will not load data.", loader_id) | |||
| return | |||
| loader.data_loader.load() | |||
| # Update loader cache status to CACHED. | |||
| # Loader with cache status CACHED should remain the same cache status. | |||
| loader.cache_status = CacheStatus.CACHED | |||
| except MindInsightException as ex: | |||
| logger.warning("Data loader %r load data failed. " | |||
| "Delete data_loader. Detail: %s", loader_id, ex) | |||
| @@ -711,8 +719,7 @@ class _DetailCacheManager(_BaseCacheManager): | |||
| train_job_obj = CachedTrainJob(basic_info=None) | |||
| train_job_obj.set(DATAVISUAL_CACHE_KEY, train_job) | |||
| # Will assign real value in future. | |||
| train_job_obj.cache_status = CacheStatus.CACHED | |||
| train_job_obj.cache_status = loader.cache_status | |||
| return train_job_obj | |||
| @@ -13,6 +13,7 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Loader struct.""" | |||
| from mindinsight.datavisual.common.enums import CacheStatus | |||
| class LoaderStruct: | |||
| @@ -27,6 +28,7 @@ class LoaderStruct: | |||
| self._path = path | |||
| self._latest_update_time = latest_update_time | |||
| self._data_loader = data_loader | |||
| self._cache_status = CacheStatus.NOT_IN_CACHE | |||
| @property | |||
| def loader_id(self): | |||
| @@ -48,11 +50,21 @@ class LoaderStruct: | |||
| """Get data loader.""" | |||
| return self._data_loader | |||
| @property | |||
| def cache_status(self): | |||
| """Get cache status of loader.""" | |||
| return self._cache_status | |||
| @latest_update_time.setter | |||
| def latest_update_time(self, latest_update_time): | |||
| """Set the latest update time of loader.""" | |||
| self._latest_update_time = latest_update_time | |||
| @cache_status.setter | |||
| def cache_status(self, cache_status): | |||
| """Set cache status of loader.""" | |||
| self._cache_status = cache_status | |||
| def to_dict(self): | |||
| """Transform LoaderStruct to dict.""" | |||
| return dict( | |||
| @@ -17,10 +17,10 @@ | |||
| from mindinsight.datavisual.common.log import logger | |||
| from mindinsight.datavisual.common import exceptions | |||
| from mindinsight.datavisual.common.enums import PluginNameEnum | |||
| from mindinsight.datavisual.common.enums import CacheStatus | |||
| from mindinsight.datavisual.common.validation import Validation | |||
| from mindinsight.datavisual.processors.base_processor import BaseProcessor | |||
| from mindinsight.datavisual.data_transform.data_manager import DATAVISUAL_PLUGIN_KEY, DATAVISUAL_CACHE_KEY | |||
| from mindinsight.datavisual.data_transform.data_manager import CacheStatus | |||
| class TrainTaskManager(BaseProcessor): | |||
| @@ -132,23 +132,22 @@ class TrainTaskManager(BaseProcessor): | |||
| Returns: | |||
| dict, indicates train job ID and its current cache status. | |||
| """ | |||
| brief_cache = self._data_manager.get_brief_cache() | |||
| brief_train_jobs = brief_cache.get_train_jobs() | |||
| for train_id in train_ids: | |||
| brief_train_job = brief_train_jobs.get(train_id) | |||
| if brief_train_job is None: | |||
| raise exceptions.TrainJobNotExistError(f'Train id {train_id} not exists') | |||
| cache_result = [] | |||
| for train_id in train_ids: | |||
| brief_train_job = brief_train_jobs.get(train_id) | |||
| if brief_train_job.cache_status.value == CacheStatus.NOT_IN_CACHE.value: | |||
| try: | |||
| train_job = self._data_manager.get_train_job(train_id) | |||
| except exceptions.TrainJobNotExistError: | |||
| logger.warning('Train job %s not existed', train_id) | |||
| continue | |||
| if train_job.cache_status == CacheStatus.NOT_IN_CACHE: | |||
| self._data_manager.cache_train_job(train_id) | |||
| # Update loader cache status to CACHING for consistency in response. | |||
| train_job.cache_status = CacheStatus.CACHING | |||
| cache_result.append({ | |||
| 'train_id': train_id, | |||
| 'cache_status': brief_train_job.cache_status.value, | |||
| }) | |||
| cache_result.append(dict( | |||
| train_id=train_id, | |||
| cache_status=train_job.cache_status.value, | |||
| )) | |||
| return cache_result | |||