You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_manager.py 18 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Management of all events data.
  17. This module exists to all loaders.
  18. It can read events data through the DataLoader.
  19. This module also acts as a thread pool manager.
  20. """
  21. import threading
  22. import time
  23. from concurrent.futures import ThreadPoolExecutor, wait, ALL_COMPLETED
  24. from mindinsight.conf import settings
  25. from mindinsight.datavisual.common import exceptions
  26. from mindinsight.datavisual.common.log import logger
  27. from mindinsight.datavisual.common.enums import DataManagerStatus
  28. from mindinsight.datavisual.common.enums import PluginNameEnum
  29. from mindinsight.datavisual.data_transform.loader_generators.loader_generator import MAX_DATA_LOADER_SIZE
  30. from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator
  31. from mindinsight.utils.exceptions import MindInsightException
  32. from mindinsight.utils.exceptions import ParamValueError
  33. class DataManager:
  34. """
  35. DataManager manages a pool of loader which help access events data.
  36. Each loader helps deal the data of the events.
  37. A loader corresponds to an events_data.
  38. The DataManager build a pool including all the data_loader.
  39. The data_loader provides extracting
  40. method to get the information of events.
  41. """
  42. def __init__(self, loader_generators):
  43. """
  44. Initialize the pool of loader and the dict of name-to-path.
  45. Args:
  46. loader_generators (list[LoaderGenerator]): Loader generators help generate loaders.
  47. self._status: Refer `datavisual.common.enums.DataManagerStatus`.
  48. self._loader_pool: {'loader_id': <LoaderStruct>}.
  49. """
  50. self._loader_pool = {}
  51. self._deleted_id_list = []
  52. self._status = DataManagerStatus.INIT.value
  53. self._status_mutex = threading.Lock()
  54. self._loader_pool_mutex = threading.Lock()
  55. self._max_threads_count = 30
  56. self._reload_interval = 3
  57. self._loader_generators = loader_generators
  58. def _add_loader(self, loader):
  59. """
  60. Add a loader to load data.
  61. Args:
  62. loader (LoaderStruct): A object of `Loader`.
  63. """
  64. if len(self._loader_pool) >= MAX_DATA_LOADER_SIZE:
  65. delete_number = len(self._loader_pool) - MAX_DATA_LOADER_SIZE + 1
  66. sorted_loaders = sorted(self._loader_pool.items(),
  67. key=lambda loader: loader[1].latest_update_time)
  68. for index in range(delete_number):
  69. delete_loader_id = sorted_loaders[index][0]
  70. self._delete_loader(delete_loader_id)
  71. self._loader_pool.update({loader.loader_id: loader})
  72. def _delete_loader(self, loader_id):
  73. """
  74. Delete loader from loader pool by loader id.
  75. Args:
  76. loader_id (str): ID of loader.
  77. """
  78. if self._loader_pool.get(loader_id) is not None:
  79. logger.debug("delete loader %s", loader_id)
  80. self._loader_pool.pop(loader_id)
  81. def _execute_loader(self, loader_id):
  82. """
  83. Load data form data_loader.
  84. If there is something wrong by loading, add logs and delete the loader.
  85. Args:
  86. loader_id (str): An ID for `Loader`.
  87. """
  88. try:
  89. with self._loader_pool_mutex:
  90. loader = self._loader_pool.get(loader_id, None)
  91. if loader is None:
  92. logger.debug("Loader %r has been deleted, will not load data.", loader_id)
  93. return
  94. loader.data_loader.load()
  95. except MindInsightException as ex:
  96. logger.warning("Data loader %r load data failed. "
  97. "Delete data_loader. Detail: %s", loader_id, ex)
  98. with self._loader_pool_mutex:
  99. self._delete_loader(loader_id)
  100. def start_load_data(self,
  101. reload_interval=settings.RELOAD_INTERVAL,
  102. max_threads_count=MAX_DATA_LOADER_SIZE):
  103. """
  104. Start threads for loading data.
  105. Args:
  106. reload_interval (int): Time to reload data once.
  107. max_threads_count (int): Max number of threads of execution.
  108. """
  109. logger.info("Start to load data, reload_interval: %s, "
  110. "max_threads_count: %s.", reload_interval, max_threads_count)
  111. DataManager.check_reload_interval(reload_interval)
  112. DataManager.check_max_threads_count(max_threads_count)
  113. self._reload_interval = reload_interval
  114. self._max_threads_count = max_threads_count
  115. thread = threading.Thread(target=self._reload_data,
  116. name='start_load_data_thread')
  117. thread.daemon = True
  118. thread.start()
  119. def _reload_data(self):
  120. """This function periodically loads the data."""
  121. # Let gunicorn load other modules first.
  122. time.sleep(1)
  123. while True:
  124. self._load_data()
  125. if not self._reload_interval:
  126. break
  127. time.sleep(self._reload_interval)
  128. def reload_data(self):
  129. """
  130. Reload the data once.
  131. This function needs to be used after `start_load_data` function.
  132. """
  133. logger.debug("start to reload data")
  134. thread = threading.Thread(target=self._load_data,
  135. name='reload_data_thread')
  136. thread.daemon = False
  137. thread.start()
  138. def _load_data(self):
  139. """This function will load data once and ignore it if the status is loading."""
  140. logger.info("Start to load data, reload interval: %r.", self._reload_interval)
  141. with self._status_mutex:
  142. if self.status == DataManagerStatus.LOADING.value:
  143. logger.debug("Current status is %s , will ignore to load data.", self.status)
  144. return
  145. self.status = DataManagerStatus.LOADING.value
  146. self._generate_loaders()
  147. self._execute_load_data()
  148. if not self._loader_pool:
  149. self.status = DataManagerStatus.INVALID.value
  150. else:
  151. self.status = DataManagerStatus.DONE.value
  152. logger.info("Load event data end, status: %r, and loader pool size is %r.",
  153. self.status, len(self._loader_pool))
  154. def _generate_loaders(self):
  155. """This function generates the loader from given path."""
  156. loader_dict = {}
  157. for generator in self._loader_generators:
  158. loader_dict.update(generator.generate_loaders(self._loader_pool))
  159. sorted_loaders = sorted(loader_dict.items(), key=lambda loader: loader[1].latest_update_time)
  160. latest_loaders = sorted_loaders[-MAX_DATA_LOADER_SIZE:]
  161. self._deal_loaders(latest_loaders)
  162. def _deal_loaders(self, latest_loaders):
  163. """
  164. This function determines which loaders to keep or remove or added.
  165. It is based on the given dict of loaders.
  166. Args:
  167. latest_loaders (list[dict]): A list of <loader_id: LoaderStruct>.
  168. """
  169. with self._loader_pool_mutex:
  170. for loader_id, loader in latest_loaders:
  171. if self._loader_pool.get(loader_id, None) is None:
  172. self._add_loader(loader)
  173. continue
  174. # If this loader was updated manually before,
  175. # its latest_update_time may bigger than update_time in summary.
  176. if self._loader_pool[loader_id].latest_update_time < loader.latest_update_time:
  177. self._update_loader_latest_update_time(loader_id, loader.latest_update_time)
  178. def _execute_load_data(self):
  179. """Load data through multiple threads."""
  180. threads_count = self._get_threads_count()
  181. if not threads_count:
  182. logger.info("Can not find any valid train log path to load, loader pool is empty.")
  183. return
  184. logger.info("Start to execute load data. threads_count: %s.", threads_count)
  185. with ThreadPoolExecutor(max_workers=threads_count) as executor:
  186. futures = []
  187. loader_pool = self._get_snapshot_loader_pool()
  188. for loader_id in loader_pool:
  189. future = executor.submit(self._execute_loader, loader_id)
  190. futures.append(future)
  191. wait(futures, return_when=ALL_COMPLETED)
  192. @staticmethod
  193. def check_reload_interval(reload_interval):
  194. """
  195. Check reload interval is valid.
  196. Args:
  197. reload_interval (int): Reload interval >= 0.
  198. """
  199. if not isinstance(reload_interval, int):
  200. raise ParamValueError("The value of reload interval should be integer.")
  201. if reload_interval < 0:
  202. raise ParamValueError("The value of reload interval should be >= 0.")
  203. @staticmethod
  204. def check_max_threads_count(max_threads_count):
  205. """
  206. Threads count should be a integer, and should > 0.
  207. Args:
  208. max_threads_count (int), should > 0.
  209. """
  210. if not isinstance(max_threads_count, int):
  211. raise ParamValueError("The value of max threads count should be integer.")
  212. if max_threads_count <= 0:
  213. raise ParamValueError("The value of max threads count should be > 0.")
  214. def _get_threads_count(self):
  215. """
  216. Use the maximum number of threads available.
  217. Returns:
  218. int, number of threads.
  219. """
  220. threads_count = min(self._max_threads_count, len(self._loader_pool))
  221. return threads_count
  222. def get_train_job_by_plugin(self, train_id, plugin_name):
  223. """
  224. Get a train job by train job id.
  225. If the given train job does not has the given plugin data, the tag list will be empty.
  226. Args:
  227. train_id (str): Get train job info by the given id.
  228. plugin_name (str): Get tags by given plugin.
  229. Returns:
  230. TypedDict('TrainJobEntity', {'id': str, 'name': str, 'tags': List[str]}),
  231. a train job object.
  232. """
  233. self._check_status_valid()
  234. self._check_train_job_exist(train_id, self._loader_pool)
  235. loader = self._get_loader(train_id)
  236. if loader is None:
  237. logger.warning("No valid summary log in train job %s, "
  238. "or it is not in the cache.", train_id)
  239. return None
  240. name = loader.name
  241. data_loader = loader.data_loader
  242. tags = []
  243. try:
  244. events_data = data_loader.get_events_data()
  245. tags = events_data.list_tags_by_plugin(plugin_name)
  246. except KeyError:
  247. logger.debug("Plugin name %r does not exist "
  248. "in train job %r, and set tags to empty list.", plugin_name, name)
  249. except AttributeError:
  250. logger.debug("Train job %r has been deleted or it has not loaded data, "
  251. "and set tags to empty list.", name)
  252. result = dict(id=train_id, name=name, tags=tags)
  253. return result
  254. def delete_train_job(self, train_id):
  255. """
  256. Delete train job with a train id.
  257. Args:
  258. train_id (str): ID for train job.
  259. """
  260. with self._loader_pool_mutex:
  261. self._delete_loader(train_id)
  262. def list_tensors(self, train_id, tag):
  263. """
  264. List tensors of the given train job and tag.
  265. If the tensor can not find by the given tag, will raise exception.
  266. Args:
  267. train_id (str): ID for train job.
  268. tag (str): The tag name.
  269. Returns:
  270. NamedTuple, the tuple format is `collections.namedtuple('_Tensor', ['wall_time', 'event_step', 'value'])`.
  271. the value will contain the given tag data.
  272. """
  273. self._check_status_valid()
  274. loader_pool = self._get_snapshot_loader_pool()
  275. if not self._is_loader_in_loader_pool(train_id, loader_pool):
  276. raise ParamValueError("Can not find any data in loader pool about the train job.")
  277. data_loader = loader_pool[train_id].data_loader
  278. events_data = data_loader.get_events_data()
  279. try:
  280. tensors = events_data.tensors(tag)
  281. except KeyError:
  282. error_msg = "Can not find any data in this train job by given tag."
  283. raise ParamValueError(error_msg)
  284. return tensors
  285. def _check_train_job_exist(self, train_id, loader_pool):
  286. """
  287. Check train job exist, if not exist, will raise exception.
  288. Args:
  289. train_id (str): The given train job id.
  290. loader_pool (dict[str, LoaderStruct]): Refer to self._loader_pool.
  291. Raises:
  292. ParamValueError: Can not found train job in data manager.
  293. """
  294. is_exist = False
  295. if train_id in loader_pool:
  296. return
  297. for generator in self._loader_generators:
  298. if generator.check_train_job_exist(train_id):
  299. is_exist = True
  300. break
  301. if not is_exist:
  302. raise ParamValueError("Can not find the train job in data manager.")
  303. def _is_loader_in_loader_pool(self, train_id, loader_pool):
  304. """
  305. Check train job exist, if not exist, return False. Else, return True.
  306. Args:
  307. train_id (str): The given train job id.
  308. loader_pool (dict): See self._loader_pool.
  309. Returns:
  310. bool, if loader in loader pool, return True.
  311. """
  312. if train_id in loader_pool:
  313. return True
  314. return False
  315. def _get_snapshot_loader_pool(self):
  316. """
  317. Create a snapshot of data loader pool to avoid concurrent mutation and iteration issues.
  318. Returns:
  319. dict, a copy of `self._loader_pool`.
  320. """
  321. with self._loader_pool_mutex:
  322. return dict(self._loader_pool)
  323. def _check_status_valid(self):
  324. """Check if the status is valid to load data."""
  325. if self.status == DataManagerStatus.INIT.value:
  326. raise exceptions.SummaryLogIsLoading("Data is being loaded, "
  327. "current status: %s." % self._status)
  328. def get_single_train_job(self, train_id, manual_update=False):
  329. """
  330. Get train job by train ID.
  331. Args:
  332. train_id (str): Train ID for train job.
  333. manual_update (bool): If manual update, True.
  334. Returns:
  335. dict, single train job, if can not find any data, will return None.
  336. """
  337. self._check_status_valid()
  338. self._check_train_job_exist(train_id, self._loader_pool)
  339. loader = self._get_loader(train_id, manual_update)
  340. if loader is None:
  341. logger.warning("No valid summary log in train job %s, "
  342. "or it is not in the cache.", train_id)
  343. return None
  344. train_job = loader.to_dict()
  345. train_job.pop('data_loader')
  346. plugin_data = {}
  347. for plugin_name in PluginNameEnum.list_members():
  348. job = self.get_train_job_by_plugin(train_id, plugin_name=plugin_name)
  349. if job is None:
  350. plugin_data[plugin_name] = []
  351. else:
  352. plugin_data[plugin_name] = job['tags']
  353. train_job.update({'tag_mapping': plugin_data})
  354. return train_job
  355. def _get_loader(self, train_id, manual_update=False):
  356. """
  357. Get loader by train id.
  358. Args:
  359. train_id (str): Train Id.
  360. manual_update (bool): If manual, True. Else False.
  361. Returns:
  362. LoaderStruct, the loader.
  363. """
  364. loader = None
  365. is_reload = False
  366. with self._loader_pool_mutex:
  367. if self._is_loader_in_loader_pool(train_id, self._loader_pool):
  368. loader = self._loader_pool.get(train_id)
  369. if manual_update and loader is None:
  370. for generator in self._loader_generators:
  371. tmp_loader = generator.generate_loader_by_train_id(train_id)
  372. if loader and loader.latest_update_time > tmp_loader.latest_update_time:
  373. continue
  374. loader = tmp_loader
  375. if loader is None:
  376. return None
  377. self._add_loader(loader)
  378. is_reload = True
  379. if manual_update:
  380. self._update_loader_latest_update_time(loader.loader_id)
  381. if is_reload:
  382. self.reload_data()
  383. return loader
  384. def _update_loader_latest_update_time(self, loader_id, latest_update_time=None):
  385. """
  386. Update loader with latest_update_time.
  387. Args:
  388. loader_id (str): ID of loader.
  389. latest_update_time (float): Timestamp.
  390. """
  391. if latest_update_time is None:
  392. latest_update_time = time.time()
  393. self._loader_pool[loader_id].latest_update_time = latest_update_time
  394. @property
  395. def status(self):
  396. """
  397. Get the status of data manager.
  398. Returns:
  399. DataManagerStatus, the status of data manager.
  400. """
  401. return self._status
  402. @status.setter
  403. def status(self, status):
  404. """Set data manger status."""
  405. self._status = status
  406. _loader_generators = [DataLoaderGenerator(settings.SUMMARY_BASE_DIR)]
  407. DATA_MANAGER = DataManager(_loader_generators)

MindInsight为MindSpore提供了简单易用的调优调试能力。在训练过程中,可以将标量、张量、图像、计算图、模型超参、训练耗时等数据记录到文件中,通过MindInsight可视化页面进行查看及分析。

Contributors (1)