You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

example_db.py 6.3 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. import os
  2. import joblib
  3. import numpy as np
  4. from sklearn import svm
  5. from learnware.market import EasyMarket, BaseUserInfo
  6. from learnware.market import database_ops
  7. from learnware.learnware import Learnware
  8. import learnware.specification as specification
  9. from learnware.utils import get_module_by_module_path
  10. curr_root = os.path.dirname(os.path.abspath(__file__))
  11. semantic_specs = [
  12. {
  13. "Data": {"Values": ["Tabular"], "Type": "Class"},
  14. "Task": {"Values": ["Classification"], "Type": "Class"},
  15. "Library": {"Values": ["Scikit-learn"], "Type": "Class"},
  16. "Scenario": {"Values": ["Nature"], "Type": "Tag"},
  17. "Description": {"Values": "", "Type": "String"},
  18. "Name": {"Values": "learnware_1", "Type": "String"},
  19. },
  20. {
  21. "Data": {"Values": ["Tabular"], "Type": "Class"},
  22. "Task": {"Values": ["Classification"], "Type": "Class"},
  23. "Library": {"Values": ["Scikit-learn"], "Type": "Class"},
  24. "Scenario": {"Values": ["Business", "Nature"], "Type": "Tag"},
  25. "Description": {"Values": "", "Type": "String"},
  26. "Name": {"Values": "learnware_2", "Type": "String"},
  27. },
  28. {
  29. "Data": {"Values": ["Tabular"], "Type": "Class"},
  30. "Task": {"Values": ["Regression"], "Type": "Class"},
  31. "Library": {"Values": ["Scikit-learn"], "Type": "Class"},
  32. "Scenario": {"Values": ["Business"], "Type": "Tag"},
  33. "Description": {"Values": "", "Type": "String"},
  34. "Name": {"Values": "learnware_3", "Type": "String"},
  35. },
  36. ]
  37. user_senmantic = {
  38. "Data": {"Values": ["Tabular"], "Type": "Class"},
  39. "Task": {"Values": ["Classification"], "Type": "Class"},
  40. "Device": {"Values": ["GPU"], "Type": "Class"},
  41. "Scenario": {"Values": ["Business"], "Type": "Tag"},
  42. "Description": {"Values": "", "Type": "String"},
  43. "Name": {"Values": "learnware", "Type": "String"},
  44. }
  45. def prepare_learnware(learnware_num=10):
  46. np.random.seed(2023)
  47. for i in range(learnware_num):
  48. dir_path = os.path.join(curr_root, "learnware_pool", "svm_%d" % (i))
  49. os.makedirs(dir_path, exist_ok=True)
  50. print("Preparing Learnware: %d" % (i))
  51. data_X = np.random.randn(5000, 20) * i
  52. data_y = np.random.randn(5000)
  53. data_y = np.where(data_y > 0, 1, 0)
  54. clf = svm.SVC(kernel="linear")
  55. clf.fit(data_X, data_y)
  56. joblib.dump(clf, os.path.join(dir_path, "svm.pkl"))
  57. spec = specification.utils.generate_rkme_spec(X=data_X, gamma=0.1, cuda_idx=0)
  58. spec.save(os.path.join(dir_path, "svm.json"))
  59. init_file = os.path.join(dir_path, "__init__.py")
  60. os.system(f"cp example_init.py {init_file}")
  61. yaml_file = os.path.join(dir_path, "learnware.yaml")
  62. os.system(f"cp example.yaml {yaml_file}")
  63. zip_file = dir_path + ".zip"
  64. os.system(f"zip -q -r -j {zip_file} {dir_path}")
  65. os.system(f"rm -r {dir_path}")
  66. def get_zip_path_list():
  67. root_path = os.path.join(curr_root, "learnware_pool")
  68. zip_path_list = [os.path.join(root_path, path) for path in os.listdir(root_path)]
  69. return zip_path_list
  70. def test_market():
  71. database_ops.clear_learnware_table()
  72. easy_market = EasyMarket()
  73. print("Total Item:", len(easy_market))
  74. zip_path_list = get_zip_path_list() # the path list for learnware .zip
  75. for idx, zip_path in enumerate(zip_path_list):
  76. semantic_spec = semantic_specs[idx % 3]
  77. semantic_spec["Name"]["Values"] = "learnware_%d" % (idx)
  78. semantic_spec["Description"]["Values"] = "test_learnware_number_%d" % (idx)
  79. easy_market.add_learnware(zip_path, semantic_spec)
  80. print("Total Item:", len(easy_market))
  81. curr_inds = easy_market._get_ids()
  82. print("Available ids:", curr_inds)
  83. easy_market.delete_learnware(curr_inds[3])
  84. easy_market.delete_learnware(curr_inds[2])
  85. curr_inds = easy_market._get_ids()
  86. print("Available ids:", curr_inds)
  87. def test_search_semantics():
  88. easy_market = EasyMarket()
  89. print("Total Item:", len(easy_market))
  90. root_path = "./learnware_pool"
  91. os.makedirs(root_path, exist_ok=True)
  92. test_learnware_num = 3
  93. prepare_learnware(test_learnware_num)
  94. test_folder = "./test_stat"
  95. zip_path_list = get_zip_path_list()
  96. idx, zip_path = 1, zip_path_list[1]
  97. unzip_dir = os.path.join(test_folder, f"{idx}")
  98. os.makedirs(unzip_dir, exist_ok=True)
  99. os.system(f"unzip -o -q {zip_path} -d {unzip_dir}")
  100. user_spec = specification.rkme.RKMEStatSpecification()
  101. user_spec.load(os.path.join(unzip_dir, "svm.json"))
  102. user_info = BaseUserInfo(id="user_0", semantic_spec=user_senmantic)
  103. _, single_learnware_list, _ = easy_market.search_learnware(user_info)
  104. print("User info:", user_info.get_semantic_spec())
  105. print(f"search result of user{idx}:")
  106. for learnware in single_learnware_list:
  107. print("Choose learnware:", learnware.id, learnware.get_specification().get_semantic_spec())
  108. os.system(f"rm -r {test_folder}")
  109. def test_stat_search():
  110. easy_market = EasyMarket()
  111. print("Total Item:", len(easy_market))
  112. test_folder = "./test_stat"
  113. zip_path_list = get_zip_path_list()
  114. for idx, zip_path in enumerate(zip_path_list):
  115. unzip_dir = os.path.join(test_folder, f"{idx}")
  116. os.makedirs(unzip_dir, exist_ok=True)
  117. os.system(f"unzip -o -q {zip_path} -d {unzip_dir}")
  118. user_spec = specification.rkme.RKMEStatSpecification()
  119. user_spec.load(os.path.join(unzip_dir, "svm.json"))
  120. user_info = BaseUserInfo(
  121. id="user_0", semantic_spec=user_senmantic, stat_info={"RKMEStatSpecification": user_spec}
  122. )
  123. sorted_score_list, single_learnware_list, mixture_learnware_list = easy_market.search_learnware(user_info)
  124. print(f"search result of user{idx}:")
  125. for score, learnware in zip(sorted_score_list, single_learnware_list):
  126. print(f"score: {score}, learnware_id: {learnware.id}")
  127. mixture_id = " ".join([learnware.id for learnware in mixture_learnware_list])
  128. print(f"mixture_learnware: {mixture_id}\n")
  129. os.system(f"rm -r {test_folder}")
  130. if __name__ == "__main__":
  131. learnware_num = 10
  132. prepare_learnware(learnware_num)
  133. test_market()
  134. test_stat_search()
  135. test_search_semantics()

基于学件范式,全流程地支持学件上传、检测、组织、查搜、部署和复用等功能。同时,该仓库作为北冥坞系统的引擎,支撑北冥坞系统的核心功能。