You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 7.4 kB

2 years ago
2 years ago
2 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. import os
  2. import fire
  3. import zipfile
  4. import numpy as np
  5. from tqdm import tqdm
  6. from shutil import copyfile, rmtree
  7. import learnware
  8. from learnware.market import EasyMarket, BaseUserInfo
  9. from learnware.market import database_ops
  10. from learnware.learnware import Learnware, JobSelectorReuser, AveragingReuser
  11. import learnware.specification as specification
  12. from pfs import Dataloader
  13. semantic_specs = [
  14. {
  15. "Data": {"Values": ["Tabular"], "Type": "Class"},
  16. "Task": {"Values": ["Classification"], "Type": "Class"},
  17. "Library": {"Values": ["Scikit-learn"], "Type": "Class"},
  18. "Scenario": {"Values": ["Business"], "Type": "Tag"},
  19. "Description": {"Values": "", "Type": "String"},
  20. "Name": {"Values": "learnware_1", "Type": "String"},
  21. }
  22. ]
  23. user_semantic = {
  24. "Data": {"Values": ["Tabular"], "Type": "Class"},
  25. "Task": {"Values": ["Classification"], "Type": "Class"},
  26. "Library": {"Values": ["Scikit-learn"], "Type": "Class"},
  27. "Scenario": {"Values": ["Business"], "Type": "Tag"},
  28. "Description": {"Values": "", "Type": "String"},
  29. "Name": {"Values": "", "Type": "String"},
  30. }
  31. class PFSDatasetWorkflow:
  32. def _init_pfs_dataset(self):
  33. pfs = Dataloader()
  34. pfs.regenerate_data()
  35. algo_list = ["ridge", "lgb"]
  36. for algo in algo_list:
  37. pfs.set_algo(algo)
  38. pfs.retrain_models()
  39. def _init_learnware_market(self):
  40. """initialize learnware market"""
  41. learnware.init()
  42. easy_market = EasyMarket(market_id="pfs", rebuild=True)
  43. print("Total Item:", len(easy_market))
  44. zip_path_list = []
  45. curr_root = os.path.dirname(os.path.abspath(__file__))
  46. curr_root = os.path.join(curr_root, "learnware_pool")
  47. for zip_path in os.listdir(curr_root):
  48. zip_path_list.append(os.path.join(curr_root, zip_path))
  49. for idx, zip_path in enumerate(zip_path_list):
  50. semantic_spec = semantic_specs[0]
  51. semantic_spec["Name"]["Values"] = "learnware_%d" % (idx)
  52. semantic_spec["Description"]["Values"] = "test_learnware_number_%d" % (idx)
  53. easy_market.add_learnware(zip_path, semantic_spec)
  54. print("Total Item:", len(easy_market))
  55. curr_inds = easy_market._get_ids()
  56. print("Available ids:", curr_inds)
  57. def prepare_learnware(self, regenerate_flag=False):
  58. if regenerate_flag:
  59. self._init_pfs_dataset()
  60. pfs = Dataloader()
  61. idx_list = pfs.get_idx_list()
  62. algo_list = ["lgb"] # ["ridge", "lgb"]
  63. curr_root = os.path.dirname(os.path.abspath(__file__))
  64. curr_root = os.path.join(curr_root, "learnware_pool")
  65. os.makedirs(curr_root, exist_ok=True)
  66. for idx in tqdm(idx_list):
  67. train_x, train_y, test_x, test_y = pfs.get_idx_data(idx)
  68. spec = specification.utils.generate_rkme_spec(X=train_x, gamma=0.1, cuda_idx=0)
  69. for algo in algo_list:
  70. pfs.set_algo(algo)
  71. dir_path = os.path.join(curr_root, f"{algo}_{idx}")
  72. os.makedirs(dir_path, exist_ok=True)
  73. spec_path = os.path.join(dir_path, "rkme.json")
  74. spec.save(spec_path)
  75. model_path = pfs.get_model_path(idx)
  76. model_file = os.path.join(dir_path, "model.out")
  77. copyfile(model_path, model_file)
  78. init_file = os.path.join(dir_path, "__init__.py")
  79. copyfile("example_init.py", init_file)
  80. yaml_file = os.path.join(dir_path, "learnware.yaml")
  81. copyfile("example.yaml", yaml_file)
  82. zip_file = dir_path + ".zip"
  83. with zipfile.ZipFile(zip_file, "w") as zip_obj:
  84. for foldername, subfolders, filenames in os.walk(dir_path):
  85. for filename in filenames:
  86. file_path = os.path.join(foldername, filename)
  87. zip_info = zipfile.ZipInfo(filename)
  88. zip_info.compress_type = zipfile.ZIP_STORED
  89. with open(file_path, "rb") as file:
  90. zip_obj.writestr(zip_info, file.read())
  91. rmtree(dir_path)
  92. def test(self, regenerate_flag=False):
  93. self.prepare_learnware(regenerate_flag)
  94. self._init_learnware_market()
  95. easy_market = EasyMarket(market_id="pfs")
  96. print("Total Item:", len(easy_market))
  97. pfs = Dataloader()
  98. idx_list = pfs.get_idx_list()
  99. os.makedirs("./user_spec", exist_ok=True)
  100. single_score_list = []
  101. random_score_list = []
  102. job_selector_score_list = []
  103. ensemble_score_list = []
  104. for idx in idx_list:
  105. train_x, train_y, test_x, test_y = pfs.get_idx_data(idx)
  106. user_spec = specification.utils.generate_rkme_spec(X=test_x, gamma=0.1, cuda_idx=0)
  107. user_spec_path = f"./user_spec/user_{idx}.json"
  108. user_spec.save(user_spec_path)
  109. user_info = BaseUserInfo(
  110. id=f"user_{idx}", semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec}
  111. )
  112. (
  113. sorted_score_list,
  114. single_learnware_list,
  115. mixture_score,
  116. mixture_learnware_list,
  117. ) = easy_market.search_learnware(user_info)
  118. print(f"search result of user{idx}:")
  119. print(
  120. f"single model num: {len(sorted_score_list)}, max_score: {sorted_score_list[0]}, min_score: {sorted_score_list[-1]}"
  121. )
  122. loss_list = []
  123. for score, learnware in zip(sorted_score_list, single_learnware_list):
  124. pred_y = learnware.predict(test_x)
  125. loss_list.append(pfs.score(test_y, pred_y))
  126. print(
  127. f"Top1-score: {sorted_score_list[0]}, learnware_id: {single_learnware_list[0].id}, loss: {loss_list[0]}"
  128. )
  129. mixture_id = " ".join([learnware.id for learnware in mixture_learnware_list])
  130. print(f"mixture_score: {mixture_score}, mixture_learnware: {mixture_id}")
  131. reuse_job_selector = JobSelectorReuser(learnware_list=mixture_learnware_list, use_herding=False)
  132. job_selector_predict_y = reuse_job_selector.predict(user_data=test_x)
  133. job_selector_score = pfs.score(test_y, job_selector_predict_y)
  134. print(f"mixture reuse loss (job selector): {job_selector_score}")
  135. reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list)
  136. ensemble_predict_y = reuse_ensemble.predict(user_data=test_x)
  137. ensemble_score = pfs.score(test_y, ensemble_predict_y)
  138. print(f"mixture reuse loss (ensemble): {ensemble_score}\n")
  139. single_score_list.append(loss_list[0])
  140. random_score_list.append(np.mean(loss_list))
  141. job_selector_score_list.append(job_selector_score)
  142. ensemble_score_list.append(ensemble_score)
  143. print(f"Single search score: {np.mean(single_score_list)}")
  144. print(f"Job selector score: {np.mean(job_selector_score_list)}")
  145. print(f"Average ensemble score: {np.mean(ensemble_score_list)}")
  146. print(f"Random search score: {np.mean(random_score_list)}")
  147. if __name__ == "__main__":
  148. fire.Fire(PFSDatasetWorkflow)

基于学件范式,全流程地支持学件上传、检测、组织、查搜、部署和复用等功能。同时,该仓库作为北冥坞系统的引擎,支撑北冥坞系统的核心功能。