|
- #! /usr/bin/python
- # -*- coding: utf-8 -*-
-
- import os
- import pickle
- import sys
- import time
- from datetime import datetime
-
- import numpy as np
- import tensorflow as tf
-
- import gridfs
- import pymongo
- from tensorlayer import logging
- from tensorlayer.files import (
- assign_weights, del_folder, exists_or_mkdir, load_hdf5_to_weights, save_weights_to_hdf5, static_graph2net
- )
-
-
- class TensorHub(object):
- """It is a MongoDB based manager that help you to manage data, network architecture, parameters and logging.
-
- Parameters
- -------------
- ip : str
- Localhost or IP address.
- port : int
- Port number.
- dbname : str
- Database name.
- username : str or None
- User name, set to None if you do not need authentication.
- password : str
- Password.
- project_name : str or None
- Experiment key for this entire project, similar with the repository name of Github.
-
- Attributes
- ------------
- ip, port, dbname and other input parameters : see above
- See above.
- project_name : str
- The given project name, if no given, set to the script name.
- db : mongodb client
- See ``pymongo.MongoClient``.
- """
-
- # @deprecated_alias(db_name='dbname', user_name='username', end_support_version=2.1)
- def __init__(
- self, ip='localhost', port=27017, dbname='dbname', username='None', password='password', project_name=None
- ):
- self.ip = ip
- self.port = port
- self.dbname = dbname
- self.username = username
-
- print("[Database] Initializing ...")
- # connect mongodb
- client = pymongo.MongoClient(ip, port)
- self.db = client[dbname]
- if username is None:
- print(username, password)
- self.db.authenticate(username, password)
- else:
- print("[Database] No username given, it works if authentication is not required")
- if project_name is None:
- self.project_name = sys.argv[0].split('.')[0]
- print("[Database] No project_name given, use {}".format(self.project_name))
- else:
- self.project_name = project_name
-
- # define file system (Buckets)
- self.dataset_fs = gridfs.GridFS(self.db, collection="datasetFilesystem")
- self.model_fs = gridfs.GridFS(self.db, collection="modelfs")
- # self.params_fs = gridfs.GridFS(self.db, collection="parametersFilesystem")
- # self.architecture_fs = gridfs.GridFS(self.db, collection="architectureFilesystem")
-
- print("[Database] Connected ")
- _s = "[Database] Info:\n"
- _s += " ip : {}\n".format(self.ip)
- _s += " port : {}\n".format(self.port)
- _s += " dbname : {}\n".format(self.dbname)
- _s += " username : {}\n".format(self.username)
- _s += " password : {}\n".format("*******")
- _s += " project_name : {}\n".format(self.project_name)
- self._s = _s
- print(self._s)
-
- def __str__(self):
- """Print information of databset."""
- return self._s
-
- def _fill_project_info(self, args):
- """Fill in project_name for all studies, architectures and parameters."""
- return args.update({'project_name': self.project_name})
-
- @staticmethod
- def _serialization(ps):
- """Serialize data."""
- return pickle.dumps(ps, protocol=pickle.HIGHEST_PROTOCOL) # protocol=2)
- # with open('_temp.pkl', 'wb') as file:
- # return pickle.dump(ps, file, protocol=pickle.HIGHEST_PROTOCOL)
-
- @staticmethod
- def _deserialization(ps):
- """Deseralize data."""
- return pickle.loads(ps)
-
- # =========================== MODELS ================================
- def save_model(self, network=None, model_name='model', **kwargs):
- """Save model architecture and parameters into database, timestamp will be added automatically.
-
- Parameters
- ----------
- network : TensorLayer Model
- TensorLayer Model instance.
- model_name : str
- The name/key of model.
- kwargs : other events
- Other events, such as name, accuracy, loss, step number and etc (optinal).
-
- Examples
- ---------
- Save model architecture and parameters into database.
- >>> db.save_model(net, accuracy=0.8, loss=2.3, name='second_model')
-
- Load one model with parameters from database (run this in other script)
- >>> net = db.find_top_model(accuracy=0.8, loss=2.3)
-
- Find and load the latest model.
- >>> net = db.find_top_model(sort=[("time", pymongo.DESCENDING)])
- >>> net = db.find_top_model(sort=[("time", -1)])
-
- Find and load the oldest model.
- >>> net = db.find_top_model(sort=[("time", pymongo.ASCENDING)])
- >>> net = db.find_top_model(sort=[("time", 1)])
-
- Get model information
- >>> net._accuracy
- ... 0.8
-
- Returns
- ---------
- boolean : True for success, False for fail.
- """
- kwargs.update({'model_name': model_name})
- self._fill_project_info(kwargs) # put project_name into kwargs
-
- # params = network.get_all_params()
- params = network.all_weights
-
- s = time.time()
-
- # kwargs.update({'architecture': network.all_graphs, 'time': datetime.utcnow()})
- kwargs.update({'architecture': network.config, 'time': datetime.utcnow()})
-
- try:
- params_id = self.model_fs.put(self._serialization(params))
- kwargs.update({'params_id': params_id, 'time': datetime.utcnow()})
- self.db.Model.insert_one(kwargs)
- print("[Database] Save model: SUCCESS, took: {}s".format(round(time.time() - s, 2)))
- return True
- except Exception as e:
- exc_type, exc_obj, exc_tb = sys.exc_info()
- fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
- logging.info("{} {} {} {} {}".format(exc_type, exc_obj, fname, exc_tb.tb_lineno, e))
- print("[Database] Save model: FAIL")
- return False
-
- def find_top_model(self, sort=None, model_name='model', **kwargs):
- """Finds and returns a model architecture and its parameters from the database which matches the requirement.
-
- Parameters
- ----------
- sort : List of tuple
- PyMongo sort comment, search "PyMongo find one sorting" and `collection level operations <http://api.mongodb.com/python/current/api/pymongo/collection.html>`__ for more details.
- model_name : str or None
- The name/key of model.
- kwargs : other events
- Other events, such as name, accuracy, loss, step number and etc (optinal).
-
- Examples
- ---------
- - see ``save_model``.
-
- Returns
- ---------
- network : TensorLayer Model
- Note that, the returned network contains all information of the document (record), e.g. if you saved accuracy in the document, you can get the accuracy by using ``net._accuracy``.
- """
- # print(kwargs) # {}
- kwargs.update({'model_name': model_name})
- self._fill_project_info(kwargs)
-
- s = time.time()
-
- d = self.db.Model.find_one(filter=kwargs, sort=sort)
-
- # _temp_file_name = '_find_one_model_ztemp_file'
- if d is not None:
- params_id = d['params_id']
- graphs = d['architecture']
- _datetime = d['time']
- # exists_or_mkdir(_temp_file_name, False)
- # with open(os.path.join(_temp_file_name, 'graph.pkl'), 'wb') as file:
- # pickle.dump(graphs, file, protocol=pickle.HIGHEST_PROTOCOL)
- else:
- print("[Database] FAIL! Cannot find model: {}".format(kwargs))
- return False
- try:
- params = self._deserialization(self.model_fs.get(params_id).read())
- # TODO : restore model and load weights
- network = static_graph2net(graphs)
- assign_weights(weights=params, network=network)
- # np.savez(os.path.join(_temp_file_name, 'params.npz'), params=params)
- #
- # network = load_graph_and_params(name=_temp_file_name, sess=sess)
- # del_folder(_temp_file_name)
-
- pc = self.db.Model.find(kwargs)
- print(
- "[Database] Find one model SUCCESS. kwargs:{} sort:{} save time:{} took: {}s".format(
- kwargs, sort, _datetime, round(time.time() - s, 2)
- )
- )
-
- # FIXME : not sure what's this for
- # put all informations of model into the TL layer
- # for key in d:
- # network.__dict__.update({"_%s" % key: d[key]})
-
- # check whether more parameters match the requirement
- params_id_list = pc.distinct('params_id')
- n_params = len(params_id_list)
- if n_params != 1:
- print(" Note that there are {} models match the kwargs".format(n_params))
- return network
- except Exception as e:
- exc_type, exc_obj, exc_tb = sys.exc_info()
- fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
- logging.info("{} {} {} {} {}".format(exc_type, exc_obj, fname, exc_tb.tb_lineno, e))
- return False
-
- def delete_model(self, **kwargs):
- """Delete model.
-
- Parameters
- -----------
- kwargs : logging information
- Find items to delete, leave it empty to delete all log.
- """
- self._fill_project_info(kwargs)
- self.db.Model.delete_many(kwargs)
- logging.info("[Database] Delete Model SUCCESS")
-
- # =========================== DATASET ===============================
- def save_dataset(self, dataset=None, dataset_name=None, **kwargs):
- """Saves one dataset into database, timestamp will be added automatically.
-
- Parameters
- ----------
- dataset : any type
- The dataset you want to store.
- dataset_name : str
- The name of dataset.
- kwargs : other events
- Other events, such as description, author and etc (optinal).
-
- Examples
- ----------
- Save dataset
- >>> db.save_dataset([X_train, y_train, X_test, y_test], 'mnist', description='this is a tutorial')
-
- Get dataset
- >>> dataset = db.find_top_dataset('mnist')
-
- Returns
- ---------
- boolean : Return True if save success, otherwise, return False.
- """
- self._fill_project_info(kwargs)
- if dataset_name is None:
- raise Exception("dataset_name is None, please give a dataset name")
- kwargs.update({'dataset_name': dataset_name})
-
- s = time.time()
- try:
- dataset_id = self.dataset_fs.put(self._serialization(dataset))
- kwargs.update({'dataset_id': dataset_id, 'time': datetime.utcnow()})
- self.db.Dataset.insert_one(kwargs)
- # print("[Database] Save params: {} SUCCESS, took: {}s".format(file_name, round(time.time()-s, 2)))
- print("[Database] Save dataset: SUCCESS, took: {}s".format(round(time.time() - s, 2)))
- return True
- except Exception as e:
- exc_type, exc_obj, exc_tb = sys.exc_info()
- fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
- logging.info("{} {} {} {} {}".format(exc_type, exc_obj, fname, exc_tb.tb_lineno, e))
- print("[Database] Save dataset: FAIL")
- return False
-
- def find_top_dataset(self, dataset_name=None, sort=None, **kwargs):
- """Finds and returns a dataset from the database which matches the requirement.
-
- Parameters
- ----------
- dataset_name : str
- The name of dataset.
- sort : List of tuple
- PyMongo sort comment, search "PyMongo find one sorting" and `collection level operations <http://api.mongodb.com/python/current/api/pymongo/collection.html>`__ for more details.
- kwargs : other events
- Other events, such as description, author and etc (optinal).
-
- Examples
- ---------
- Save dataset
- >>> db.save_dataset([X_train, y_train, X_test, y_test], 'mnist', description='this is a tutorial')
-
- Get dataset
- >>> dataset = db.find_top_dataset('mnist')
- >>> datasets = db.find_datasets('mnist')
-
- Returns
- --------
- dataset : the dataset or False
- Return False if nothing found.
-
- """
-
- self._fill_project_info(kwargs)
- if dataset_name is None:
- raise Exception("dataset_name is None, please give a dataset name")
- kwargs.update({'dataset_name': dataset_name})
-
- s = time.time()
-
- d = self.db.Dataset.find_one(filter=kwargs, sort=sort)
-
- if d is not None:
- dataset_id = d['dataset_id']
- else:
- print("[Database] FAIL! Cannot find dataset: {}".format(kwargs))
- return False
- try:
- dataset = self._deserialization(self.dataset_fs.get(dataset_id).read())
- pc = self.db.Dataset.find(kwargs)
- print("[Database] Find one dataset SUCCESS, {} took: {}s".format(kwargs, round(time.time() - s, 2)))
-
- # check whether more datasets match the requirement
- dataset_id_list = pc.distinct('dataset_id')
- n_dataset = len(dataset_id_list)
- if n_dataset != 1:
- print(" Note that there are {} datasets match the requirement".format(n_dataset))
- return dataset
- except Exception as e:
- exc_type, exc_obj, exc_tb = sys.exc_info()
- fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
- logging.info("{} {} {} {} {}".format(exc_type, exc_obj, fname, exc_tb.tb_lineno, e))
- return False
-
- def find_datasets(self, dataset_name=None, **kwargs):
- """Finds and returns all datasets from the database which matches the requirement.
- In some case, the data in a dataset can be stored separately for better management.
-
- Parameters
- ----------
- dataset_name : str
- The name/key of dataset.
- kwargs : other events
- Other events, such as description, author and etc (optional).
-
- Returns
- --------
- params : the parameters, return False if nothing found.
-
- """
-
- self._fill_project_info(kwargs)
- if dataset_name is None:
- raise Exception("dataset_name is None, please give a dataset name")
- kwargs.update({'dataset_name': dataset_name})
-
- s = time.time()
- pc = self.db.Dataset.find(kwargs)
-
- if pc is not None:
- dataset_id_list = pc.distinct('dataset_id')
- dataset_list = []
- for dataset_id in dataset_id_list: # you may have multiple Buckets files
- tmp = self.dataset_fs.get(dataset_id).read()
- dataset_list.append(self._deserialization(tmp))
- else:
- print("[Database] FAIL! Cannot find any dataset: {}".format(kwargs))
- return False
-
- print("[Database] Find {} datasets SUCCESS, took: {}s".format(len(dataset_list), round(time.time() - s, 2)))
- return dataset_list
-
- def delete_datasets(self, **kwargs):
- """Delete datasets.
-
- Parameters
- -----------
- kwargs : logging information
- Find items to delete, leave it empty to delete all log.
-
- """
-
- self._fill_project_info(kwargs)
- self.db.Dataset.delete_many(kwargs)
- logging.info("[Database] Delete Dataset SUCCESS")
-
- # =========================== LOGGING ===============================
- def save_training_log(self, **kwargs):
- """Saves the training log, timestamp will be added automatically.
-
- Parameters
- -----------
- kwargs : logging information
- Events, such as accuracy, loss, step number and etc.
-
- Examples
- ---------
- >>> db.save_training_log(accuracy=0.33, loss=0.98)
-
- """
-
- self._fill_project_info(kwargs)
- kwargs.update({'time': datetime.utcnow()})
- _result = self.db.TrainLog.insert_one(kwargs)
- _log = self._print_dict(kwargs)
- logging.info("[Database] train log: " + _log)
-
- def save_validation_log(self, **kwargs):
- """Saves the validation log, timestamp will be added automatically.
-
- Parameters
- -----------
- kwargs : logging information
- Events, such as accuracy, loss, step number and etc.
-
- Examples
- ---------
- >>> db.save_validation_log(accuracy=0.33, loss=0.98)
-
- """
-
- self._fill_project_info(kwargs)
- kwargs.update({'time': datetime.utcnow()})
- _result = self.db.ValidLog.insert_one(kwargs)
- _log = self._print_dict(kwargs)
- logging.info("[Database] valid log: " + _log)
-
- def save_testing_log(self, **kwargs):
- """Saves the testing log, timestamp will be added automatically.
-
- Parameters
- -----------
- kwargs : logging information
- Events, such as accuracy, loss, step number and etc.
-
- Examples
- ---------
- >>> db.save_testing_log(accuracy=0.33, loss=0.98)
-
- """
-
- self._fill_project_info(kwargs)
- kwargs.update({'time': datetime.utcnow()})
- _result = self.db.TestLog.insert_one(kwargs)
- _log = self._print_dict(kwargs)
- logging.info("[Database] test log: " + _log)
-
- def delete_training_log(self, **kwargs):
- """Deletes training log.
-
- Parameters
- -----------
- kwargs : logging information
- Find items to delete, leave it empty to delete all log.
-
- Examples
- ---------
- Save training log
- >>> db.save_training_log(accuracy=0.33)
- >>> db.save_training_log(accuracy=0.44)
-
- Delete logs that match the requirement
- >>> db.delete_training_log(accuracy=0.33)
-
- Delete all logs
- >>> db.delete_training_log()
- """
- self._fill_project_info(kwargs)
- self.db.TrainLog.delete_many(kwargs)
- logging.info("[Database] Delete TrainLog SUCCESS")
-
- def delete_validation_log(self, **kwargs):
- """Deletes validation log.
-
- Parameters
- -----------
- kwargs : logging information
- Find items to delete, leave it empty to delete all log.
-
- Examples
- ---------
- - see ``save_training_log``.
- """
- self._fill_project_info(kwargs)
- self.db.ValidLog.delete_many(kwargs)
- logging.info("[Database] Delete ValidLog SUCCESS")
-
- def delete_testing_log(self, **kwargs):
- """Deletes testing log.
-
- Parameters
- -----------
- kwargs : logging information
- Find items to delete, leave it empty to delete all log.
-
- Examples
- ---------
- - see ``save_training_log``.
- """
- self._fill_project_info(kwargs)
- self.db.TestLog.delete_many(kwargs)
- logging.info("[Database] Delete TestLog SUCCESS")
-
- # def find_training_logs(self, **kwargs):
- # pass
- #
- # def find_validation_logs(self, **kwargs):
- # pass
- #
- # def find_testing_logs(self, **kwargs):
- # pass
-
- # =========================== Task ===================================
- def create_task(self, task_name=None, script=None, hyper_parameters=None, saved_result_keys=None, **kwargs):
- """Uploads a task to the database, timestamp will be added automatically.
-
- Parameters
- -----------
- task_name : str
- The task name.
- script : str
- File name of the python script.
- hyper_parameters : dictionary
- The hyper parameters pass into the script.
- saved_result_keys : list of str
- The keys of the task results to keep in the database when the task finishes.
- kwargs : other parameters
- Users customized parameters such as description, version number.
-
- Examples
- -----------
- Uploads a task
- >>> db.create_task(task_name='mnist', script='example/tutorial_mnist_simple.py', description='simple tutorial')
-
- Finds and runs the latest task
- >>> db.run_top_task(sort=[("time", pymongo.DESCENDING)])
- >>> db.run_top_task(sort=[("time", -1)])
-
- Finds and runs the oldest task
- >>> db.run_top_task(sort=[("time", pymongo.ASCENDING)])
- >>> db.run_top_task(sort=[("time", 1)])
-
- """
- if not isinstance(task_name, str): # is None:
- raise Exception("task_name should be string")
- if not isinstance(script, str): # is None:
- raise Exception("script should be string")
- if hyper_parameters is None:
- hyper_parameters = {}
- if saved_result_keys is None:
- saved_result_keys = []
-
- self._fill_project_info(kwargs)
- kwargs.update({'time': datetime.utcnow()})
- kwargs.update({'hyper_parameters': hyper_parameters})
- kwargs.update({'saved_result_keys': saved_result_keys})
-
- _script = open(script, 'rb').read()
-
- kwargs.update({'status': 'pending', 'script': _script, 'result': {}})
- self.db.Task.insert_one(kwargs)
- logging.info("[Database] Saved Task - task_name: {} script: {}".format(task_name, script))
-
- def run_top_task(self, task_name=None, sort=None, **kwargs):
- """Finds and runs a pending task that in the first of the sorting list.
-
- Parameters
- -----------
- task_name : str
- The task name.
- sort : List of tuple
- PyMongo sort comment, search "PyMongo find one sorting" and `collection level operations <http://api.mongodb.com/python/current/api/pymongo/collection.html>`__ for more details.
- kwargs : other parameters
- Users customized parameters such as description, version number.
-
- Examples
- ---------
- Monitors the database and pull tasks to run
- >>> while True:
- >>> print("waiting task from distributor")
- >>> db.run_top_task(task_name='mnist', sort=[("time", -1)])
- >>> time.sleep(1)
-
- Returns
- --------
- boolean : True for success, False for fail.
- """
- if not isinstance(task_name, str): # is None:
- raise Exception("task_name should be string")
- self._fill_project_info(kwargs)
- kwargs.update({'status': 'pending'})
-
- # find task and set status to running
- task = self.db.Task.find_one_and_update(kwargs, {'$set': {'status': 'running'}}, sort=sort)
-
- # try:
- # get task info e.g. hyper parameters, python script
- if task is None:
- logging.info("[Database] Find Task FAIL: key: {} sort: {}".format(task_name, sort))
- return False
- else:
- logging.info("[Database] Find Task SUCCESS: key: {} sort: {}".format(task_name, sort))
- _datetime = task['time']
- _script = task['script']
- _id = task['_id']
- _hyper_parameters = task['hyper_parameters']
- _saved_result_keys = task['saved_result_keys']
- logging.info(" hyper parameters:")
- for key in _hyper_parameters:
- globals()[key] = _hyper_parameters[key]
- logging.info(" {}: {}".format(key, _hyper_parameters[key]))
- # run task
- s = time.time()
- logging.info("[Database] Start Task: key: {} sort: {} push time: {}".format(task_name, sort, _datetime))
- _script = _script.decode('utf-8')
- with tf.Graph().as_default(): # # as graph: # clear all TF graphs
- exec(_script, globals())
-
- # set status to finished
- _ = self.db.Task.find_one_and_update({'_id': _id}, {'$set': {'status': 'finished'}})
-
- # return results
- __result = {}
- for _key in _saved_result_keys:
- logging.info(" result: {}={} {}".format(_key, globals()[_key], type(globals()[_key])))
- __result.update({"%s" % _key: globals()[_key]})
- _ = self.db.Task.find_one_and_update(
- {'_id': _id}, {'$set': {
- 'result': __result
- }}, return_document=pymongo.ReturnDocument.AFTER
- )
- logging.info(
- "[Database] Finished Task: task_name - {} sort: {} push time: {} took: {}s".format(
- task_name, sort, _datetime,
- time.time() - s
- )
- )
- return True
- # except Exception as e:
- # exc_type, exc_obj, exc_tb = sys.exc_info()
- # fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
- # logging.info("{} {} {} {} {}".format(exc_type, exc_obj, fname, exc_tb.tb_lineno, e))
- # logging.info("[Database] Fail to run task")
- # # if fail, set status back to pending
- # _ = self.db.Task.find_one_and_update({'_id': _id}, {'$set': {'status': 'pending'}})
- # return False
-
- def delete_tasks(self, **kwargs):
- """Delete tasks.
-
- Parameters
- -----------
- kwargs : logging information
- Find items to delete, leave it empty to delete all log.
-
- Examples
- ---------
- >>> db.delete_tasks()
-
- """
-
- self._fill_project_info(kwargs)
- self.db.Task.delete_many(kwargs)
- logging.info("[Database] Delete Task SUCCESS")
-
- def check_unfinished_task(self, task_name=None, **kwargs):
- """Finds and runs a pending task.
-
- Parameters
- -----------
- task_name : str
- The task name.
- kwargs : other parameters
- Users customized parameters such as description, version number.
-
- Examples
- ---------
- Wait until all tasks finish in user's local console
-
- >>> while not db.check_unfinished_task():
- >>> time.sleep(1)
- >>> print("all tasks finished")
- >>> sess = tf.InteractiveSession()
- >>> net = db.find_top_model(sess=sess, sort=[("test_accuracy", -1)])
- >>> print("the best accuracy {} is from model {}".format(net._test_accuracy, net._name))
-
- Returns
- --------
- boolean : True for success, False for fail.
-
- """
-
- if not isinstance(task_name, str): # is None:
- raise Exception("task_name should be string")
- self._fill_project_info(kwargs)
-
- kwargs.update({'$or': [{'status': 'pending'}, {'status': 'running'}]})
-
- # ## find task
- # task = self.db.Task.find_one(kwargs)
- task = self.db.Task.find(kwargs)
-
- task_id_list = task.distinct('_id')
- n_task = len(task_id_list)
-
- if n_task == 0:
- logging.info("[Database] No unfinished task - task_name: {}".format(task_name))
- return False
- else:
-
- logging.info("[Database] Find {} unfinished task - task_name: {}".format(n_task, task_name))
- return True
-
- @staticmethod
- def _print_dict(args):
- string = ''
- for key, value in args.items():
- if key is not '_id':
- string += str(key) + ": " + str(value) + " / "
- return string
|