You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

db.py 27 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. import os
  4. import pickle
  5. import sys
  6. import time
  7. from datetime import datetime
  8. import numpy as np
  9. import tensorflow as tf
  10. import gridfs
  11. import pymongo
  12. from tensorlayer import logging
  13. from tensorlayer.files import (
  14. assign_weights, del_folder, exists_or_mkdir, load_hdf5_to_weights, save_weights_to_hdf5, static_graph2net
  15. )
  16. class TensorHub(object):
  17. """It is a MongoDB based manager that help you to manage data, network architecture, parameters and logging.
  18. Parameters
  19. -------------
  20. ip : str
  21. Localhost or IP address.
  22. port : int
  23. Port number.
  24. dbname : str
  25. Database name.
  26. username : str or None
  27. User name, set to None if you do not need authentication.
  28. password : str
  29. Password.
  30. project_name : str or None
  31. Experiment key for this entire project, similar with the repository name of Github.
  32. Attributes
  33. ------------
  34. ip, port, dbname and other input parameters : see above
  35. See above.
  36. project_name : str
  37. The given project name, if no given, set to the script name.
  38. db : mongodb client
  39. See ``pymongo.MongoClient``.
  40. """
  41. # @deprecated_alias(db_name='dbname', user_name='username', end_support_version=2.1)
  42. def __init__(
  43. self, ip='localhost', port=27017, dbname='dbname', username='None', password='password', project_name=None
  44. ):
  45. self.ip = ip
  46. self.port = port
  47. self.dbname = dbname
  48. self.username = username
  49. print("[Database] Initializing ...")
  50. # connect mongodb
  51. client = pymongo.MongoClient(ip, port)
  52. self.db = client[dbname]
  53. if username is None:
  54. print(username, password)
  55. self.db.authenticate(username, password)
  56. else:
  57. print("[Database] No username given, it works if authentication is not required")
  58. if project_name is None:
  59. self.project_name = sys.argv[0].split('.')[0]
  60. print("[Database] No project_name given, use {}".format(self.project_name))
  61. else:
  62. self.project_name = project_name
  63. # define file system (Buckets)
  64. self.dataset_fs = gridfs.GridFS(self.db, collection="datasetFilesystem")
  65. self.model_fs = gridfs.GridFS(self.db, collection="modelfs")
  66. # self.params_fs = gridfs.GridFS(self.db, collection="parametersFilesystem")
  67. # self.architecture_fs = gridfs.GridFS(self.db, collection="architectureFilesystem")
  68. print("[Database] Connected ")
  69. _s = "[Database] Info:\n"
  70. _s += " ip : {}\n".format(self.ip)
  71. _s += " port : {}\n".format(self.port)
  72. _s += " dbname : {}\n".format(self.dbname)
  73. _s += " username : {}\n".format(self.username)
  74. _s += " password : {}\n".format("*******")
  75. _s += " project_name : {}\n".format(self.project_name)
  76. self._s = _s
  77. print(self._s)
  78. def __str__(self):
  79. """Print information of databset."""
  80. return self._s
  81. def _fill_project_info(self, args):
  82. """Fill in project_name for all studies, architectures and parameters."""
  83. return args.update({'project_name': self.project_name})
  84. @staticmethod
  85. def _serialization(ps):
  86. """Serialize data."""
  87. return pickle.dumps(ps, protocol=pickle.HIGHEST_PROTOCOL) # protocol=2)
  88. # with open('_temp.pkl', 'wb') as file:
  89. # return pickle.dump(ps, file, protocol=pickle.HIGHEST_PROTOCOL)
  90. @staticmethod
  91. def _deserialization(ps):
  92. """Deseralize data."""
  93. return pickle.loads(ps)
  94. # =========================== MODELS ================================
  95. def save_model(self, network=None, model_name='model', **kwargs):
  96. """Save model architecture and parameters into database, timestamp will be added automatically.
  97. Parameters
  98. ----------
  99. network : TensorLayer Model
  100. TensorLayer Model instance.
  101. model_name : str
  102. The name/key of model.
  103. kwargs : other events
  104. Other events, such as name, accuracy, loss, step number and etc (optinal).
  105. Examples
  106. ---------
  107. Save model architecture and parameters into database.
  108. >>> db.save_model(net, accuracy=0.8, loss=2.3, name='second_model')
  109. Load one model with parameters from database (run this in other script)
  110. >>> net = db.find_top_model(accuracy=0.8, loss=2.3)
  111. Find and load the latest model.
  112. >>> net = db.find_top_model(sort=[("time", pymongo.DESCENDING)])
  113. >>> net = db.find_top_model(sort=[("time", -1)])
  114. Find and load the oldest model.
  115. >>> net = db.find_top_model(sort=[("time", pymongo.ASCENDING)])
  116. >>> net = db.find_top_model(sort=[("time", 1)])
  117. Get model information
  118. >>> net._accuracy
  119. ... 0.8
  120. Returns
  121. ---------
  122. boolean : True for success, False for fail.
  123. """
  124. kwargs.update({'model_name': model_name})
  125. self._fill_project_info(kwargs) # put project_name into kwargs
  126. # params = network.get_all_params()
  127. params = network.all_weights
  128. s = time.time()
  129. # kwargs.update({'architecture': network.all_graphs, 'time': datetime.utcnow()})
  130. kwargs.update({'architecture': network.config, 'time': datetime.utcnow()})
  131. try:
  132. params_id = self.model_fs.put(self._serialization(params))
  133. kwargs.update({'params_id': params_id, 'time': datetime.utcnow()})
  134. self.db.Model.insert_one(kwargs)
  135. print("[Database] Save model: SUCCESS, took: {}s".format(round(time.time() - s, 2)))
  136. return True
  137. except Exception as e:
  138. exc_type, exc_obj, exc_tb = sys.exc_info()
  139. fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
  140. logging.info("{} {} {} {} {}".format(exc_type, exc_obj, fname, exc_tb.tb_lineno, e))
  141. print("[Database] Save model: FAIL")
  142. return False
  143. def find_top_model(self, sort=None, model_name='model', **kwargs):
  144. """Finds and returns a model architecture and its parameters from the database which matches the requirement.
  145. Parameters
  146. ----------
  147. sort : List of tuple
  148. PyMongo sort comment, search "PyMongo find one sorting" and `collection level operations <http://api.mongodb.com/python/current/api/pymongo/collection.html>`__ for more details.
  149. model_name : str or None
  150. The name/key of model.
  151. kwargs : other events
  152. Other events, such as name, accuracy, loss, step number and etc (optinal).
  153. Examples
  154. ---------
  155. - see ``save_model``.
  156. Returns
  157. ---------
  158. network : TensorLayer Model
  159. Note that, the returned network contains all information of the document (record), e.g. if you saved accuracy in the document, you can get the accuracy by using ``net._accuracy``.
  160. """
  161. # print(kwargs) # {}
  162. kwargs.update({'model_name': model_name})
  163. self._fill_project_info(kwargs)
  164. s = time.time()
  165. d = self.db.Model.find_one(filter=kwargs, sort=sort)
  166. # _temp_file_name = '_find_one_model_ztemp_file'
  167. if d is not None:
  168. params_id = d['params_id']
  169. graphs = d['architecture']
  170. _datetime = d['time']
  171. # exists_or_mkdir(_temp_file_name, False)
  172. # with open(os.path.join(_temp_file_name, 'graph.pkl'), 'wb') as file:
  173. # pickle.dump(graphs, file, protocol=pickle.HIGHEST_PROTOCOL)
  174. else:
  175. print("[Database] FAIL! Cannot find model: {}".format(kwargs))
  176. return False
  177. try:
  178. params = self._deserialization(self.model_fs.get(params_id).read())
  179. # TODO : restore model and load weights
  180. network = static_graph2net(graphs)
  181. assign_weights(weights=params, network=network)
  182. # np.savez(os.path.join(_temp_file_name, 'params.npz'), params=params)
  183. #
  184. # network = load_graph_and_params(name=_temp_file_name, sess=sess)
  185. # del_folder(_temp_file_name)
  186. pc = self.db.Model.find(kwargs)
  187. print(
  188. "[Database] Find one model SUCCESS. kwargs:{} sort:{} save time:{} took: {}s".format(
  189. kwargs, sort, _datetime, round(time.time() - s, 2)
  190. )
  191. )
  192. # FIXME : not sure what's this for
  193. # put all informations of model into the TL layer
  194. # for key in d:
  195. # network.__dict__.update({"_%s" % key: d[key]})
  196. # check whether more parameters match the requirement
  197. params_id_list = pc.distinct('params_id')
  198. n_params = len(params_id_list)
  199. if n_params != 1:
  200. print(" Note that there are {} models match the kwargs".format(n_params))
  201. return network
  202. except Exception as e:
  203. exc_type, exc_obj, exc_tb = sys.exc_info()
  204. fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
  205. logging.info("{} {} {} {} {}".format(exc_type, exc_obj, fname, exc_tb.tb_lineno, e))
  206. return False
  207. def delete_model(self, **kwargs):
  208. """Delete model.
  209. Parameters
  210. -----------
  211. kwargs : logging information
  212. Find items to delete, leave it empty to delete all log.
  213. """
  214. self._fill_project_info(kwargs)
  215. self.db.Model.delete_many(kwargs)
  216. logging.info("[Database] Delete Model SUCCESS")
  217. # =========================== DATASET ===============================
  218. def save_dataset(self, dataset=None, dataset_name=None, **kwargs):
  219. """Saves one dataset into database, timestamp will be added automatically.
  220. Parameters
  221. ----------
  222. dataset : any type
  223. The dataset you want to store.
  224. dataset_name : str
  225. The name of dataset.
  226. kwargs : other events
  227. Other events, such as description, author and etc (optinal).
  228. Examples
  229. ----------
  230. Save dataset
  231. >>> db.save_dataset([X_train, y_train, X_test, y_test], 'mnist', description='this is a tutorial')
  232. Get dataset
  233. >>> dataset = db.find_top_dataset('mnist')
  234. Returns
  235. ---------
  236. boolean : Return True if save success, otherwise, return False.
  237. """
  238. self._fill_project_info(kwargs)
  239. if dataset_name is None:
  240. raise Exception("dataset_name is None, please give a dataset name")
  241. kwargs.update({'dataset_name': dataset_name})
  242. s = time.time()
  243. try:
  244. dataset_id = self.dataset_fs.put(self._serialization(dataset))
  245. kwargs.update({'dataset_id': dataset_id, 'time': datetime.utcnow()})
  246. self.db.Dataset.insert_one(kwargs)
  247. # print("[Database] Save params: {} SUCCESS, took: {}s".format(file_name, round(time.time()-s, 2)))
  248. print("[Database] Save dataset: SUCCESS, took: {}s".format(round(time.time() - s, 2)))
  249. return True
  250. except Exception as e:
  251. exc_type, exc_obj, exc_tb = sys.exc_info()
  252. fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
  253. logging.info("{} {} {} {} {}".format(exc_type, exc_obj, fname, exc_tb.tb_lineno, e))
  254. print("[Database] Save dataset: FAIL")
  255. return False
  256. def find_top_dataset(self, dataset_name=None, sort=None, **kwargs):
  257. """Finds and returns a dataset from the database which matches the requirement.
  258. Parameters
  259. ----------
  260. dataset_name : str
  261. The name of dataset.
  262. sort : List of tuple
  263. PyMongo sort comment, search "PyMongo find one sorting" and `collection level operations <http://api.mongodb.com/python/current/api/pymongo/collection.html>`__ for more details.
  264. kwargs : other events
  265. Other events, such as description, author and etc (optinal).
  266. Examples
  267. ---------
  268. Save dataset
  269. >>> db.save_dataset([X_train, y_train, X_test, y_test], 'mnist', description='this is a tutorial')
  270. Get dataset
  271. >>> dataset = db.find_top_dataset('mnist')
  272. >>> datasets = db.find_datasets('mnist')
  273. Returns
  274. --------
  275. dataset : the dataset or False
  276. Return False if nothing found.
  277. """
  278. self._fill_project_info(kwargs)
  279. if dataset_name is None:
  280. raise Exception("dataset_name is None, please give a dataset name")
  281. kwargs.update({'dataset_name': dataset_name})
  282. s = time.time()
  283. d = self.db.Dataset.find_one(filter=kwargs, sort=sort)
  284. if d is not None:
  285. dataset_id = d['dataset_id']
  286. else:
  287. print("[Database] FAIL! Cannot find dataset: {}".format(kwargs))
  288. return False
  289. try:
  290. dataset = self._deserialization(self.dataset_fs.get(dataset_id).read())
  291. pc = self.db.Dataset.find(kwargs)
  292. print("[Database] Find one dataset SUCCESS, {} took: {}s".format(kwargs, round(time.time() - s, 2)))
  293. # check whether more datasets match the requirement
  294. dataset_id_list = pc.distinct('dataset_id')
  295. n_dataset = len(dataset_id_list)
  296. if n_dataset != 1:
  297. print(" Note that there are {} datasets match the requirement".format(n_dataset))
  298. return dataset
  299. except Exception as e:
  300. exc_type, exc_obj, exc_tb = sys.exc_info()
  301. fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
  302. logging.info("{} {} {} {} {}".format(exc_type, exc_obj, fname, exc_tb.tb_lineno, e))
  303. return False
  304. def find_datasets(self, dataset_name=None, **kwargs):
  305. """Finds and returns all datasets from the database which matches the requirement.
  306. In some case, the data in a dataset can be stored separately for better management.
  307. Parameters
  308. ----------
  309. dataset_name : str
  310. The name/key of dataset.
  311. kwargs : other events
  312. Other events, such as description, author and etc (optional).
  313. Returns
  314. --------
  315. params : the parameters, return False if nothing found.
  316. """
  317. self._fill_project_info(kwargs)
  318. if dataset_name is None:
  319. raise Exception("dataset_name is None, please give a dataset name")
  320. kwargs.update({'dataset_name': dataset_name})
  321. s = time.time()
  322. pc = self.db.Dataset.find(kwargs)
  323. if pc is not None:
  324. dataset_id_list = pc.distinct('dataset_id')
  325. dataset_list = []
  326. for dataset_id in dataset_id_list: # you may have multiple Buckets files
  327. tmp = self.dataset_fs.get(dataset_id).read()
  328. dataset_list.append(self._deserialization(tmp))
  329. else:
  330. print("[Database] FAIL! Cannot find any dataset: {}".format(kwargs))
  331. return False
  332. print("[Database] Find {} datasets SUCCESS, took: {}s".format(len(dataset_list), round(time.time() - s, 2)))
  333. return dataset_list
  334. def delete_datasets(self, **kwargs):
  335. """Delete datasets.
  336. Parameters
  337. -----------
  338. kwargs : logging information
  339. Find items to delete, leave it empty to delete all log.
  340. """
  341. self._fill_project_info(kwargs)
  342. self.db.Dataset.delete_many(kwargs)
  343. logging.info("[Database] Delete Dataset SUCCESS")
  344. # =========================== LOGGING ===============================
  345. def save_training_log(self, **kwargs):
  346. """Saves the training log, timestamp will be added automatically.
  347. Parameters
  348. -----------
  349. kwargs : logging information
  350. Events, such as accuracy, loss, step number and etc.
  351. Examples
  352. ---------
  353. >>> db.save_training_log(accuracy=0.33, loss=0.98)
  354. """
  355. self._fill_project_info(kwargs)
  356. kwargs.update({'time': datetime.utcnow()})
  357. _result = self.db.TrainLog.insert_one(kwargs)
  358. _log = self._print_dict(kwargs)
  359. logging.info("[Database] train log: " + _log)
  360. def save_validation_log(self, **kwargs):
  361. """Saves the validation log, timestamp will be added automatically.
  362. Parameters
  363. -----------
  364. kwargs : logging information
  365. Events, such as accuracy, loss, step number and etc.
  366. Examples
  367. ---------
  368. >>> db.save_validation_log(accuracy=0.33, loss=0.98)
  369. """
  370. self._fill_project_info(kwargs)
  371. kwargs.update({'time': datetime.utcnow()})
  372. _result = self.db.ValidLog.insert_one(kwargs)
  373. _log = self._print_dict(kwargs)
  374. logging.info("[Database] valid log: " + _log)
  375. def save_testing_log(self, **kwargs):
  376. """Saves the testing log, timestamp will be added automatically.
  377. Parameters
  378. -----------
  379. kwargs : logging information
  380. Events, such as accuracy, loss, step number and etc.
  381. Examples
  382. ---------
  383. >>> db.save_testing_log(accuracy=0.33, loss=0.98)
  384. """
  385. self._fill_project_info(kwargs)
  386. kwargs.update({'time': datetime.utcnow()})
  387. _result = self.db.TestLog.insert_one(kwargs)
  388. _log = self._print_dict(kwargs)
  389. logging.info("[Database] test log: " + _log)
  390. def delete_training_log(self, **kwargs):
  391. """Deletes training log.
  392. Parameters
  393. -----------
  394. kwargs : logging information
  395. Find items to delete, leave it empty to delete all log.
  396. Examples
  397. ---------
  398. Save training log
  399. >>> db.save_training_log(accuracy=0.33)
  400. >>> db.save_training_log(accuracy=0.44)
  401. Delete logs that match the requirement
  402. >>> db.delete_training_log(accuracy=0.33)
  403. Delete all logs
  404. >>> db.delete_training_log()
  405. """
  406. self._fill_project_info(kwargs)
  407. self.db.TrainLog.delete_many(kwargs)
  408. logging.info("[Database] Delete TrainLog SUCCESS")
  409. def delete_validation_log(self, **kwargs):
  410. """Deletes validation log.
  411. Parameters
  412. -----------
  413. kwargs : logging information
  414. Find items to delete, leave it empty to delete all log.
  415. Examples
  416. ---------
  417. - see ``save_training_log``.
  418. """
  419. self._fill_project_info(kwargs)
  420. self.db.ValidLog.delete_many(kwargs)
  421. logging.info("[Database] Delete ValidLog SUCCESS")
  422. def delete_testing_log(self, **kwargs):
  423. """Deletes testing log.
  424. Parameters
  425. -----------
  426. kwargs : logging information
  427. Find items to delete, leave it empty to delete all log.
  428. Examples
  429. ---------
  430. - see ``save_training_log``.
  431. """
  432. self._fill_project_info(kwargs)
  433. self.db.TestLog.delete_many(kwargs)
  434. logging.info("[Database] Delete TestLog SUCCESS")
  435. # def find_training_logs(self, **kwargs):
  436. # pass
  437. #
  438. # def find_validation_logs(self, **kwargs):
  439. # pass
  440. #
  441. # def find_testing_logs(self, **kwargs):
  442. # pass
  443. # =========================== Task ===================================
  444. def create_task(self, task_name=None, script=None, hyper_parameters=None, saved_result_keys=None, **kwargs):
  445. """Uploads a task to the database, timestamp will be added automatically.
  446. Parameters
  447. -----------
  448. task_name : str
  449. The task name.
  450. script : str
  451. File name of the python script.
  452. hyper_parameters : dictionary
  453. The hyper parameters pass into the script.
  454. saved_result_keys : list of str
  455. The keys of the task results to keep in the database when the task finishes.
  456. kwargs : other parameters
  457. Users customized parameters such as description, version number.
  458. Examples
  459. -----------
  460. Uploads a task
  461. >>> db.create_task(task_name='mnist', script='example/tutorial_mnist_simple.py', description='simple tutorial')
  462. Finds and runs the latest task
  463. >>> db.run_top_task(sort=[("time", pymongo.DESCENDING)])
  464. >>> db.run_top_task(sort=[("time", -1)])
  465. Finds and runs the oldest task
  466. >>> db.run_top_task(sort=[("time", pymongo.ASCENDING)])
  467. >>> db.run_top_task(sort=[("time", 1)])
  468. """
  469. if not isinstance(task_name, str): # is None:
  470. raise Exception("task_name should be string")
  471. if not isinstance(script, str): # is None:
  472. raise Exception("script should be string")
  473. if hyper_parameters is None:
  474. hyper_parameters = {}
  475. if saved_result_keys is None:
  476. saved_result_keys = []
  477. self._fill_project_info(kwargs)
  478. kwargs.update({'time': datetime.utcnow()})
  479. kwargs.update({'hyper_parameters': hyper_parameters})
  480. kwargs.update({'saved_result_keys': saved_result_keys})
  481. _script = open(script, 'rb').read()
  482. kwargs.update({'status': 'pending', 'script': _script, 'result': {}})
  483. self.db.Task.insert_one(kwargs)
  484. logging.info("[Database] Saved Task - task_name: {} script: {}".format(task_name, script))
  485. def run_top_task(self, task_name=None, sort=None, **kwargs):
  486. """Finds and runs a pending task that in the first of the sorting list.
  487. Parameters
  488. -----------
  489. task_name : str
  490. The task name.
  491. sort : List of tuple
  492. PyMongo sort comment, search "PyMongo find one sorting" and `collection level operations <http://api.mongodb.com/python/current/api/pymongo/collection.html>`__ for more details.
  493. kwargs : other parameters
  494. Users customized parameters such as description, version number.
  495. Examples
  496. ---------
  497. Monitors the database and pull tasks to run
  498. >>> while True:
  499. >>> print("waiting task from distributor")
  500. >>> db.run_top_task(task_name='mnist', sort=[("time", -1)])
  501. >>> time.sleep(1)
  502. Returns
  503. --------
  504. boolean : True for success, False for fail.
  505. """
  506. if not isinstance(task_name, str): # is None:
  507. raise Exception("task_name should be string")
  508. self._fill_project_info(kwargs)
  509. kwargs.update({'status': 'pending'})
  510. # find task and set status to running
  511. task = self.db.Task.find_one_and_update(kwargs, {'$set': {'status': 'running'}}, sort=sort)
  512. # try:
  513. # get task info e.g. hyper parameters, python script
  514. if task is None:
  515. logging.info("[Database] Find Task FAIL: key: {} sort: {}".format(task_name, sort))
  516. return False
  517. else:
  518. logging.info("[Database] Find Task SUCCESS: key: {} sort: {}".format(task_name, sort))
  519. _datetime = task['time']
  520. _script = task['script']
  521. _id = task['_id']
  522. _hyper_parameters = task['hyper_parameters']
  523. _saved_result_keys = task['saved_result_keys']
  524. logging.info(" hyper parameters:")
  525. for key in _hyper_parameters:
  526. globals()[key] = _hyper_parameters[key]
  527. logging.info(" {}: {}".format(key, _hyper_parameters[key]))
  528. # run task
  529. s = time.time()
  530. logging.info("[Database] Start Task: key: {} sort: {} push time: {}".format(task_name, sort, _datetime))
  531. _script = _script.decode('utf-8')
  532. with tf.Graph().as_default(): # # as graph: # clear all TF graphs
  533. exec(_script, globals())
  534. # set status to finished
  535. _ = self.db.Task.find_one_and_update({'_id': _id}, {'$set': {'status': 'finished'}})
  536. # return results
  537. __result = {}
  538. for _key in _saved_result_keys:
  539. logging.info(" result: {}={} {}".format(_key, globals()[_key], type(globals()[_key])))
  540. __result.update({"%s" % _key: globals()[_key]})
  541. _ = self.db.Task.find_one_and_update(
  542. {'_id': _id}, {'$set': {
  543. 'result': __result
  544. }}, return_document=pymongo.ReturnDocument.AFTER
  545. )
  546. logging.info(
  547. "[Database] Finished Task: task_name - {} sort: {} push time: {} took: {}s".format(
  548. task_name, sort, _datetime,
  549. time.time() - s
  550. )
  551. )
  552. return True
  553. # except Exception as e:
  554. # exc_type, exc_obj, exc_tb = sys.exc_info()
  555. # fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
  556. # logging.info("{} {} {} {} {}".format(exc_type, exc_obj, fname, exc_tb.tb_lineno, e))
  557. # logging.info("[Database] Fail to run task")
  558. # # if fail, set status back to pending
  559. # _ = self.db.Task.find_one_and_update({'_id': _id}, {'$set': {'status': 'pending'}})
  560. # return False
  561. def delete_tasks(self, **kwargs):
  562. """Delete tasks.
  563. Parameters
  564. -----------
  565. kwargs : logging information
  566. Find items to delete, leave it empty to delete all log.
  567. Examples
  568. ---------
  569. >>> db.delete_tasks()
  570. """
  571. self._fill_project_info(kwargs)
  572. self.db.Task.delete_many(kwargs)
  573. logging.info("[Database] Delete Task SUCCESS")
  574. def check_unfinished_task(self, task_name=None, **kwargs):
  575. """Finds and runs a pending task.
  576. Parameters
  577. -----------
  578. task_name : str
  579. The task name.
  580. kwargs : other parameters
  581. Users customized parameters such as description, version number.
  582. Examples
  583. ---------
  584. Wait until all tasks finish in user's local console
  585. >>> while not db.check_unfinished_task():
  586. >>> time.sleep(1)
  587. >>> print("all tasks finished")
  588. >>> sess = tf.InteractiveSession()
  589. >>> net = db.find_top_model(sess=sess, sort=[("test_accuracy", -1)])
  590. >>> print("the best accuracy {} is from model {}".format(net._test_accuracy, net._name))
  591. Returns
  592. --------
  593. boolean : True for success, False for fail.
  594. """
  595. if not isinstance(task_name, str): # is None:
  596. raise Exception("task_name should be string")
  597. self._fill_project_info(kwargs)
  598. kwargs.update({'$or': [{'status': 'pending'}, {'status': 'running'}]})
  599. # ## find task
  600. # task = self.db.Task.find_one(kwargs)
  601. task = self.db.Task.find(kwargs)
  602. task_id_list = task.distinct('_id')
  603. n_task = len(task_id_list)
  604. if n_task == 0:
  605. logging.info("[Database] No unfinished task - task_name: {}".format(task_name))
  606. return False
  607. else:
  608. logging.info("[Database] Find {} unfinished task - task_name: {}".format(n_task, task_name))
  609. return True
  610. @staticmethod
  611. def _print_dict(args):
  612. string = ''
  613. for key, value in args.items():
  614. if key is not '_id':
  615. string += str(key) + ": " + str(value) + " / "
  616. return string

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.