# Copyright 2019 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """ Management of all events data. This module exists to all loaders. It can read events data through the DataLoader. This module also acts as a thread pool manager. """ import abc import enum import threading import time import datetime import os from typing import Iterable, Optional from concurrent.futures import ThreadPoolExecutor, wait, ALL_COMPLETED from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher from mindinsight.conf import settings from mindinsight.datavisual.common import exceptions from mindinsight.datavisual.common.log import logger from mindinsight.datavisual.common.enums import DataManagerStatus from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.exceptions import TrainJobNotExistError from mindinsight.datavisual.data_transform.loader_generators.loader_generator import MAX_DATA_LOADER_SIZE from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator from mindinsight.utils.exceptions import MindInsightException from mindinsight.utils.exceptions import ParamValueError @enum.unique class _CacheStatus(enum.Enum): """Train job cache status.""" NOT_IN_CACHE = "NOT_IN_CACHE" CACHING = "CACHING" CACHED = "CACHED" class _BasicTrainJob: """ Basic info about train job. Args: train_id (str): Id of the train job. abs_summary_base_dir (str): The canonical path of summary base directory. It should be the return value of realpath(). abs_summary_dir (str): The canonical path of summary directory. It should be the return value of realpath(). create_time (DateTime): The create time of summary directory. update_time (DateTime): The latest modify time of summary files directly in the summary directory. """ def __init__(self, train_id, abs_summary_base_dir, abs_summary_dir, create_time, update_time): self._train_id = train_id self._abs_summary_base_dir = abs_summary_base_dir self._abs_summary_dir = abs_summary_dir self._create_time = create_time self._update_time = update_time @property def summary_dir(self): """Get summary directory path.""" return self._abs_summary_dir @property def train_id(self): """Get train id.""" return self._train_id class CachedTrainJob: """ Cache item for BriefCacheManager. DetailCacheManager will also wrap it's return value with this class. Args: basic_info (_BasicTrainJob): Basic info about the train job. """ def __init__(self, basic_info: _BasicTrainJob): self._basic_info = basic_info self._last_access_time = datetime.datetime.utcnow() # Other cached content is stored here. self._content = {} self._cache_status = _CacheStatus.NOT_IN_CACHE @property def cache_status(self): """Get cache status.""" return self._cache_status @cache_status.setter def cache_status(self, value): """Set cache status.""" self._cache_status = value def update_access_time(self): """Update last access time of this cache item.""" self._last_access_time = datetime.datetime.utcnow() @property def last_access_time(self): """Get last access time for purposes such as LRU.""" return self._last_access_time @property def summary_dir(self): """Get summary directory path.""" return self._basic_info.summary_dir def set(self, key, value): """Set value to cache.""" self._content[key] = value def get(self, key): """Get value from cache.""" try: return self._content[key] except KeyError: raise ParamValueError("Invalid cache key({}).".format(key)) @property def basic_info(self): """Get basic train job info.""" return self._basic_info @basic_info.setter def basic_info(self, value): """Set basic train job info.""" self._basic_info = value class TrainJob: """ Train job object. You must not create TrainJob objects manually. You should always get TrainJob objects from DataManager. Args: brief_train_job (CachedTrainJob): Brief info about train job. detail_train_job (Optional[CachedTrainJob]): Detailed info about train job. Default: None. """ def __init__(self, brief_train_job: CachedTrainJob, detail_train_job: Optional[CachedTrainJob] = None): self._brief = brief_train_job self._detail = detail_train_job if self._detail is None: self._cache_status = _CacheStatus.NOT_IN_CACHE else: self._cache_status = self._detail.cache_status def has_detail(self): """Whether this train job has detailed info in cache.""" return bool(self._detail is not None) def get_detail(self, key): """ Get detail content. Args: key (Any): Cache key. Returns: Any, cache content. Raises: TrainJobDetailNotInCacheError: when this train job has no detail cache. """ if not self.has_detail(): raise exceptions.TrainJobDetailNotInCacheError() return self._detail.get(key) def get_brief(self, key): """ Get brief content. Args: key (Any): Cache key. Returns: Any, cache content. """ return self._brief.get(key) class BaseCacheItemUpdater(abc.ABC): """Abstract base class for other modules to update cache content.""" def update_item(self, cache_item: CachedTrainJob): """ Update cache item in place. Args: cache_item (CachedTrainJob): The cache item to be processed. """ raise NotImplementedError() class _BaseCacheManager: """Base class for cache manager.""" def __init__(self): # Use dict to remove duplicate updaters. self._updaters = {} # key is train_id self._lock = threading.Lock() self._cache_items = {} def size(self): """Gets used cache slots.""" return len(self._cache_items) def register_cache_item_updater(self, updater: BaseCacheItemUpdater): """Register cache item updater.""" self._updaters[updater.__class__.__qualname__] = updater def get_train_jobs(self): """Get cached train jobs.""" copied_train_jobs = dict(self._cache_items) return copied_train_jobs def get_train_job(self, train_id): """Get cached train job.""" try: return self._cache_items[train_id] except KeyError: raise TrainJobNotExistError(train_id) def cache_train_job(self, train_id) -> bool: """ Cache given train job and update train job's last access time. This method should return true if reload actions should be taken to cache the train job. Args: train_id (str): Train Id. """ raise NotImplementedError() def delete_train_job(self, train_id): """Delete train job from cache.""" if train_id in self._cache_items: del self._cache_items[train_id] def has_content(self): """Whether this cache manager has train jobs.""" return bool(self._cache_items) def update_cache(self, disk_train_jobs: Iterable[_BasicTrainJob]): """ Update cache according to given train jobs on disk. Different cache manager should implement different cache update policies in this method. Args: disk_train_jobs (Iterable[_BasicTrainJob]): Train jobs on disk. """ raise NotImplementedError() def _merge_with_disk(self, disk_train_jobs: Iterable[_BasicTrainJob]): """ Merge train jobs in cache with train jobs from disk This method will remove train jobs not on disk. Call this function with lock for thread safety. Args: disk_train_jobs (Iterable[_BasicTrainJob]): Basic train jobs info from disk. Returns: dict, a dict containing train jobs to be cached. """ new_cache_items = {} for train_job in disk_train_jobs: if train_job.train_id not in self._cache_items: new_cache_items[train_job.train_id] = CachedTrainJob(train_job) else: reused_train_job = self._cache_items[train_job.train_id] reused_train_job.basic_info = train_job new_cache_items[train_job.train_id] = reused_train_job return new_cache_items class _BriefCacheManager(_BaseCacheManager): """A cache manager that holds all disk train jobs on disk.""" def cache_train_job(self, train_id): """ Cache given train job. All disk train jobs are cached on every reload, so this method always return false. Args: train_id (str): Train Id. """ if train_id in self._cache_items: self._cache_items[train_id].update_access_time() return False def update_cache(self, disk_train_jobs): """Update cache.""" with self._lock: new_cache_items = self._merge_with_disk(disk_train_jobs) self._cache_items = new_cache_items for updater in self._updaters.values(): for cache_item in self._cache_items.values(): updater.update_item(cache_item) # Key for plugin tags. DATAVISUAL_PLUGIN_KEY = "tag_mapping" # Detail train job cache key for datavisual content. DATAVISUAL_CACHE_KEY = "datavisual" class _DetailCacheManager(_BaseCacheManager): """A cache manager that holds detailed info for most recently used train jobs.""" def __init__(self, loader_generators): super().__init__() self._loader_pool = {} self._deleted_id_list = [] self._loader_pool_mutex = threading.Lock() self._max_threads_count = 30 self._loader_generators = loader_generators def size(self): """ Get the number of items in this cache manager. To be implemented. Returns: int, the number of items in this cache manager. """ raise NotImplementedError() def loader_pool_size(self): """Get loader pool size.""" return len(self._loader_pool) def update_cache(self, disk_train_jobs: Iterable[_BasicTrainJob]): """ Update cache. Will switch to using disk_train_jobs in the future. Args: disk_train_jobs (Iterable[_BasicTrainJob]): Basic info about train jobs on disk. """ self._generate_loaders() self._execute_load_data() def cache_train_job(self, train_id): """Cache given train job.""" loader = None need_reload = False with self._loader_pool_mutex: if self._is_loader_in_loader_pool(train_id, self._loader_pool): loader = self._loader_pool.get(train_id) if loader is None: for generator in self._loader_generators: tmp_loader = generator.generate_loader_by_train_id(train_id) if loader and loader.latest_update_time > tmp_loader.latest_update_time: continue loader = tmp_loader if loader is None: raise TrainJobNotExistError(train_id) self._add_loader(loader) need_reload = True self._update_loader_latest_update_time(loader.loader_id) return need_reload def get_train_jobs(self): """ Get train jobs To be implemented. """ def _add_loader(self, loader): """ Add a loader to load data. Args: loader (LoaderStruct): A object of `Loader`. """ if len(self._loader_pool) >= MAX_DATA_LOADER_SIZE: delete_number = len(self._loader_pool) - MAX_DATA_LOADER_SIZE + 1 sorted_loaders = sorted(self._loader_pool.items(), key=lambda loader: loader[1].latest_update_time) for index in range(delete_number): delete_loader_id = sorted_loaders[index][0] self._delete_loader(delete_loader_id) self._loader_pool.update({loader.loader_id: loader}) def _delete_loader(self, loader_id): """ Delete loader from loader pool by loader id. Args: loader_id (str): ID of loader. """ if self._loader_pool.get(loader_id) is not None: logger.debug("delete loader %s", loader_id) self._loader_pool.pop(loader_id) def _execute_loader(self, loader_id): """ Load data form data_loader. If there is something wrong by loading, add logs and delete the loader. Args: loader_id (str): An ID for `Loader`. """ try: with self._loader_pool_mutex: loader = self._loader_pool.get(loader_id, None) if loader is None: logger.debug("Loader %r has been deleted, will not load data.", loader_id) return loader.data_loader.load() except MindInsightException as ex: logger.warning("Data loader %r load data failed. " "Delete data_loader. Detail: %s", loader_id, ex) with self._loader_pool_mutex: self._delete_loader(loader_id) def _generate_loaders(self): """This function generates the loader from given path.""" loader_dict = {} for generator in self._loader_generators: loader_dict.update(generator.generate_loaders(self._loader_pool)) sorted_loaders = sorted(loader_dict.items(), key=lambda loader: loader[1].latest_update_time) latest_loaders = sorted_loaders[-MAX_DATA_LOADER_SIZE:] self._deal_loaders(latest_loaders) def _deal_loaders(self, latest_loaders): """ This function determines which loaders to keep or remove or added. It is based on the given dict of loaders. Args: latest_loaders (list[dict]): A list of . """ with self._loader_pool_mutex: for loader_id, loader in latest_loaders: if self._loader_pool.get(loader_id, None) is None: self._add_loader(loader) continue # If this loader was updated manually before, # its latest_update_time may bigger than update_time in summary. if self._loader_pool[loader_id].latest_update_time < loader.latest_update_time: self._update_loader_latest_update_time(loader_id, loader.latest_update_time) def _execute_load_data(self): """Load data through multiple threads.""" threads_count = self._get_threads_count() if not threads_count: logger.info("Can not find any valid train log path to load, loader pool is empty.") return logger.info("Start to execute load data. threads_count: %s.", threads_count) with ThreadPoolExecutor(max_workers=threads_count) as executor: futures = [] loader_pool = self._get_snapshot_loader_pool() for loader_id in loader_pool: future = executor.submit(self._execute_loader, loader_id) futures.append(future) wait(futures, return_when=ALL_COMPLETED) def _get_threads_count(self): """ Use the maximum number of threads available. Returns: int, number of threads. """ threads_count = min(self._max_threads_count, len(self._loader_pool)) return threads_count def delete_train_job(self, train_id): """ Delete train job with a train id. Args: train_id (str): ID for train job. """ with self._loader_pool_mutex: self._delete_loader(train_id) def list_tensors(self, train_id, tag): """ List tensors of the given train job and tag. If the tensor can not find by the given tag, will raise exception. Args: train_id (str): ID for train job. tag (str): The tag name. Returns: NamedTuple, the tuple format is `collections.namedtuple('_Tensor', ['wall_time', 'event_step', 'value'])`. the value will contain the given tag data. """ loader_pool = self._get_snapshot_loader_pool() if not self._is_loader_in_loader_pool(train_id, loader_pool): raise TrainJobNotExistError("Can not find the given train job in cache.") data_loader = loader_pool[train_id].data_loader events_data = data_loader.get_events_data() try: tensors = events_data.tensors(tag) except KeyError: error_msg = "Can not find any data in this train job by given tag." raise ParamValueError(error_msg) return tensors def _check_train_job_exist(self, train_id, loader_pool): """ Check train job exist, if not exist, will raise exception. Args: train_id (str): The given train job id. loader_pool (dict[str, LoaderStruct]): Refer to self._loader_pool. Raises: TrainJobNotExistError: Can not find train job in data manager. """ is_exist = False if train_id in loader_pool: return for generator in self._loader_generators: if generator.check_train_job_exist(train_id): is_exist = True break if not is_exist: raise TrainJobNotExistError("Can not find the train job in data manager.") def _is_loader_in_loader_pool(self, train_id, loader_pool): """ Check train job exist, if not exist, return False. Else, return True. Args: train_id (str): The given train job id. loader_pool (dict): See self._loader_pool. Returns: bool, if loader in loader pool, return True. """ if train_id in loader_pool: return True return False def _get_snapshot_loader_pool(self): """ Create a snapshot of data loader pool to avoid concurrent mutation and iteration issues. Returns: dict, a copy of `self._loader_pool`. """ with self._loader_pool_mutex: return dict(self._loader_pool) def get_train_job(self, train_id): """ Get train job by train ID. This method overrides parent method. Args: train_id (str): Train ID for train job. Returns: dict, single train job, if can not find any data, will return None. """ self._check_train_job_exist(train_id, self._loader_pool) loader = self._get_loader(train_id) if loader is None: logger.warning("No valid summary log in train job %s, " "or it is not in the cache.", train_id) return None train_job = loader.to_dict() train_job.pop('data_loader') plugin_data = {} for plugin_name in PluginNameEnum.list_members(): job = self.get_train_job_by_plugin(train_id, plugin_name=plugin_name) if job is None: plugin_data[plugin_name] = [] else: plugin_data[plugin_name] = job['tags'] train_job.update({DATAVISUAL_PLUGIN_KEY: plugin_data}) # Will fill basic_info value in future. train_job_obj = CachedTrainJob(basic_info=None) train_job_obj.set(DATAVISUAL_CACHE_KEY, train_job) # Will assign real value in future. train_job_obj.cache_status = _CacheStatus.CACHED return train_job_obj def _get_loader(self, train_id): """ Get loader by train id. Args: train_id (str): Train Id. Returns: LoaderStruct, the loader. """ loader = None with self._loader_pool_mutex: if self._is_loader_in_loader_pool(train_id, self._loader_pool): loader = self._loader_pool.get(train_id) return loader def _update_loader_latest_update_time(self, loader_id, latest_update_time=None): """ Update loader with latest_update_time. Args: loader_id (str): ID of loader. latest_update_time (float): Timestamp. """ if latest_update_time is None: latest_update_time = time.time() self._loader_pool[loader_id].latest_update_time = latest_update_time def get_train_job_by_plugin(self, train_id, plugin_name): """ Get a train job by train job id. If the given train job does not has the given plugin data, the tag list will be empty. Args: train_id (str): Get train job info by the given id. plugin_name (str): Get tags by given plugin. Returns: TypedDict('TrainJobEntity', {'id': str, 'name': str, 'tags': List[str]}), a train job object. """ self._check_train_job_exist(train_id, self._loader_pool) loader = self._get_loader(train_id) if loader is None: logger.warning("No valid summary log in train job %s, " "or it is not in the cache.", train_id) return None name = loader.name data_loader = loader.data_loader tags = [] try: events_data = data_loader.get_events_data() tags = events_data.list_tags_by_plugin(plugin_name) except KeyError: logger.debug("Plugin name %r does not exist " "in train job %r, and set tags to empty list.", plugin_name, name) except AttributeError: logger.debug("Train job %r has been deleted or it has not loaded data, " "and set tags to empty list.", name) result = dict(id=train_id, name=name, tags=tags) return result class DataManager: """ DataManager manages a pool of loader which help access events data. Each loader helps deal the data of the events. A loader corresponds to an events_data. The DataManager build a pool including all the data_loader. The data_loader provides extracting method to get the information of events. """ def __init__(self, summary_base_dir): """ Initialize the pool of loader and the dict of name-to-path. Args: summary_base_dir (str): Base summary directory. self._status: Refer `datavisual.common.enums.DataManagerStatus`. """ self._summary_base_dir = os.path.realpath(summary_base_dir) self._status = DataManagerStatus.INIT.value self._status_mutex = threading.Lock() self._reload_interval = 3 loader_generators = [DataLoaderGenerator(self._summary_base_dir)] self._detail_cache = _DetailCacheManager(loader_generators) self._brief_cache = _BriefCacheManager() def start_load_data(self, reload_interval=settings.RELOAD_INTERVAL, max_threads_count=MAX_DATA_LOADER_SIZE): """ Start threads for loading data. Args: reload_interval (int): Time to reload data once. max_threads_count (int): Max number of threads of execution. """ logger.info("Start to load data, reload_interval: %s, " "max_threads_count: %s.", reload_interval, max_threads_count) DataManager.check_reload_interval(reload_interval) DataManager.check_max_threads_count(max_threads_count) self._reload_interval = reload_interval self._max_threads_count = max_threads_count thread = threading.Thread(target=self._reload_data, name='start_load_data_thread') thread.daemon = True thread.start() def _reload_data(self): """This function periodically loads the data.""" # Let gunicorn load other modules first. time.sleep(1) while True: self._load_data() if not self._reload_interval: break time.sleep(self._reload_interval) def reload_data(self): """ Reload the data once. This function needs to be used after `start_load_data` function. """ logger.debug("start to reload data") thread = threading.Thread(target=self._load_data, name='reload_data_thread') thread.daemon = False thread.start() def _load_data(self): """This function will load data once and ignore it if the status is loading.""" logger.info("Start to load data, reload interval: %r.", self._reload_interval) with self._status_mutex: if self.status == DataManagerStatus.LOADING.value: logger.debug("Current status is %s , will ignore to load data.", self.status) return self.status = DataManagerStatus.LOADING.value summaries_info = SummaryWatcher().list_summary_directories(self._summary_base_dir) basic_train_jobs = [] for info in summaries_info: basic_train_jobs.append(_BasicTrainJob( train_id=info['relative_path'], abs_summary_base_dir=self._summary_base_dir, abs_summary_dir=os.path.realpath(os.path.join( self._summary_base_dir, info['relative_path'] )), create_time=info['create_time'], update_time=info['update_time'] )) self._brief_cache.update_cache(basic_train_jobs) self._detail_cache.update_cache(basic_train_jobs) if not self._brief_cache.has_content() and not self._detail_cache.has_content(): self.status = DataManagerStatus.INVALID.value else: self.status = DataManagerStatus.DONE.value logger.info("Load event data end, status: %r, and loader pool size is %r.", self.status, self._detail_cache.loader_pool_size()) @staticmethod def check_reload_interval(reload_interval): """ Check reload interval is valid. Args: reload_interval (int): Reload interval >= 0. """ if not isinstance(reload_interval, int): raise ParamValueError("The value of reload interval should be integer.") if reload_interval < 0: raise ParamValueError("The value of reload interval should be >= 0.") @staticmethod def check_max_threads_count(max_threads_count): """ Threads count should be a integer, and should > 0. Args: max_threads_count (int), should > 0. """ if not isinstance(max_threads_count, int): raise ParamValueError("The value of max threads count should be integer.") if max_threads_count <= 0: raise ParamValueError("The value of max threads count should be > 0.") def get_train_job_by_plugin(self, train_id, plugin_name): """ Get a train job by train job id. If the given train job does not has the given plugin data, the tag list will be empty. Args: train_id (str): Get train job info by the given id. plugin_name (str): Get tags by given plugin. Returns: TypedDict('TrainJobEntity', {'id': str, 'name': str, 'tags': List[str]}), a train job object. """ self._check_status_valid() return self._detail_cache.get_train_job_by_plugin(train_id, plugin_name) def delete_train_job(self, train_id, only_delete_from_cache=True): """ Delete train job with a train id. Args: train_id (str): ID for train job. """ if not only_delete_from_cache: raise NotImplementedError("Delete from both cache and disk is not supported.") self._brief_cache.delete_train_job(train_id) self._detail_cache.delete_train_job(train_id) def list_tensors(self, train_id, tag): """ List tensors of the given train job and tag. If the tensor can not find by the given tag, will raise exception. Args: train_id (str): ID for train job. tag (str): The tag name. Returns: NamedTuple, the tuple format is `collections.namedtuple('_Tensor', ['wall_time', 'event_step', 'value'])`. the value will contain the given tag data. """ self._check_status_valid() return self._detail_cache.list_tensors(train_id, tag) def _check_status_valid(self): """Check if the status is valid to load data.""" if self.status == DataManagerStatus.INIT.value: raise exceptions.SummaryLogIsLoading("Data is being loaded, current status: %s." % self._status) def get_train_job(self, train_id): """ Get train job by train ID. Args: train_id (str): Train ID for train job. Returns: dict, single train job, if can not find any data, will return None. """ self._check_status_valid() detail_train_job = self._detail_cache.get_train_job(train_id) brief_train_job = self._brief_cache.get_train_job(train_id) return TrainJob(brief_train_job, detail_train_job) def list_train_jobs(self): """ List train jobs. To be implemented. """ raise NotImplementedError() @property def status(self): """ Get the status of data manager. Returns: DataManagerStatus, the status of data manager. """ return self._status @status.setter def status(self, status): """Set data manger status.""" self._status = status def cache_train_job(self, train_id): """Cache given train job (async).""" brief_need_reload = self._brief_cache.cache_train_job(train_id) detail_need_reload = self._detail_cache.cache_train_job(train_id) if brief_need_reload or detail_need_reload: self.reload_data() def register_brief_cache_item_updater(self, updater: BaseCacheItemUpdater): """Register brief cache item updater for brief cache manager.""" self._brief_cache.register_cache_item_updater(updater) DATA_MANAGER = DataManager(settings.SUMMARY_BASE_DIR)