You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 8.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. import numpy as np
  2. import torch
  3. from tqdm import tqdm
  4. from get_data import *
  5. import os
  6. import random
  7. from learnware.specification.image import RKMEImageSpecification
  8. from learnware.reuse.averaging import AveragingReuser
  9. from utils import generate_uploader, generate_user, ImageDataLoader, train, eval_prediction
  10. from learnware.learnware import Learnware
  11. import time
  12. from learnware.market import EasyMarket, BaseUserInfo
  13. from learnware.market import database_ops
  14. from learnware.learnware import Learnware
  15. import learnware.specification as specification
  16. from learnware.logger import get_module_logger
  17. from shutil import copyfile, rmtree
  18. import zipfile
  19. logger = get_module_logger("image_test", level="INFO")
  20. origin_data_root = "./data/origin_data"
  21. processed_data_root = "./data/processed_data"
  22. tmp_dir = "./data/tmp"
  23. learnware_pool_dir = "./data/learnware_pool"
  24. dataset = "cifar10"
  25. n_uploaders = 30
  26. n_users = 20
  27. n_classes = 10
  28. data_root = os.path.join(origin_data_root, dataset)
  29. data_save_root = os.path.join(processed_data_root, dataset)
  30. user_save_root = os.path.join(data_save_root, "user")
  31. uploader_save_root = os.path.join(data_save_root, "uploader")
  32. model_save_root = os.path.join(data_save_root, "uploader_model")
  33. os.makedirs(data_root, exist_ok=True)
  34. os.makedirs(user_save_root, exist_ok=True)
  35. os.makedirs(uploader_save_root, exist_ok=True)
  36. os.makedirs(model_save_root, exist_ok=True)
  37. semantic_specs = [
  38. {
  39. "Data": {"Values": ["Tabular"], "Type": "Class"},
  40. "Task": {"Values": ["Classification"], "Type": "Class"},
  41. "Library": {"Values": ["Pytorch"], "Type": "Class"},
  42. "Scenario": {"Values": ["Business"], "Type": "Tag"},
  43. "Description": {"Values": "", "Type": "String"},
  44. "Name": {"Values": "learnware_1", "Type": "String"},
  45. "Output": {"Dimension": 10},
  46. }
  47. ]
  48. user_semantic = {
  49. "Data": {"Values": ["Tabular"], "Type": "Class"},
  50. "Task": {"Values": ["Classification"], "Type": "Class"},
  51. "Library": {"Values": ["Pytorch"], "Type": "Class"},
  52. "Scenario": {"Values": ["Business"], "Type": "Tag"},
  53. "Description": {"Values": "", "Type": "String"},
  54. "Name": {"Values": "", "Type": "String"},
  55. }
  56. def prepare_data():
  57. if dataset == "cifar10":
  58. X_train, y_train, X_test, y_test = get_cifar10(data_root)
  59. elif dataset == "mnist":
  60. X_train, y_train, X_test, y_test = get_mnist(data_root)
  61. else:
  62. return
  63. generate_uploader(X_train, y_train, n_uploaders=n_uploaders, data_save_root=uploader_save_root)
  64. generate_user(X_test, y_test, n_users=n_users, data_save_root=user_save_root)
  65. def prepare_model():
  66. dataloader = ImageDataLoader(data_save_root, train=True)
  67. for i in range(n_uploaders):
  68. logger.info("Train on uploader: %d" % (i))
  69. X, y = dataloader.get_idx_data(i)
  70. model = train(X, y, out_classes=n_classes)
  71. model_save_path = os.path.join(model_save_root, "uploader_%d.pth" % (i))
  72. torch.save(model.state_dict(), model_save_path)
  73. logger.info("Model saved to '%s'" % (model_save_path))
  74. def prepare_learnware(data_path, model_path, init_file_path, yaml_path, save_root, zip_name):
  75. os.makedirs(save_root, exist_ok=True)
  76. tmp_spec_path = os.path.join(save_root, "rkme.json")
  77. tmp_model_path = os.path.join(save_root, "conv_model.pth")
  78. tmp_yaml_path = os.path.join(save_root, "learnware.yaml")
  79. tmp_init_path = os.path.join(save_root, "__init__.py")
  80. tmp_model_file_path = os.path.join(save_root, "model.py")
  81. mmodel_file_path = "./example_files/model.py"
  82. # Computing the specification from the whole dataset is too costly.
  83. X = np.load(data_path)
  84. indices = np.random.choice(len(X), size=2000, replace=False)
  85. X_sampled = X[indices]
  86. st = time.time()
  87. user_spec = RKMEImageSpecification(cuda_idx=0)
  88. user_spec.generate_stat_spec_from_data(X=X_sampled)
  89. ed = time.time()
  90. logger.info("Stat spec generated in %.3f s" % (ed - st))
  91. user_spec.save(tmp_spec_path)
  92. copyfile(model_path, tmp_model_path)
  93. copyfile(yaml_path, tmp_yaml_path)
  94. copyfile(init_file_path, tmp_init_path)
  95. copyfile(mmodel_file_path, tmp_model_file_path)
  96. zip_file_name = os.path.join(learnware_pool_dir, "%s.zip" % (zip_name))
  97. with zipfile.ZipFile(zip_file_name, "w", compression=zipfile.ZIP_DEFLATED) as zip_obj:
  98. zip_obj.write(tmp_spec_path, "rkme.json")
  99. zip_obj.write(tmp_model_path, "conv_model.pth")
  100. zip_obj.write(tmp_yaml_path, "learnware.yaml")
  101. zip_obj.write(tmp_init_path, "__init__.py")
  102. zip_obj.write(tmp_model_file_path, "model.py")
  103. rmtree(save_root)
  104. logger.info("New Learnware Saved to %s" % (zip_file_name))
  105. return zip_file_name
  106. def prepare_market():
  107. image_market = EasyMarket(market_id="cifar10", rebuild=True)
  108. try:
  109. rmtree(learnware_pool_dir)
  110. except:
  111. pass
  112. os.makedirs(learnware_pool_dir, exist_ok=True)
  113. for i in tqdm(range(n_uploaders), total=n_uploaders, desc="Preparing..."):
  114. data_path = os.path.join(uploader_save_root, "uploader_%d_X.npy" % (i))
  115. model_path = os.path.join(model_save_root, "uploader_%d.pth" % (i))
  116. init_file_path = "./example_files/example_init.py"
  117. yaml_file_path = "./example_files/example_yaml.yaml"
  118. new_learnware_path = prepare_learnware(
  119. data_path, model_path, init_file_path, yaml_file_path, tmp_dir, "%s_%d" % (dataset, i)
  120. )
  121. semantic_spec = semantic_specs[0]
  122. semantic_spec["Name"]["Values"] = "learnware_%d" % (i)
  123. semantic_spec["Description"]["Values"] = "test_learnware_number_%d" % (i)
  124. image_market.add_learnware(new_learnware_path, semantic_spec)
  125. logger.info("Total Item: %d" % (len(image_market)))
  126. curr_inds = image_market._get_ids()
  127. logger.info("Available ids: " + str(curr_inds))
  128. def test_search(gamma=0.1, load_market=True):
  129. if load_market:
  130. image_market = EasyMarket(market_id="cifar10")
  131. else:
  132. prepare_market()
  133. image_market = EasyMarket(market_id="cifar10")
  134. logger.info("Number of items in the market: %d" % len(image_market))
  135. select_list = []
  136. avg_list = []
  137. improve_list = []
  138. job_selector_score_list = []
  139. ensemble_score_list = []
  140. for i in tqdm(range(n_users), total=n_users, desc="Searching..."):
  141. user_data_path = os.path.join(user_save_root, "user_%d_X.npy" % (i))
  142. user_label_path = os.path.join(user_save_root, "user_%d_y.npy" % (i))
  143. user_data = np.load(user_data_path)
  144. user_label = np.load(user_label_path)
  145. user_stat_spec = RKMEImageSpecification(cuda_idx=0)
  146. user_stat_spec.generate_stat_spec_from_data(X=user_data, resize=False)
  147. user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_stat_spec})
  148. logger.info("Searching Market for user: %d" % i)
  149. sorted_score_list, single_learnware_list, mixture_score, mixture_learnware_list = image_market.search_learnware(
  150. user_info
  151. )
  152. acc_list = []
  153. for idx, (score, learnware) in enumerate(zip(sorted_score_list[:5], single_learnware_list[:5])):
  154. pred_y = learnware.predict(user_data)
  155. acc = eval_prediction(pred_y, user_label)
  156. acc_list.append(acc)
  157. logger.info("Search rank: %d, score: %.3f, learnware_id: %s, acc: %.3f" % (idx, score, learnware.id, acc))
  158. # test reuse (job selector)
  159. # reuse_baseline = JobSelectorReuser(learnware_list=mixture_learnware_list, herding_num=100)
  160. # reuse_predict = reuse_baseline.predict(user_data=user_data)
  161. # reuse_score = eval_prediction(reuse_predict, user_label)
  162. # job_selector_score_list.append(reuse_score)
  163. # print(f"mixture reuse loss: {reuse_score}")
  164. # test reuse (ensemble)
  165. reuse_ensemble = AveragingReuser(learnware_list=single_learnware_list[:3], mode="vote_by_prob")
  166. ensemble_predict_y = reuse_ensemble.predict(user_data=user_data)
  167. ensemble_score = eval_prediction(ensemble_predict_y, user_label)
  168. ensemble_score_list.append(ensemble_score)
  169. print(f"reuse accuracy (vote_by_prob): {ensemble_score}\n")
  170. select_list.append(acc_list[0])
  171. avg_list.append(np.mean(acc_list))
  172. improve_list.append((acc_list[0] - np.mean(acc_list)) / np.mean(acc_list))
  173. logger.info(
  174. "Accuracy of selected learnware: %.3f +/- %.3f, Average performance: %.3f +/- %.3f"
  175. % (np.mean(select_list), np.std(select_list), np.mean(avg_list), np.std(avg_list))
  176. )
  177. logger.info(
  178. "Ensemble Reuse Performance: %.3f +/- %.3f" % (np.mean(ensemble_score_list), np.std(ensemble_score_list))
  179. )
  180. if __name__ == "__main__":
  181. logger.info("=" * 40)
  182. logger.info(f"n_uploaders:\t{n_uploaders}")
  183. logger.info(f"n_users:\t{n_users}")
  184. logger.info("=" * 40)
  185. prepare_data()
  186. prepare_model()
  187. test_search(load_market=False)