From 990800239bd3cc03539f0901997eb39468038829 Mon Sep 17 00:00:00 2001 From: Li Hongzhang Date: Wed, 9 Sep 2020 15:50:44 +0800 Subject: [PATCH] add summary loading switch mechanism --- mindinsight/backend/data_manager/__init__.py | 8 +- mindinsight/common/hook/datavisual.py | 29 -- mindinsight/conf/constants.py | 33 +- mindinsight/conf/defaults.py | 1 - mindinsight/datavisual/common/enums.py | 7 - .../datavisual/data_transform/data_loader.py | 9 +- .../datavisual/data_transform/data_manager.py | 330 ++++++------------ .../data_transform/ms_data_loader.py | 153 ++++---- .../processors/train_task_manager.py | 5 +- .../datavisual/utils/crc32/__init__.pyi | 7 +- mindinsight/datavisual/utils/tools.py | 2 +- mindinsight/utils/computing_resource_mgr.py | 2 +- tests/st/func/datavisual/conftest.py | 12 +- .../lineagemgr/cache/test_lineage_cache.py | 5 +- .../data_transform/test_data_loader.py | 7 +- .../data_transform/test_data_manager.py | 37 +- .../data_transform/test_ms_data_loader.py | 7 +- .../processors/test_graph_processor.py | 12 +- .../processors/test_histogram_processor.py | 7 +- .../processors/test_images_processor.py | 7 +- .../processors/test_scalars_processor.py | 7 +- .../processors/test_tensor_processor.py | 7 +- .../processors/test_train_task_manager.py | 6 +- tests/utils/tools.py | 17 - 24 files changed, 255 insertions(+), 462 deletions(-) diff --git a/mindinsight/backend/data_manager/__init__.py b/mindinsight/backend/data_manager/__init__.py index e05a5b53..2602d911 100644 --- a/mindinsight/backend/data_manager/__init__.py +++ b/mindinsight/backend/data_manager/__init__.py @@ -13,8 +13,8 @@ # limitations under the License. # ============================================================================ """Trigger data manager load.""" +import time -from mindinsight.conf import settings from mindinsight.datavisual.common.log import logger from mindinsight.datavisual.data_transform.data_manager import DATA_MANAGER from mindinsight.lineagemgr.cache_item_updater import LineageCacheItemUpdater @@ -31,5 +31,7 @@ def init_module(app): # Just to suppress pylint warning about unused arg. logger.debug("App: %s", type(app)) DATA_MANAGER.register_brief_cache_item_updater(LineageCacheItemUpdater()) - DATA_MANAGER.start_load_data(reload_interval=int(settings.RELOAD_INTERVAL), - max_threads_count=int(settings.MAX_THREADS_COUNT)) + # Let gunicorn load other modules first. + time.sleep(1) + + DATA_MANAGER.start_load_data(auto_reload=True) diff --git a/mindinsight/common/hook/datavisual.py b/mindinsight/common/hook/datavisual.py index cf3db546..eca174b5 100644 --- a/mindinsight/common/hook/datavisual.py +++ b/mindinsight/common/hook/datavisual.py @@ -17,29 +17,9 @@ import argparse import os -from mindinsight.conf import settings from mindinsight.utils.hook import BaseHook -class ReloadIntervalAction(argparse.Action): - """Reload interval action class definition.""" - - def __call__(self, parser, namespace, values, option_string=None): - """ - Inherited __call__ method from argparse.Action. - - Args: - parser (ArgumentParser): Passed-in argument parser. - namespace (Namespace): Namespace object to hold arguments. - values (object): Argument values with type depending on argument definition. - option_string (str): Option string for specific argument name. - """ - reload_interval = values - if reload_interval < 0: - parser.error(f'{option_string} should be greater than or equal to 0') - setattr(namespace, self.dest, reload_interval) - - class SummaryBaseDirAction(argparse.Action): """Summary base dir action class definition.""" @@ -67,15 +47,6 @@ class Hook(BaseHook): Args: parser (ArgumentParser): Specify parser to which arguments are added. """ - parser.add_argument( - '--reload-interval', - type=int, - action=ReloadIntervalAction, - help=""" - data reload time(Seconds). It should be greater than 0 or equal to 0. - If it equals 0, load data only once. Default value is %s seconds. - """ % settings.RELOAD_INTERVAL) - parser.add_argument( '--summary-base-dir', type=str, diff --git a/mindinsight/conf/constants.py b/mindinsight/conf/constants.py index 391182ac..38c77c6b 100644 --- a/mindinsight/conf/constants.py +++ b/mindinsight/conf/constants.py @@ -14,35 +14,7 @@ # ============================================================================ """Constants module for mindinsight settings.""" import logging -import math -import os - - -_DEFAULT_MAX_THREADS_COUNT = 15 - - -def _calc_default_max_processes_cnt(): - """Calc default processes count.""" - - # We need to make sure every summary directory has a process to load data. - min_cnt = _DEFAULT_MAX_THREADS_COUNT - # Do not use too many processes to avoid system problems (eg. out of memory). - max_cnt = 45 - used_cpu_ratio = 0.75 - - cpu_count = os.cpu_count() - if cpu_count is None: - return min_cnt - - processes_cnt = math.floor(cpu_count * used_cpu_ratio) - - if processes_cnt < min_cnt: - return min_cnt - - if processes_cnt > max_cnt: - return max_cnt - - return processes_cnt +import multiprocessing #################################### @@ -77,8 +49,7 @@ API_PREFIX = '/v1/mindinsight' #################################### # Datavisual default settings. #################################### -MAX_THREADS_COUNT = _DEFAULT_MAX_THREADS_COUNT -MAX_PROCESSES_COUNT = _calc_default_max_processes_cnt() +MAX_PROCESSES_COUNT = max(min(int(multiprocessing.cpu_count() * 0.75), 45), 1) MAX_TAG_SIZE_PER_EVENTS_DATA = 300 DEFAULT_STEP_SIZES_PER_TAG = 500 diff --git a/mindinsight/conf/defaults.py b/mindinsight/conf/defaults.py index b554869e..14c7fd98 100644 --- a/mindinsight/conf/defaults.py +++ b/mindinsight/conf/defaults.py @@ -29,5 +29,4 @@ URL_PATH_PREFIX = '' #################################### # Datavisual default settings. #################################### -RELOAD_INTERVAL = 3 # Seconds SUMMARY_BASE_DIR = os.getcwd() diff --git a/mindinsight/datavisual/common/enums.py b/mindinsight/datavisual/common/enums.py index 4c18e4eb..b5f452be 100644 --- a/mindinsight/datavisual/common/enums.py +++ b/mindinsight/datavisual/common/enums.py @@ -32,13 +32,6 @@ class DataManagerStatus(BaseEnum): INVALID = 'INVALID' -class DetailCacheManagerStatus(BaseEnum): - """Data manager status.""" - INIT = 'INIT' - LOADING = 'LOADING' - DONE = 'DONE' - - class PluginNameEnum(BaseEnum): """Plugin Name Enum.""" IMAGE = 'image' diff --git a/mindinsight/datavisual/data_transform/data_loader.py b/mindinsight/datavisual/data_transform/data_loader.py index fdba646f..349f023c 100644 --- a/mindinsight/datavisual/data_transform/data_loader.py +++ b/mindinsight/datavisual/data_transform/data_loader.py @@ -34,11 +34,14 @@ class DataLoader: self._summary_dir = summary_dir self._loader = None - def load(self, computing_resource_mgr): + def load(self, executor=None): """Load the data when loader is exist. Args: - computing_resource_mgr (ComputingResourceManager): The ComputingResourceManager instance. + executor (Optional[Executor]): The executor instance. + + Returns: + bool, True if the loader is finished loading. """ if self._loader is None: @@ -53,7 +56,7 @@ class DataLoader: logger.warning("No valid files can be loaded, summary_dir: %s.", self._summary_dir) raise exceptions.SummaryLogPathInvalid() - self._loader.load(computing_resource_mgr) + return self._loader.load(executor) def get_events_data(self): """ diff --git a/mindinsight/datavisual/data_transform/data_manager.py b/mindinsight/datavisual/data_transform/data_manager.py index 292066bb..ca6801af 100644 --- a/mindinsight/datavisual/data_transform/data_manager.py +++ b/mindinsight/datavisual/data_transform/data_manager.py @@ -27,15 +27,13 @@ import time import os from typing import Iterable, Optional -from concurrent.futures import ThreadPoolExecutor, wait, ALL_COMPLETED - from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher from mindinsight.conf import settings from mindinsight.datavisual.common import exceptions from mindinsight.datavisual.common.enums import CacheStatus from mindinsight.datavisual.common.log import logger -from mindinsight.datavisual.common.enums import DataManagerStatus, DetailCacheManagerStatus +from mindinsight.datavisual.common.enums import DataManagerStatus from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.exceptions import TrainJobNotExistError from mindinsight.datavisual.data_transform.loader_generators.loader_generator import MAX_DATA_LOADER_SIZE @@ -294,7 +292,8 @@ class BaseCacheItemUpdater(abc.ABC): class _BaseCacheManager: """Base class for cache manager.""" - def __init__(self): + def __init__(self, summary_base_dir): + self._summary_base_dir = summary_base_dir # Use dict to remove duplicate updaters. self._updaters = {} @@ -342,40 +341,17 @@ class _BaseCacheManager: """Whether this cache manager has train jobs.""" return bool(self._cache_items) - def update_cache(self, disk_train_jobs: Iterable[_BasicTrainJob]): + def update_cache(self, executor): """ Update cache according to given train jobs on disk. Different cache manager should implement different cache update policies in this method. Args: - disk_train_jobs (Iterable[_BasicTrainJob]): Train jobs on disk. + executor (Executor): The Executor instance. """ raise NotImplementedError() - def _merge_with_disk(self, disk_train_jobs: Iterable[_BasicTrainJob]): - """ - Merge train jobs in cache with train jobs from disk - - This method will remove train jobs not on disk. Call this function with lock for thread safety. - - Args: - disk_train_jobs (Iterable[_BasicTrainJob]): Basic train jobs info from disk. - - Returns: - dict, a dict containing train jobs to be cached. - """ - new_cache_items = {} - for train_job in disk_train_jobs: - if train_job.train_id not in self._cache_items: - new_cache_items[train_job.train_id] = CachedTrainJob(train_job) - else: - reused_train_job = self._cache_items[train_job.train_id] - reused_train_job.basic_info = train_job - new_cache_items[train_job.train_id] = reused_train_job - - return new_cache_items - class _BriefCacheManager(_BaseCacheManager): """A cache manager that holds all disk train jobs on disk.""" @@ -394,15 +370,57 @@ class _BriefCacheManager(_BaseCacheManager): return False - def update_cache(self, disk_train_jobs): + def update_cache(self, executor): """Update cache.""" + logger.info('Start to update BriefCacheManager.') + summaries_info = SummaryWatcher().list_summary_directories(self._summary_base_dir) + + basic_train_jobs = [] + for info in summaries_info: + profiler = info['profiler'] + basic_train_jobs.append(_BasicTrainJob( + train_id=info['relative_path'], + abs_summary_base_dir=self._summary_base_dir, + abs_summary_dir=os.path.realpath(os.path.join( + self._summary_base_dir, + info['relative_path'] + )), + create_time=info['create_time'], + update_time=info['update_time'], + profiler_dir=None if profiler is None else profiler['directory'], + profiler_type="" if profiler is None else profiler['profiler_type'], + )) + with self._lock: - new_cache_items = self._merge_with_disk(disk_train_jobs) + new_cache_items = self._merge_with_disk(basic_train_jobs) self._cache_items = new_cache_items for updater in self._updaters.values(): for cache_item in self._cache_items.values(): updater.update_item(cache_item) + def _merge_with_disk(self, disk_train_jobs: Iterable[_BasicTrainJob]): + """ + Merge train jobs in cache with train jobs from disk + + This method will remove train jobs not on disk. Call this function with lock for thread safety. + + Args: + disk_train_jobs (Iterable[_BasicTrainJob]): Basic train jobs info from disk. + + Returns: + dict, a dict containing train jobs to be cached. + """ + new_cache_items = {} + for train_job in disk_train_jobs: + if train_job.train_id not in self._cache_items: + new_cache_items[train_job.train_id] = CachedTrainJob(train_job) + else: + reused_train_job = self._cache_items[train_job.train_id] + reused_train_job.basic_info = train_job + new_cache_items[train_job.train_id] = reused_train_job + + return new_cache_items + @property def cache_items(self): """Get cache items.""" @@ -417,21 +435,14 @@ DATAVISUAL_CACHE_KEY = "datavisual" class _DetailCacheManager(_BaseCacheManager): """A cache manager that holds detailed info for most recently used train jobs.""" - def __init__(self, loader_generators): - super().__init__() + def __init__(self, summary_base_dir): + super().__init__(summary_base_dir) self._loader_pool = {} self._deleted_id_list = [] self._loader_pool_mutex = threading.Lock() - self._max_threads_count = 30 - self._loader_generators = loader_generators - self._status = DetailCacheManagerStatus.INIT.value + self._loader_generators = [DataLoaderGenerator(summary_base_dir)] self._loading_mutex = threading.Lock() - @property - def status(self): - """Get loading status, if it is loading, return True.""" - return self._status - def has_content(self): """Whether this cache manager has train jobs.""" return bool(self._loader_pool) @@ -451,37 +462,22 @@ class _DetailCacheManager(_BaseCacheManager): """Get loader pool size.""" return len(self._loader_pool) - def _load_in_cache(self): - """Generate and execute loaders.""" - def load(): - self._generate_loaders() - self._execute_load_data() - try: - exception_wrapper(load()) - except UnknownError as ex: - logger.warning("Load event data failed. Detail: %s.", str(ex)) - finally: - self._status = DetailCacheManagerStatus.DONE.value - logger.info("Load event data end, status: %r, and loader pool size is %r.", - self._status, self.loader_pool_size()) - - def update_cache(self, disk_train_jobs: Iterable[_BasicTrainJob]): + def update_cache(self, executor): """ Update cache. Will switch to using disk_train_jobs in the future. Args: - disk_train_jobs (Iterable[_BasicTrainJob]): Basic info about train jobs on disk. - + executor (Executor): The Executor instance. """ with self._loading_mutex: - if self._status == DetailCacheManagerStatus.LOADING.value: - logger.debug("Event data is loading, and loader pool size is %r.", self.loader_pool_size()) - return - self._status = DetailCacheManagerStatus.LOADING.value - thread = threading.Thread(target=self._load_in_cache, name="load_detail_in_cache") - thread.start() + load_in_cache = exception_wrapper(self._execute_load_data) + try: + while not load_in_cache(executor): + yield + except UnknownError as ex: + logger.warning("Load event data failed. Detail: %s.", str(ex)) def cache_train_job(self, train_id): """Cache given train job.""" @@ -501,11 +497,6 @@ class _DetailCacheManager(_BaseCacheManager): if loader is None: raise TrainJobNotExistError(train_id) - # Update cache status loader to CACHING if loader is NOT_IN_CACHE - # before triggering the next interval. - if loader.cache_status == CacheStatus.NOT_IN_CACHE: - loader.cache_status = CacheStatus.CACHING - self._add_loader(loader) need_reload = True @@ -546,7 +537,7 @@ class _DetailCacheManager(_BaseCacheManager): logger.debug("delete loader %s", loader_id) self._loader_pool.pop(loader_id) - def _execute_loader(self, loader_id, computing_resource_mgr): + def _execute_loader(self, loader_id, executor): """ Load data form data_loader. @@ -554,20 +545,25 @@ class _DetailCacheManager(_BaseCacheManager): Args: loader_id (str): An ID for `Loader`. - computing_resource_mgr (ComputingResourceManager): The ComputingResourceManager instance. + executor (Executor): The Executor instance. + + Returns: + bool, True if the loader is finished loading. """ try: with self._loader_pool_mutex: loader = self._loader_pool.get(loader_id, None) if loader is None: logger.debug("Loader %r has been deleted, will not load data.", loader_id) - return + return True - loader.data_loader.load(computing_resource_mgr) - - # Update loader cache status to CACHED. - # Loader with cache status CACHED should remain the same cache status. - loader.cache_status = CacheStatus.CACHED + loader.cache_status = CacheStatus.CACHING + if loader.data_loader.load(executor): + # Update loader cache status to CACHED. + # Loader with cache status CACHED should remain the same cache status. + loader.cache_status = CacheStatus.CACHED + return True + return False except MindInsightException as ex: logger.warning("Data loader %r load data failed. " @@ -575,6 +571,7 @@ class _DetailCacheManager(_BaseCacheManager): with self._loader_pool_mutex: self._delete_loader(loader_id) + return True def _generate_loaders(self): """This function generates the loader from given path.""" @@ -607,38 +604,14 @@ class _DetailCacheManager(_BaseCacheManager): if self._loader_pool[loader_id].latest_update_time < loader.latest_update_time: self._update_loader_latest_update_time(loader_id, loader.latest_update_time) - def _execute_load_data(self): + def _execute_load_data(self, executor): """Load data through multiple threads.""" - threads_count = self._get_threads_count() - if not threads_count: - logger.info("Can not find any valid train log path to load, loader pool is empty.") - return - - logger.info("Start to execute load data. threads_count: %s.", threads_count) - - with ComputingResourceManager( - executors_cnt=threads_count, - max_processes_cnt=settings.MAX_PROCESSES_COUNT) as computing_resource_mgr: - - with ThreadPoolExecutor(max_workers=threads_count) as executor: - futures = [] - loader_pool = self._get_snapshot_loader_pool() - for loader_id in loader_pool: - future = executor.submit(self._execute_loader, loader_id, computing_resource_mgr) - futures.append(future) - wait(futures, return_when=ALL_COMPLETED) - - def _get_threads_count(self): - """ - Use the maximum number of threads available. - - Returns: - int, number of threads. - - """ - threads_count = min(self._max_threads_count, len(self._loader_pool)) - - return threads_count + self._generate_loaders() + loader_pool = self._get_snapshot_loader_pool() + loaded = True + for loader_id in loader_pool: + loaded = self._execute_loader(loader_id, executor) and loaded + return loaded def delete_train_job(self, train_id): """ @@ -864,11 +837,8 @@ class DataManager: self._status = DataManagerStatus.INIT.value self._status_mutex = threading.Lock() - self._reload_interval = 3 - - loader_generators = [DataLoaderGenerator(self._summary_base_dir)] - self._detail_cache = _DetailCacheManager(loader_generators) - self._brief_cache = _BriefCacheManager() + self._detail_cache = _DetailCacheManager(self._summary_base_dir) + self._brief_cache = _BriefCacheManager(self._summary_base_dir) # This lock is used to make sure that only one self._load_data_in_thread() is running. # Because self._load_data_in_thread() will create process pool when loading files, we can not @@ -880,126 +850,58 @@ class DataManager: """Get summary base dir.""" return self._summary_base_dir - def start_load_data(self, - reload_interval=settings.RELOAD_INTERVAL, - max_threads_count=MAX_DATA_LOADER_SIZE): + def start_load_data(self, auto_reload=False): """ Start threads for loading data. - Args: - reload_interval (int): Time to reload data once. - max_threads_count (int): Max number of threads of execution. - - """ - logger.info("Start to load data, reload_interval: %s, " - "max_threads_count: %s.", reload_interval, max_threads_count) - DataManager.check_reload_interval(reload_interval) - DataManager.check_max_threads_count(max_threads_count) - - self._reload_interval = reload_interval - self._max_threads_count = max_threads_count - - thread = threading.Thread(target=self._reload_data_in_thread, - name='start_load_data_thread') - thread.daemon = True - thread.start() - - def _reload_data_in_thread(self): - """This function periodically loads the data.""" - # Let gunicorn load other modules first. - time.sleep(1) - while True: - self._load_data_in_thread_wrapper() - - if not self._reload_interval: - break - time.sleep(self._reload_interval) - - def reload_data(self): - """ - Reload the data once. - - This function needs to be used after `start_load_data` function. + Returns: + Thread, the background Thread instance. """ - logger.debug("start to reload data") + logger.info("Start to load data") thread = threading.Thread(target=self._load_data_in_thread_wrapper, - name='reload_data_thread') - thread.daemon = False + name='start_load_data_thread', + args=(auto_reload,), + daemon=True) + thread.daemon = True thread.start() + return thread - def _load_data_in_thread_wrapper(self): + def _load_data_in_thread_wrapper(self, auto_reload): """Wrapper for load data in thread.""" + if self._load_data_lock.locked(): + return try: with self._load_data_lock: - exception_wrapper(self._load_data()) + while True: + exception_wrapper(self._load_data)() + if not auto_reload: + break except UnknownError as exc: # Not raising the exception here to ensure that data reloading does not crash. logger.warning(exc.message) def _load_data(self): """This function will load data once and ignore it if the status is loading.""" - logger.info("Start to load data, reload interval: %r.", self._reload_interval) with self._status_mutex: if self.status == DataManagerStatus.LOADING.value: logger.debug("Current status is %s , will ignore to load data.", self.status) return self.status = DataManagerStatus.LOADING.value - summaries_info = SummaryWatcher().list_summary_directories(self._summary_base_dir) - - basic_train_jobs = [] - for info in summaries_info: - profiler = info['profiler'] - basic_train_jobs.append(_BasicTrainJob( - train_id=info['relative_path'], - abs_summary_base_dir=self._summary_base_dir, - abs_summary_dir=os.path.realpath(os.path.join( - self._summary_base_dir, - info['relative_path'] - )), - create_time=info['create_time'], - update_time=info['update_time'], - profiler_dir=None if profiler is None else profiler['directory'], - profiler_type="" if profiler is None else profiler['profiler_type'], - )) - - self._brief_cache.update_cache(basic_train_jobs) - self._detail_cache.update_cache(basic_train_jobs) - - if not self._brief_cache.has_content() and not self._detail_cache.has_content() \ - and self._detail_cache.status == DetailCacheManagerStatus.DONE.value: - self.status = DataManagerStatus.INVALID.value - else: - self.status = DataManagerStatus.DONE.value - - logger.info("Load brief data end, and loader pool size is %r.", self._detail_cache.loader_pool_size()) - - @staticmethod - def check_reload_interval(reload_interval): - """ - Check reload interval is valid. - - Args: - reload_interval (int): Reload interval >= 0. - """ - if not isinstance(reload_interval, int): - raise ParamValueError("The value of reload interval should be integer.") - - if reload_interval < 0: - raise ParamValueError("The value of reload interval should be >= 0.") - - @staticmethod - def check_max_threads_count(max_threads_count): - """ - Threads count should be a integer, and should > 0. - - Args: - max_threads_count (int), should > 0. - """ - if not isinstance(max_threads_count, int): - raise ParamValueError("The value of max threads count should be integer.") - if max_threads_count <= 0: - raise ParamValueError("The value of max threads count should be > 0.") + with ComputingResourceManager(executors_cnt=1, + max_processes_cnt=settings.MAX_PROCESSES_COUNT) as computing_resource_mgr: + with computing_resource_mgr.get_executor() as executor: + self._brief_cache.update_cache(executor) + for _ in self._detail_cache.update_cache(executor): + self._brief_cache.update_cache(executor) + executor.wait_all_tasks_finish() + with self._status_mutex: + if not self._brief_cache.has_content() and not self._detail_cache.has_content(): + self.status = DataManagerStatus.INVALID.value + else: + self.status = DataManagerStatus.DONE.value + + logger.info("Load brief data end, and loader pool size is %r.", self._detail_cache.loader_pool_size()) def get_train_job_by_plugin(self, train_id, plugin_name): """ @@ -1093,7 +995,7 @@ class DataManager: brief_need_reload = self._brief_cache.cache_train_job(train_id) detail_need_reload = self._detail_cache.cache_train_job(train_id) if brief_need_reload or detail_need_reload: - self.reload_data() + self.start_load_data() def register_brief_cache_item_updater(self, updater: BaseCacheItemUpdater): """Register brief cache item updater for brief cache manager.""" @@ -1107,9 +1009,5 @@ class DataManager: """Get brief train job.""" return self._brief_cache.get_train_job(train_id) - def get_detail_cache_status(self): - """Get detail status, just for ut/st.""" - return self._detail_cache.status - DATA_MANAGER = DataManager(settings.SUMMARY_BASE_DIR) diff --git a/mindinsight/datavisual/data_transform/ms_data_loader.py b/mindinsight/datavisual/data_transform/ms_data_loader.py index 250a573e..f96780dd 100644 --- a/mindinsight/datavisual/data_transform/ms_data_loader.py +++ b/mindinsight/datavisual/data_transform/ms_data_loader.py @@ -39,6 +39,7 @@ from mindinsight.datavisual.data_transform.tensor_container import TensorContain from mindinsight.datavisual.proto_files import mindinsight_anf_ir_pb2 as anf_ir_pb2 from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2 from mindinsight.datavisual.utils import crc32 +from mindinsight.utils.computing_resource_mgr import ComputingResourceManager, Executor from mindinsight.utils.exceptions import UnknownError HEADER_SIZE = 8 @@ -81,16 +82,44 @@ class MSDataLoader: "we will reload all files in path %s.", self._summary_dir) self.__init__(self._summary_dir) - def load(self, computing_resource_mgr): + def load(self, executor=None): """ Load all log valid files. When the file is reloaded, it will continue to load from where it left off. Args: - computing_resource_mgr (ComputingResourceManager): The ComputingResourceManager instance. + executor (Optional[executor]): The Executor instance. + + Returns: + bool, True if the train job is finished loading. """ logger.debug("Start to load data in ms data loader.") + if isinstance(executor, Executor): + return self._load(executor) + + if executor is not None: + raise TypeError("'executor' should be an Executor instance or None.") + + with ComputingResourceManager() as mgr: + with mgr.get_executor() as new_executor: + while not self._load(new_executor): + pass + new_executor.wait_all_tasks_finish() + return True + + def _load(self, executor): + """ + Load all log valid files. + + When the file is reloaded, it will continue to load from where it left off. + + Args: + executor (executor): The Executor instance. + + Returns: + bool, True if the train job is finished loading. + """ filenames = self.filter_valid_files() if not filenames: logger.warning("No valid files can be loaded, summary_dir: %s.", self._summary_dir) @@ -99,9 +128,10 @@ class MSDataLoader: self._valid_filenames = filenames self._check_files_deleted(filenames, old_filenames) - with computing_resource_mgr.get_executor() as executor: - for parser in self._parser_list: - parser.parse_files(executor, filenames, events_data=self._events_data) + finished = True + for parser in self._parser_list: + finished = parser.parse_files(executor, filenames, events_data=self._events_data) and finished + return finished def filter_valid_files(self): """ @@ -127,9 +157,8 @@ class _Parser: """Parsed base class.""" def __init__(self, summary_dir): - self._latest_filename = '' - self._latest_mtime = 0 self._summary_dir = summary_dir + self._latest_filename = '' def parse_files(self, executor, filenames, events_data): """ @@ -142,12 +171,6 @@ class _Parser: """ raise NotImplementedError - def sort_files(self, filenames): - """Sort by modify time increments and filenames increments.""" - filenames = sorted(filenames, key=lambda file: ( - FileHandler.file_stat(FileHandler.join(self._summary_dir, file)).mtime, file)) - return filenames - def filter_files(self, filenames): """ Gets a list of files that this parsing class can parse. @@ -160,30 +183,14 @@ class _Parser: """ raise NotImplementedError - def _set_latest_file(self, filename): - """ - Check if the file's modification time is newer than the last time it was loaded, and if so, set the time. - - Args: - filename (str): The file name that needs to be checked and set. - - Returns: - bool, Returns True if the file was modified earlier than the last time it was loaded, or False. - """ - mtime = FileHandler.file_stat(FileHandler.join(self._summary_dir, filename)).mtime - if mtime < self._latest_mtime or \ - (mtime == self._latest_mtime and filename <= self._latest_filename): - return False - - self._latest_mtime = mtime - self._latest_filename = filename - - return True - class _PbParser(_Parser): """This class is used to parse pb file.""" + def __init__(self, summary_dir): + super(_PbParser, self).__init__(summary_dir) + self._latest_mtime = 0 + def parse_files(self, executor, filenames, events_data): pb_filenames = self.filter_files(filenames) pb_filenames = self.sort_files(pb_filenames) @@ -198,6 +205,8 @@ class _PbParser(_Parser): continue events_data.add_tensor_event(tensor_event) + return False + return True def filter_files(self, filenames): """ @@ -208,9 +217,38 @@ class _PbParser(_Parser): Returns: list[str], filename list. + + Returns: + bool, True if all the pb files are finished loading. """ return list(filter(lambda filename: re.search(r'\.pb$', filename), filenames)) + def sort_files(self, filenames): + """Sort by modify time increments and filenames increments.""" + filenames = sorted(filenames, key=lambda file: ( + FileHandler.file_stat(FileHandler.join(self._summary_dir, file)).mtime, file)) + return filenames + + def _set_latest_file(self, filename): + """ + Check if the file's modification time is newer than the last time it was loaded, and if so, set the time. + + Args: + filename (str): The file name that needs to be checked and set. + + Returns: + bool, Returns True if the file was modified earlier than the last time it was loaded, or False. + """ + mtime = FileHandler.file_stat(FileHandler.join(self._summary_dir, filename)).mtime + if mtime < self._latest_mtime or \ + (mtime == self._latest_mtime and filename <= self._latest_filename): + return False + + self._latest_mtime = mtime + self._latest_filename = filename + + return True + def _parse_pb_file(self, filename): """ Parse pb file and write content to `EventsData`. @@ -270,16 +308,18 @@ class _SummaryParser(_Parser): executor (Executor): The executor instance. filenames (list[str]): File name list. events_data (EventsData): The container of event data. + + Returns: + bool, True if all the summary files are finished loading. """ self._events_data = events_data summary_files = self.filter_files(filenames) summary_files = self.sort_files(summary_files) + if self._latest_filename in summary_files: + index = summary_files.index(self._latest_filename) + summary_files = summary_files[index:] for filename in summary_files: - if self._latest_filename and \ - (self._compare_summary_file(self._latest_filename, filename)): - continue - file_path = FileHandler.join(self._summary_dir, filename) if filename != self._latest_filename: @@ -291,15 +331,18 @@ class _SummaryParser(_Parser): if new_size == self._latest_file_size: continue - self._latest_file_size = new_size try: - self._load_single_file(self._summary_file_handler, executor) + if not self._load_single_file(self._summary_file_handler, executor): + self._latest_file_size = self._summary_file_handler.offset + else: + self._latest_file_size = new_size # Wait for data in this file to be processed to avoid loading multiple files at the same time. - executor.wait_all_tasks_finish() - logger.info("Parse summary file finished, file path: %s.", file_path) + logger.info("Parse summary file offset %d, file path: %s.", self._latest_file_size, file_path) + return False except UnknownError as ex: logger.warning("Parse summary file failed, detail: %r," "file path: %s.", str(ex), file_path) + return True def filter_files(self, filenames): """ @@ -322,6 +365,9 @@ class _SummaryParser(_Parser): Args: file_handler (FileHandler): A file handler. executor (Executor): The executor instance. + + Returns: + bool, True if the summary file is finished loading. """ while True: start_offset = file_handler.offset @@ -329,7 +375,7 @@ class _SummaryParser(_Parser): event_str = self._event_load(file_handler) if event_str is None: file_handler.reset_offset(start_offset) - break + return True if len(event_str) > MAX_EVENT_STRING: logger.warning("file_path: %s, event string: %d exceeds %d and drop it.", file_handler.file_path, len(event_str), MAX_EVENT_STRING) @@ -358,15 +404,16 @@ class _SummaryParser(_Parser): raise future.add_done_callback(_add_tensor_event_callback) + return False except exceptions.CRCFailedError: file_handler.reset_offset(start_offset) logger.warning("Check crc faild and ignore this file, file_path=%s, " "offset=%s.", file_handler.file_path, file_handler.offset) - break + return True except (OSError, DecodeError, exceptions.MindInsightException) as ex: logger.warning("Parse log file fail, and ignore this file, detail: %r," "file path: %s.", str(ex), file_handler.file_path) - break + return True except Exception as ex: logger.exception(ex) raise UnknownError(str(ex)) @@ -509,24 +556,6 @@ class _SummaryParser(_Parser): return ret_tensor_events - @staticmethod - def _compare_summary_file(current_file, dst_file): - """ - Compare the creation times of the two summary log files. - - Args: - current_file (str): Must be the summary log file path. - dst_file (str): Must be the summary log file path. - - Returns: - bool, returns True if the current file is new, or False if not. - """ - current_time = int(re.search(r'summary\.(\d+)', current_file)[1]) - dst_time = int(re.search(r'summary\.(\d+)', dst_file)[1]) - if current_time > dst_time or (current_time == dst_time and current_file > dst_file): - return True - return False - def sort_files(self, filenames): """Sort by creating time increments and filenames decrement.""" filenames = sorted(filenames, diff --git a/mindinsight/datavisual/processors/train_task_manager.py b/mindinsight/datavisual/processors/train_task_manager.py index 178739bf..4aef0e0e 100644 --- a/mindinsight/datavisual/processors/train_task_manager.py +++ b/mindinsight/datavisual/processors/train_task_manager.py @@ -189,10 +189,7 @@ class TrainTaskManager(BaseProcessor): logger.warning('Train job %s not existed', train_id) continue - if train_job.cache_status == CacheStatus.NOT_IN_CACHE: - self._data_manager.cache_train_job(train_id) - # Update loader cache status to CACHING for consistency in response. - train_job.cache_status = CacheStatus.CACHING + self._data_manager.cache_train_job(train_id) cache_result.append(dict( train_id=train_id, diff --git a/mindinsight/datavisual/utils/crc32/__init__.pyi b/mindinsight/datavisual/utils/crc32/__init__.pyi index 0e98770b..d9e930f4 100644 --- a/mindinsight/datavisual/utils/crc32/__init__.pyi +++ b/mindinsight/datavisual/utils/crc32/__init__.pyi @@ -13,11 +13,14 @@ # limitations under the License. # ============================================================================ """crc32 type stub module.""" +from typing import Union +ByteStr = Union[bytes, str] -def CheckValueAgainstData(crc_value: bytes, data: bytes, size: int) -> bool: + +def CheckValueAgainstData(crc_value: ByteStr, data: ByteStr, size: int) -> bool: """Check crc_value against new crc value from data to see if data is currupted.""" -def GetMaskCrc32cValue(data: bytes, n: int) -> int: +def GetMaskCrc32cValue(data: ByteStr, n: int) -> int: """Get masked crc value from data.""" diff --git a/mindinsight/datavisual/utils/tools.py b/mindinsight/datavisual/utils/tools.py index 05b4b4ea..e7cdac1e 100644 --- a/mindinsight/datavisual/utils/tools.py +++ b/mindinsight/datavisual/utils/tools.py @@ -221,7 +221,7 @@ def if_nan_inf_to_none(name, value): def exception_wrapper(func): def wrapper(*args, **kwargs): try: - func(*args, **kwargs) + return func(*args, **kwargs) except Exception as exc: logger.exception(exc) raise UnknownError(str(exc)) diff --git a/mindinsight/utils/computing_resource_mgr.py b/mindinsight/utils/computing_resource_mgr.py index c5035530..21aeed9a 100644 --- a/mindinsight/utils/computing_resource_mgr.py +++ b/mindinsight/utils/computing_resource_mgr.py @@ -37,7 +37,7 @@ class ComputingResourceManager: executors_cnt (int): Number of executors to be provided by this class. max_processes_cnt (int): Max number of processes to be used for computing. """ - def __init__(self, executors_cnt, max_processes_cnt): + def __init__(self, executors_cnt=1, max_processes_cnt=4): self._max_processes_cnt = max_processes_cnt self._executors_cnt = executors_cnt self._lock = threading.Lock() diff --git a/tests/st/func/datavisual/conftest.py b/tests/st/func/datavisual/conftest.py index 0fb0ec8c..f7556e25 100644 --- a/tests/st/func/datavisual/conftest.py +++ b/tests/st/func/datavisual/conftest.py @@ -29,7 +29,6 @@ from mindinsight.datavisual.data_transform.loader_generators.loader_generator im from mindinsight.datavisual.utils import tools from ....utils.log_operations import LogOperations -from ....utils.tools import check_loading_done from . import constants from . import globals as gbl @@ -59,8 +58,7 @@ def init_summary_logs(): summaries_metadata = log_operations.create_summary_logs(summary_base_dir, constants.SUMMARY_DIR_NUM_FIRST, constants.SUMMARY_DIR_PREFIX) mock_data_manager = DataManager(summary_base_dir) - mock_data_manager.start_load_data(reload_interval=0) - check_loading_done(mock_data_manager) + mock_data_manager.start_load_data().join() summaries_metadata.update( log_operations.create_summary_logs(summary_base_dir, constants.SUMMARY_DIR_NUM_SECOND, @@ -72,10 +70,7 @@ def init_summary_logs(): summaries_metadata.update( log_operations.create_reservoir_log(summary_base_dir, constants.RESERVOIR_DIR_NAME, constants.RESERVOIR_STEP_NUM)) - mock_data_manager.start_load_data(reload_interval=0) - - # Sleep 1 sec to make sure the status of mock_data_manager changed to LOADING. - check_loading_done(mock_data_manager, first_sleep_time=1) + mock_data_manager.start_load_data().join() # Maximum number of loads is `MAX_DATA_LOADER_SIZE`. for i in range(len(summaries_metadata) - MAX_DATA_LOADER_SIZE): @@ -98,8 +93,7 @@ def populate_globals(): def client(): """This fixture is flask client.""" - gbl.mock_data_manager.start_load_data(reload_interval=0) - check_loading_done(gbl.mock_data_manager) + gbl.mock_data_manager.start_load_data().join() data_manager.DATA_MANAGER = gbl.mock_data_manager diff --git a/tests/st/func/lineagemgr/cache/test_lineage_cache.py b/tests/st/func/lineagemgr/cache/test_lineage_cache.py index 69bc6550..8e53f2d3 100644 --- a/tests/st/func/lineagemgr/cache/test_lineage_cache.py +++ b/tests/st/func/lineagemgr/cache/test_lineage_cache.py @@ -32,7 +32,7 @@ from ..test_model import LINEAGE_INFO_RUN1, LINEAGE_FILTRATION_EXCEPT_RUN, \ LINEAGE_FILTRATION_RUN1, LINEAGE_FILTRATION_RUN2 from ..conftest import BASE_SUMMARY_DIR from .....ut.lineagemgr.querier import event_data -from .....utils.tools import check_loading_done, assert_equal_lineages +from .....utils.tools import assert_equal_lineages @pytest.mark.usefixtures("create_summary_dir") @@ -42,8 +42,7 @@ class TestModelApi(TestCase): def setup_class(cls): data_manager = DataManager(BASE_SUMMARY_DIR) data_manager.register_brief_cache_item_updater(LineageCacheItemUpdater()) - data_manager.start_load_data(reload_interval=0) - check_loading_done(data_manager) + data_manager.start_load_data().join() cls._data_manger = data_manager diff --git a/tests/ut/datavisual/data_transform/test_data_loader.py b/tests/ut/datavisual/data_transform/test_data_loader.py index 359ef07c..5ab023d3 100644 --- a/tests/ut/datavisual/data_transform/test_data_loader.py +++ b/tests/ut/datavisual/data_transform/test_data_loader.py @@ -27,7 +27,6 @@ import pytest from mindinsight.datavisual.common.exceptions import SummaryLogPathInvalid from mindinsight.datavisual.data_transform import data_loader from mindinsight.datavisual.data_transform.data_loader import DataLoader -from mindinsight.utils.computing_resource_mgr import ComputingResourceManager from ..mock import MockLogger @@ -58,7 +57,7 @@ class TestDataLoader: """Test loading method with empty file list.""" loader = DataLoader(self._summary_dir) with pytest.raises(SummaryLogPathInvalid): - loader.load(ComputingResourceManager(1, 1)) + loader.load() assert 'No valid files can be loaded' in str(MockLogger.log_msg['warning']) def test_load_with_invalid_file_list(self): @@ -67,7 +66,7 @@ class TestDataLoader: self._generate_files(self._summary_dir, file_list) loader = DataLoader(self._summary_dir) with pytest.raises(SummaryLogPathInvalid): - loader.load(ComputingResourceManager(1, 1)) + loader.load() assert 'No valid files can be loaded' in str(MockLogger.log_msg['warning']) def test_load_success(self): @@ -78,6 +77,6 @@ class TestDataLoader: file_list = ['summary.001', 'summary.002'] self._generate_files(dir_path, file_list) dataloader = DataLoader(dir_path) - dataloader.load(ComputingResourceManager(1, 1)) + dataloader.load() assert dataloader._loader is not None shutil.rmtree(dir_path) diff --git a/tests/ut/datavisual/data_transform/test_data_manager.py b/tests/ut/datavisual/data_transform/test_data_manager.py index 32a3ddf8..f8e8d821 100644 --- a/tests/ut/datavisual/data_transform/test_data_manager.py +++ b/tests/ut/datavisual/data_transform/test_data_manager.py @@ -38,7 +38,6 @@ from mindinsight.datavisual.data_transform.loader_generators.loader_struct impor from mindinsight.datavisual.data_transform.ms_data_loader import MSDataLoader from mindinsight.utils.exceptions import ParamValueError -from ....utils.tools import check_loading_done from ..mock import MockLogger @@ -90,31 +89,9 @@ class TestDataManager: data_manager.logger = MockLogger mock_manager = data_manager.DataManager(summary_base_dir) - mock_manager.start_load_data(reload_interval=0) + mock_manager.start_load_data().join() - check_loading_done(mock_manager) - - assert MockLogger.log_msg['info'] == "Load event data end, status: 'DONE', " \ - "and loader pool size is '3'." - shutil.rmtree(summary_base_dir) - - @pytest.mark.parametrize('params', [{ - 'reload_interval': '30' - }, { - 'reload_interval': -1 - }, { - 'reload_interval': 30, - 'max_threads_count': '20' - }, { - 'reload_interval': 30, - 'max_threads_count': 0 - }]) - def test_start_load_data_with_invalid_params(self, params): - """Test start_load_data with invalid reload_interval or invalid max_threads_count.""" - summary_base_dir = tempfile.mkdtemp() - d_manager = DataManager(summary_base_dir) - with pytest.raises(ParamValueError): - d_manager.start_load_data(**params) + assert MockLogger.log_msg['info'] == "Load brief data end, and loader pool size is '3'." shutil.rmtree(summary_base_dir) def test_list_tensors_success(self): @@ -201,10 +178,9 @@ class TestDataManager: mock_generate_loaders.return_value = loader_dict mock_data_manager = data_manager.DataManager(summary_base_dir) - mock_data_manager._detail_cache._execute_load_data = Mock() + mock_data_manager._detail_cache._execute_loader = Mock() - mock_data_manager.start_load_data(reload_interval=0) - check_loading_done(mock_data_manager, 3) + mock_data_manager.start_load_data().join() current_loader_ids = mock_data_manager._detail_cache._loader_pool.keys() assert sorted(current_loader_ids) == sorted(expected_loader_ids) @@ -215,11 +191,8 @@ class TestDataManager: expected_loader_ids.extend(list(loader_dict.keys())) expected_loader_ids = expected_loader_ids[-MAX_DATA_LOADER_SIZE:] - # Make sure to finish loading, make it init. - mock_data_manager._detail_cache._status = DataManagerStatus.INIT.value mock_generate_loaders.return_value = loader_dict - mock_data_manager.start_load_data(reload_interval=0) - check_loading_done(mock_data_manager) + mock_data_manager.start_load_data().join() current_loader_ids = mock_data_manager._detail_cache._loader_pool.keys() assert sorted(current_loader_ids) == sorted(expected_loader_ids) diff --git a/tests/ut/datavisual/data_transform/test_ms_data_loader.py b/tests/ut/datavisual/data_transform/test_ms_data_loader.py index 41275e44..c8530615 100644 --- a/tests/ut/datavisual/data_transform/test_ms_data_loader.py +++ b/tests/ut/datavisual/data_transform/test_ms_data_loader.py @@ -30,7 +30,6 @@ from mindinsight.datavisual.data_transform.ms_data_loader import MSDataLoader from mindinsight.datavisual.data_transform.ms_data_loader import _PbParser from mindinsight.datavisual.data_transform.events_data import TensorEvent from mindinsight.datavisual.common.enums import PluginNameEnum -from mindinsight.utils.computing_resource_mgr import ComputingResourceManager from ..mock import MockLogger from ....utils.log_generators.graph_pb_generator import create_graph_pb_file @@ -86,7 +85,7 @@ class TestMsDataLoader: write_file(file1, SCALAR_RECORD) ms_loader = MSDataLoader(summary_dir) ms_loader._latest_summary_filename = 'summary.00' - ms_loader.load(ComputingResourceManager(1, 1)) + ms_loader.load() shutil.rmtree(summary_dir) tag = ms_loader.get_events_data().list_tags_by_plugin('scalar') tensors = ms_loader.get_events_data().tensors(tag[0]) @@ -99,7 +98,7 @@ class TestMsDataLoader: file2 = os.path.join(summary_dir, 'summary.02') write_file(file2, SCALAR_RECORD) ms_loader = MSDataLoader(summary_dir) - ms_loader.load(ComputingResourceManager(1, 1)) + ms_loader.load() shutil.rmtree(summary_dir) assert 'Check crc faild and ignore this file' in str(MockLogger.log_msg['warning']) @@ -125,7 +124,7 @@ class TestMsDataLoader: summary_dir = tempfile.mkdtemp() create_graph_pb_file(output_dir=summary_dir, filename=filename) ms_loader = MSDataLoader(summary_dir) - ms_loader.load(ComputingResourceManager(1, 1)) + ms_loader.load() events_data = ms_loader.get_events_data() plugins = events_data.list_tags_by_plugin(PluginNameEnum.GRAPH.value) shutil.rmtree(summary_dir) diff --git a/tests/ut/datavisual/processors/test_graph_processor.py b/tests/ut/datavisual/processors/test_graph_processor.py index 231a9545..da0d33f3 100644 --- a/tests/ut/datavisual/processors/test_graph_processor.py +++ b/tests/ut/datavisual/processors/test_graph_processor.py @@ -35,7 +35,7 @@ from mindinsight.datavisual.utils import crc32 from mindinsight.utils.exceptions import ParamValueError from ....utils.log_operations import LogOperations -from ....utils.tools import check_loading_done, compare_result_with_file, delete_files_or_dirs +from ....utils.tools import compare_result_with_file, delete_files_or_dirs from ..mock import MockLogger @@ -74,10 +74,7 @@ class TestGraphProcessor: self._generated_path.append(summary_base_dir) self._mock_data_manager = data_manager.DataManager(summary_base_dir) - self._mock_data_manager.start_load_data(reload_interval=0) - - # wait for loading done - check_loading_done(self._mock_data_manager, time_limit=5) + self._mock_data_manager.start_load_data().join() @pytest.fixture(scope='function') def load_no_graph_record(self): @@ -93,10 +90,7 @@ class TestGraphProcessor: self._generated_path.append(summary_base_dir) self._mock_data_manager = data_manager.DataManager(summary_base_dir) - self._mock_data_manager.start_load_data(reload_interval=0) - - # wait for loading done - check_loading_done(self._mock_data_manager, time_limit=5) + self._mock_data_manager.start_load_data().join() @pytest.mark.usefixtures('load_graph_record') def test_get_nodes_with_not_exist_train_id(self): diff --git a/tests/ut/datavisual/processors/test_histogram_processor.py b/tests/ut/datavisual/processors/test_histogram_processor.py index 0e5a46ce..47813f1c 100644 --- a/tests/ut/datavisual/processors/test_histogram_processor.py +++ b/tests/ut/datavisual/processors/test_histogram_processor.py @@ -31,7 +31,7 @@ from mindinsight.datavisual.processors.histogram_processor import HistogramProce from mindinsight.datavisual.utils import crc32 from ....utils.log_operations import LogOperations -from ....utils.tools import check_loading_done, delete_files_or_dirs +from ....utils.tools import delete_files_or_dirs from ..mock import MockLogger @@ -72,10 +72,7 @@ class TestHistogramProcessor: self._generated_path.append(summary_base_dir) self._mock_data_manager = data_manager.DataManager(summary_base_dir) - self._mock_data_manager.start_load_data(reload_interval=0) - - # wait for loading done - check_loading_done(self._mock_data_manager, time_limit=5) + self._mock_data_manager.start_load_data().join() @pytest.mark.usefixtures('load_histogram_record') def test_get_histograms_with_not_exist_id(self): diff --git a/tests/ut/datavisual/processors/test_images_processor.py b/tests/ut/datavisual/processors/test_images_processor.py index 99846314..2ca3da01 100644 --- a/tests/ut/datavisual/processors/test_images_processor.py +++ b/tests/ut/datavisual/processors/test_images_processor.py @@ -31,7 +31,7 @@ from mindinsight.datavisual.processors.images_processor import ImageProcessor from mindinsight.datavisual.utils import crc32 from ....utils.log_operations import LogOperations -from ....utils.tools import check_loading_done, delete_files_or_dirs, get_image_tensor_from_bytes +from ....utils.tools import delete_files_or_dirs, get_image_tensor_from_bytes from ..mock import MockLogger @@ -81,10 +81,7 @@ class TestImagesProcessor: self._generated_path.append(summary_base_dir) self._mock_data_manager = data_manager.DataManager(summary_base_dir) - self._mock_data_manager.start_load_data(reload_interval=0) - - # wait for loading done - check_loading_done(self._mock_data_manager, time_limit=5) + self._mock_data_manager.start_load_data().join() @pytest.fixture(scope='function') def load_image_record(self): diff --git a/tests/ut/datavisual/processors/test_scalars_processor.py b/tests/ut/datavisual/processors/test_scalars_processor.py index 3dc29978..cf11cb3f 100644 --- a/tests/ut/datavisual/processors/test_scalars_processor.py +++ b/tests/ut/datavisual/processors/test_scalars_processor.py @@ -31,7 +31,7 @@ from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor from mindinsight.datavisual.utils import crc32 from ....utils.log_operations import LogOperations -from ....utils.tools import check_loading_done, delete_files_or_dirs +from ....utils.tools import delete_files_or_dirs from ..mock import MockLogger @@ -73,10 +73,7 @@ class TestScalarsProcessor: self._generated_path.append(summary_base_dir) self._mock_data_manager = data_manager.DataManager(summary_base_dir) - self._mock_data_manager.start_load_data(reload_interval=0) - - # wait for loading done - check_loading_done(self._mock_data_manager, time_limit=5) + self._mock_data_manager.start_load_data().join() @pytest.mark.usefixtures('load_scalar_record') def test_get_metadata_list_with_not_exist_id(self): diff --git a/tests/ut/datavisual/processors/test_tensor_processor.py b/tests/ut/datavisual/processors/test_tensor_processor.py index a3cdcef1..b41ed461 100644 --- a/tests/ut/datavisual/processors/test_tensor_processor.py +++ b/tests/ut/datavisual/processors/test_tensor_processor.py @@ -38,7 +38,7 @@ from mindinsight.utils.exceptions import ParamValueError from mindinsight.utils.exceptions import ParamMissError from ....utils.log_operations import LogOperations -from ....utils.tools import check_loading_done, delete_files_or_dirs +from ....utils.tools import delete_files_or_dirs from ..mock import MockLogger @@ -79,10 +79,7 @@ class TestTensorProcessor: self._generated_path.append(summary_base_dir) self._mock_data_manager = data_manager.DataManager(summary_base_dir) - self._mock_data_manager.start_load_data(reload_interval=0) - - # wait for loading done - check_loading_done(self._mock_data_manager, time_limit=3) + self._mock_data_manager.start_load_data().join() @pytest.mark.usefixtures('load_tensor_record') def test_get_tensors_with_not_exist_id(self): diff --git a/tests/ut/datavisual/processors/test_train_task_manager.py b/tests/ut/datavisual/processors/test_train_task_manager.py index fefce9f1..83d7e7c9 100644 --- a/tests/ut/datavisual/processors/test_train_task_manager.py +++ b/tests/ut/datavisual/processors/test_train_task_manager.py @@ -31,7 +31,7 @@ from mindinsight.datavisual.processors.train_task_manager import TrainTaskManage from mindinsight.datavisual.utils import crc32 from ....utils.log_operations import LogOperations -from ....utils.tools import check_loading_done, delete_files_or_dirs +from ....utils.tools import delete_files_or_dirs from ..mock import MockLogger @@ -97,9 +97,7 @@ class TestTrainTaskManager: self._generated_path.append(self._root_dir) self._mock_data_manager = data_manager.DataManager(self._root_dir) - self._mock_data_manager.start_load_data(reload_interval=0) - - check_loading_done(self._mock_data_manager, time_limit=30) + self._mock_data_manager.start_load_data().join() @pytest.mark.usefixtures('load_data') def test_get_single_train_task_with_not_exists_train_id(self): diff --git a/tests/utils/tools.py b/tests/utils/tools.py index 0d347f89..6c15cb39 100644 --- a/tests/utils/tools.py +++ b/tests/utils/tools.py @@ -18,7 +18,6 @@ Description: This file is used for some common util. import io import os import shutil -import time import json from urllib.parse import urlencode @@ -26,8 +25,6 @@ from urllib.parse import urlencode import numpy as np from PIL import Image -from mindinsight.datavisual.common.enums import DetailCacheManagerStatus - def get_url(url, params): """ @@ -54,20 +51,6 @@ def delete_files_or_dirs(path_list): os.remove(path) -def check_loading_done(data_manager, time_limit=15, first_sleep_time=0): - """If loading data for more than `time_limit` seconds, exit.""" - if first_sleep_time > 0: - time.sleep(first_sleep_time) - start_time = time.time() - while data_manager.get_detail_cache_status() != DetailCacheManagerStatus.DONE.value: - time_used = time.time() - start_time - if time_used > time_limit: - break - time.sleep(0.1) - continue - return bool(data_manager.get_detail_cache_status == DetailCacheManagerStatus.DONE.value) - - def get_image_tensor_from_bytes(image_string): """Get image tensor from bytes.""" img = Image.open(io.BytesIO(image_string))