You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_manager.py 18 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Management of all events data.
  17. This module exists to all loaders.
  18. It can read events data through the DataLoader.
  19. This module also acts as a thread pool manager.
  20. """
  21. import threading
  22. import time
  23. from concurrent.futures import ThreadPoolExecutor, wait, ALL_COMPLETED
  24. from mindinsight.conf import settings
  25. from mindinsight.datavisual.common import exceptions
  26. from mindinsight.datavisual.common.log import logger
  27. from mindinsight.datavisual.common.enums import DataManagerStatus
  28. from mindinsight.datavisual.common.enums import PluginNameEnum
  29. from mindinsight.datavisual.common.exceptions import TrainJobNotExistError
  30. from mindinsight.datavisual.data_transform.loader_generators.loader_generator import MAX_DATA_LOADER_SIZE
  31. from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator
  32. from mindinsight.utils.exceptions import MindInsightException
  33. from mindinsight.utils.exceptions import ParamValueError
  34. class DataManager:
  35. """
  36. DataManager manages a pool of loader which help access events data.
  37. Each loader helps deal the data of the events.
  38. A loader corresponds to an events_data.
  39. The DataManager build a pool including all the data_loader.
  40. The data_loader provides extracting
  41. method to get the information of events.
  42. """
  43. def __init__(self, loader_generators):
  44. """
  45. Initialize the pool of loader and the dict of name-to-path.
  46. Args:
  47. loader_generators (list[LoaderGenerator]): Loader generators help generate loaders.
  48. self._status: Refer `datavisual.common.enums.DataManagerStatus`.
  49. self._loader_pool: {'loader_id': <LoaderStruct>}.
  50. """
  51. self._loader_pool = {}
  52. self._deleted_id_list = []
  53. self._status = DataManagerStatus.INIT.value
  54. self._status_mutex = threading.Lock()
  55. self._loader_pool_mutex = threading.Lock()
  56. self._max_threads_count = 30
  57. self._reload_interval = 3
  58. self._loader_generators = loader_generators
  59. def _add_loader(self, loader):
  60. """
  61. Add a loader to load data.
  62. Args:
  63. loader (LoaderStruct): A object of `Loader`.
  64. """
  65. if len(self._loader_pool) >= MAX_DATA_LOADER_SIZE:
  66. delete_number = len(self._loader_pool) - MAX_DATA_LOADER_SIZE + 1
  67. sorted_loaders = sorted(self._loader_pool.items(),
  68. key=lambda loader: loader[1].latest_update_time)
  69. for index in range(delete_number):
  70. delete_loader_id = sorted_loaders[index][0]
  71. self._delete_loader(delete_loader_id)
  72. self._loader_pool.update({loader.loader_id: loader})
  73. def _delete_loader(self, loader_id):
  74. """
  75. Delete loader from loader pool by loader id.
  76. Args:
  77. loader_id (str): ID of loader.
  78. """
  79. if self._loader_pool.get(loader_id) is not None:
  80. logger.debug("delete loader %s", loader_id)
  81. self._loader_pool.pop(loader_id)
  82. def _execute_loader(self, loader_id):
  83. """
  84. Load data form data_loader.
  85. If there is something wrong by loading, add logs and delete the loader.
  86. Args:
  87. loader_id (str): An ID for `Loader`.
  88. """
  89. try:
  90. with self._loader_pool_mutex:
  91. loader = self._loader_pool.get(loader_id, None)
  92. if loader is None:
  93. logger.debug("Loader %r has been deleted, will not load data.", loader_id)
  94. return
  95. loader.data_loader.load()
  96. except MindInsightException as ex:
  97. logger.warning("Data loader %r load data failed. "
  98. "Delete data_loader. Detail: %s", loader_id, ex)
  99. with self._loader_pool_mutex:
  100. self._delete_loader(loader_id)
  101. def start_load_data(self,
  102. reload_interval=settings.RELOAD_INTERVAL,
  103. max_threads_count=MAX_DATA_LOADER_SIZE):
  104. """
  105. Start threads for loading data.
  106. Args:
  107. reload_interval (int): Time to reload data once.
  108. max_threads_count (int): Max number of threads of execution.
  109. """
  110. logger.info("Start to load data, reload_interval: %s, "
  111. "max_threads_count: %s.", reload_interval, max_threads_count)
  112. DataManager.check_reload_interval(reload_interval)
  113. DataManager.check_max_threads_count(max_threads_count)
  114. self._reload_interval = reload_interval
  115. self._max_threads_count = max_threads_count
  116. thread = threading.Thread(target=self._reload_data,
  117. name='start_load_data_thread')
  118. thread.daemon = True
  119. thread.start()
  120. def _reload_data(self):
  121. """This function periodically loads the data."""
  122. # Let gunicorn load other modules first.
  123. time.sleep(1)
  124. while True:
  125. self._load_data()
  126. if not self._reload_interval:
  127. break
  128. time.sleep(self._reload_interval)
  129. def reload_data(self):
  130. """
  131. Reload the data once.
  132. This function needs to be used after `start_load_data` function.
  133. """
  134. logger.debug("start to reload data")
  135. thread = threading.Thread(target=self._load_data,
  136. name='reload_data_thread')
  137. thread.daemon = False
  138. thread.start()
  139. def _load_data(self):
  140. """This function will load data once and ignore it if the status is loading."""
  141. logger.info("Start to load data, reload interval: %r.", self._reload_interval)
  142. with self._status_mutex:
  143. if self.status == DataManagerStatus.LOADING.value:
  144. logger.debug("Current status is %s , will ignore to load data.", self.status)
  145. return
  146. self.status = DataManagerStatus.LOADING.value
  147. self._generate_loaders()
  148. self._execute_load_data()
  149. if not self._loader_pool:
  150. self.status = DataManagerStatus.INVALID.value
  151. else:
  152. self.status = DataManagerStatus.DONE.value
  153. logger.info("Load event data end, status: %r, and loader pool size is %r.",
  154. self.status, len(self._loader_pool))
  155. def _generate_loaders(self):
  156. """This function generates the loader from given path."""
  157. loader_dict = {}
  158. for generator in self._loader_generators:
  159. loader_dict.update(generator.generate_loaders(self._loader_pool))
  160. sorted_loaders = sorted(loader_dict.items(), key=lambda loader: loader[1].latest_update_time)
  161. latest_loaders = sorted_loaders[-MAX_DATA_LOADER_SIZE:]
  162. self._deal_loaders(latest_loaders)
  163. def _deal_loaders(self, latest_loaders):
  164. """
  165. This function determines which loaders to keep or remove or added.
  166. It is based on the given dict of loaders.
  167. Args:
  168. latest_loaders (list[dict]): A list of <loader_id: LoaderStruct>.
  169. """
  170. with self._loader_pool_mutex:
  171. for loader_id, loader in latest_loaders:
  172. if self._loader_pool.get(loader_id, None) is None:
  173. self._add_loader(loader)
  174. continue
  175. # If this loader was updated manually before,
  176. # its latest_update_time may bigger than update_time in summary.
  177. if self._loader_pool[loader_id].latest_update_time < loader.latest_update_time:
  178. self._update_loader_latest_update_time(loader_id, loader.latest_update_time)
  179. def _execute_load_data(self):
  180. """Load data through multiple threads."""
  181. threads_count = self._get_threads_count()
  182. if not threads_count:
  183. logger.info("Can not find any valid train log path to load, loader pool is empty.")
  184. return
  185. logger.info("Start to execute load data. threads_count: %s.", threads_count)
  186. with ThreadPoolExecutor(max_workers=threads_count) as executor:
  187. futures = []
  188. loader_pool = self._get_snapshot_loader_pool()
  189. for loader_id in loader_pool:
  190. future = executor.submit(self._execute_loader, loader_id)
  191. futures.append(future)
  192. wait(futures, return_when=ALL_COMPLETED)
  193. @staticmethod
  194. def check_reload_interval(reload_interval):
  195. """
  196. Check reload interval is valid.
  197. Args:
  198. reload_interval (int): Reload interval >= 0.
  199. """
  200. if not isinstance(reload_interval, int):
  201. raise ParamValueError("The value of reload interval should be integer.")
  202. if reload_interval < 0:
  203. raise ParamValueError("The value of reload interval should be >= 0.")
  204. @staticmethod
  205. def check_max_threads_count(max_threads_count):
  206. """
  207. Threads count should be a integer, and should > 0.
  208. Args:
  209. max_threads_count (int), should > 0.
  210. """
  211. if not isinstance(max_threads_count, int):
  212. raise ParamValueError("The value of max threads count should be integer.")
  213. if max_threads_count <= 0:
  214. raise ParamValueError("The value of max threads count should be > 0.")
  215. def _get_threads_count(self):
  216. """
  217. Use the maximum number of threads available.
  218. Returns:
  219. int, number of threads.
  220. """
  221. threads_count = min(self._max_threads_count, len(self._loader_pool))
  222. return threads_count
  223. def get_train_job_by_plugin(self, train_id, plugin_name):
  224. """
  225. Get a train job by train job id.
  226. If the given train job does not has the given plugin data, the tag list will be empty.
  227. Args:
  228. train_id (str): Get train job info by the given id.
  229. plugin_name (str): Get tags by given plugin.
  230. Returns:
  231. TypedDict('TrainJobEntity', {'id': str, 'name': str, 'tags': List[str]}),
  232. a train job object.
  233. """
  234. self._check_status_valid()
  235. self._check_train_job_exist(train_id, self._loader_pool)
  236. loader = self._get_loader(train_id)
  237. if loader is None:
  238. logger.warning("No valid summary log in train job %s, "
  239. "or it is not in the cache.", train_id)
  240. return None
  241. name = loader.name
  242. data_loader = loader.data_loader
  243. tags = []
  244. try:
  245. events_data = data_loader.get_events_data()
  246. tags = events_data.list_tags_by_plugin(plugin_name)
  247. except KeyError:
  248. logger.debug("Plugin name %r does not exist "
  249. "in train job %r, and set tags to empty list.", plugin_name, name)
  250. except AttributeError:
  251. logger.debug("Train job %r has been deleted or it has not loaded data, "
  252. "and set tags to empty list.", name)
  253. result = dict(id=train_id, name=name, tags=tags)
  254. return result
  255. def delete_train_job(self, train_id):
  256. """
  257. Delete train job with a train id.
  258. Args:
  259. train_id (str): ID for train job.
  260. """
  261. with self._loader_pool_mutex:
  262. self._delete_loader(train_id)
  263. def list_tensors(self, train_id, tag):
  264. """
  265. List tensors of the given train job and tag.
  266. If the tensor can not find by the given tag, will raise exception.
  267. Args:
  268. train_id (str): ID for train job.
  269. tag (str): The tag name.
  270. Returns:
  271. NamedTuple, the tuple format is `collections.namedtuple('_Tensor', ['wall_time', 'event_step', 'value'])`.
  272. the value will contain the given tag data.
  273. """
  274. self._check_status_valid()
  275. loader_pool = self._get_snapshot_loader_pool()
  276. if not self._is_loader_in_loader_pool(train_id, loader_pool):
  277. raise TrainJobNotExistError("Can not find the given train job in cache.")
  278. data_loader = loader_pool[train_id].data_loader
  279. events_data = data_loader.get_events_data()
  280. try:
  281. tensors = events_data.tensors(tag)
  282. except KeyError:
  283. error_msg = "Can not find any data in this train job by given tag."
  284. raise ParamValueError(error_msg)
  285. return tensors
  286. def _check_train_job_exist(self, train_id, loader_pool):
  287. """
  288. Check train job exist, if not exist, will raise exception.
  289. Args:
  290. train_id (str): The given train job id.
  291. loader_pool (dict[str, LoaderStruct]): Refer to self._loader_pool.
  292. Raises:
  293. TrainJobNotExistError: Can not find train job in data manager.
  294. """
  295. is_exist = False
  296. if train_id in loader_pool:
  297. return
  298. for generator in self._loader_generators:
  299. if generator.check_train_job_exist(train_id):
  300. is_exist = True
  301. break
  302. if not is_exist:
  303. raise TrainJobNotExistError("Can not find the train job in data manager.")
  304. def _is_loader_in_loader_pool(self, train_id, loader_pool):
  305. """
  306. Check train job exist, if not exist, return False. Else, return True.
  307. Args:
  308. train_id (str): The given train job id.
  309. loader_pool (dict): See self._loader_pool.
  310. Returns:
  311. bool, if loader in loader pool, return True.
  312. """
  313. if train_id in loader_pool:
  314. return True
  315. return False
  316. def _get_snapshot_loader_pool(self):
  317. """
  318. Create a snapshot of data loader pool to avoid concurrent mutation and iteration issues.
  319. Returns:
  320. dict, a copy of `self._loader_pool`.
  321. """
  322. with self._loader_pool_mutex:
  323. return dict(self._loader_pool)
  324. def _check_status_valid(self):
  325. """Check if the status is valid to load data."""
  326. if self.status == DataManagerStatus.INIT.value:
  327. raise exceptions.SummaryLogIsLoading("Data is being loaded, current status: %s." % self._status)
  328. def get_single_train_job(self, train_id, manual_update=False):
  329. """
  330. Get train job by train ID.
  331. Args:
  332. train_id (str): Train ID for train job.
  333. manual_update (bool): If manual update, True.
  334. Returns:
  335. dict, single train job, if can not find any data, will return None.
  336. """
  337. self._check_status_valid()
  338. self._check_train_job_exist(train_id, self._loader_pool)
  339. loader = self._get_loader(train_id, manual_update)
  340. if loader is None:
  341. logger.warning("No valid summary log in train job %s, "
  342. "or it is not in the cache.", train_id)
  343. return None
  344. train_job = loader.to_dict()
  345. train_job.pop('data_loader')
  346. plugin_data = {}
  347. for plugin_name in PluginNameEnum.list_members():
  348. job = self.get_train_job_by_plugin(train_id, plugin_name=plugin_name)
  349. if job is None:
  350. plugin_data[plugin_name] = []
  351. else:
  352. plugin_data[plugin_name] = job['tags']
  353. train_job.update({'tag_mapping': plugin_data})
  354. return train_job
  355. def _get_loader(self, train_id, manual_update=False):
  356. """
  357. Get loader by train id.
  358. Args:
  359. train_id (str): Train Id.
  360. manual_update (bool): If manual, True. Else False.
  361. Returns:
  362. LoaderStruct, the loader.
  363. """
  364. loader = None
  365. is_reload = False
  366. with self._loader_pool_mutex:
  367. if self._is_loader_in_loader_pool(train_id, self._loader_pool):
  368. loader = self._loader_pool.get(train_id)
  369. if manual_update and loader is None:
  370. for generator in self._loader_generators:
  371. tmp_loader = generator.generate_loader_by_train_id(train_id)
  372. if loader and loader.latest_update_time > tmp_loader.latest_update_time:
  373. continue
  374. loader = tmp_loader
  375. if loader is None:
  376. return None
  377. self._add_loader(loader)
  378. is_reload = True
  379. if manual_update:
  380. self._update_loader_latest_update_time(loader.loader_id)
  381. if is_reload:
  382. self.reload_data()
  383. return loader
  384. def _update_loader_latest_update_time(self, loader_id, latest_update_time=None):
  385. """
  386. Update loader with latest_update_time.
  387. Args:
  388. loader_id (str): ID of loader.
  389. latest_update_time (float): Timestamp.
  390. """
  391. if latest_update_time is None:
  392. latest_update_time = time.time()
  393. self._loader_pool[loader_id].latest_update_time = latest_update_time
  394. @property
  395. def status(self):
  396. """
  397. Get the status of data manager.
  398. Returns:
  399. DataManagerStatus, the status of data manager.
  400. """
  401. return self._status
  402. @status.setter
  403. def status(self, status):
  404. """Set data manger status."""
  405. self._status = status
  406. _loader_generators = [DataLoaderGenerator(settings.SUMMARY_BASE_DIR)]
  407. DATA_MANAGER = DataManager(_loader_generators)

MindInsight为MindSpore提供了简单易用的调优调试能力。在训练过程中,可以将标量、张量、图像、计算图、模型超参、训练耗时等数据记录到文件中,通过MindInsight可视化页面进行查看及分析。