You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

workflow.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. import os
  2. import fire
  3. import time
  4. import torch
  5. import pickle
  6. import random
  7. import tempfile
  8. import numpy as np
  9. import matplotlib.pyplot as plt
  10. from torch.utils.data import TensorDataset
  11. from learnware.utils import choose_device
  12. from learnware.client import LearnwareClient
  13. from learnware.logger import get_module_logger
  14. from learnware.specification import generate_stat_spec
  15. from learnware.tests.benchmarks import LearnwareBenchmark
  16. from learnware.market import instantiate_learnware_market, BaseUserInfo
  17. from learnware.reuse import JobSelectorReuser, AveragingReuser, EnsemblePruningReuser
  18. from model import ConvModel
  19. from utils import train_model, evaluate
  20. from config import image_benchmark_config
  21. logger = get_module_logger("image_workflow", level="INFO")
  22. class ImageDatasetWorkflow:
  23. def _plot_labeled_peformance_curves(self, all_user_curves_data):
  24. plt.figure(figsize=(10, 6))
  25. plt.xticks(range(len(self.n_labeled_list)), self.n_labeled_list)
  26. styles = [
  27. {"color": "navy", "linestyle": "-", "marker": "o"},
  28. {"color": "magenta", "linestyle": "-.", "marker": "d"},
  29. ]
  30. labels = ["User Model", "Multiple Learnware Reuse (EnsemblePrune)"]
  31. user_array, pruning_array = all_user_curves_data
  32. for array, style, label in zip([user_array, pruning_array], styles, labels):
  33. mean_curve = np.array([item[0] for item in array])
  34. std_curve = np.array([item[1] for item in array])
  35. plt.plot(mean_curve, **style, label=label)
  36. plt.fill_between(
  37. range(len(mean_curve)),
  38. mean_curve - std_curve,
  39. mean_curve + std_curve,
  40. color=style["color"],
  41. alpha=0.2,
  42. )
  43. plt.xlabel("Amout of Labeled User Data", fontsize=14)
  44. plt.ylabel("1 - Accuracy", fontsize=14)
  45. plt.title(f"Results on Image Experimental Scenario", fontsize=16)
  46. plt.legend(fontsize=14)
  47. plt.tight_layout()
  48. plt.savefig(os.path.join(self.fig_path, "image_labeled_curves.svg"), bbox_inches="tight", dpi=700)
  49. def _prepare_market(self, rebuild=False):
  50. client = LearnwareClient()
  51. self.image_benchmark = LearnwareBenchmark().get_benchmark(image_benchmark_config)
  52. self.image_market = instantiate_learnware_market(market_id=self.image_benchmark.name, rebuild=rebuild)
  53. self.user_semantic = client.get_semantic_specification(self.image_benchmark.learnware_ids[0])
  54. self.user_semantic["Name"]["Values"] = ""
  55. if len(self.image_market) == 0 or rebuild == True:
  56. for learnware_id in self.image_benchmark.learnware_ids:
  57. with tempfile.TemporaryDirectory(prefix="image_benchmark_") as tempdir:
  58. zip_path = os.path.join(tempdir, f"{learnware_id}.zip")
  59. for i in range(20):
  60. try:
  61. semantic_spec = client.get_semantic_specification(learnware_id)
  62. client.download_learnware(learnware_id, zip_path)
  63. self.image_market.add_learnware(zip_path, semantic_spec)
  64. break
  65. except:
  66. time.sleep(1)
  67. continue
  68. logger.info("Total Item: %d" % (len(self.image_market)))
  69. def image_example(self, rebuild=False):
  70. np.random.seed(1)
  71. random.seed(1)
  72. self._prepare_market(rebuild)
  73. self.n_labeled_list = [100, 200, 500, 1000, 2000, 4000]
  74. self.repeated_list = [10, 10, 10, 3, 3, 3]
  75. device = choose_device(0)
  76. self.root_path = os.path.dirname(os.path.abspath(__file__))
  77. self.fig_path = os.path.join(self.root_path, "figs")
  78. self.curve_path = os.path.join(self.root_path, "curves")
  79. self.model_path = os.path.join(self.root_path, "models")
  80. os.makedirs(self.fig_path, exist_ok=True)
  81. os.makedirs(self.curve_path, exist_ok=True)
  82. os.makedirs(self.model_path, exist_ok=True)
  83. select_list = []
  84. avg_list = []
  85. best_list = []
  86. improve_list = []
  87. job_selector_score_list = []
  88. ensemble_score_list = []
  89. all_learnwares = self.image_market.get_learnwares()
  90. for i in range(self.image_benchmark.user_num):
  91. test_x, test_y = self.image_benchmark.get_test_data(user_ids=i)
  92. train_x, train_y = self.image_benchmark.get_train_data(user_ids=i)
  93. test_x = torch.from_numpy(test_x)
  94. test_y = torch.from_numpy(test_y)
  95. test_dataset = TensorDataset(test_x, test_y)
  96. user_stat_spec = generate_stat_spec(type="image", X=test_x, whitening=False)
  97. user_info = BaseUserInfo(semantic_spec=self.user_semantic, stat_info={user_stat_spec.type: user_stat_spec})
  98. logger.info("Searching Market for user: %d" % (i))
  99. search_result = self.image_market.search_learnware(user_info)
  100. single_result = search_result.get_single_results()
  101. multiple_result = search_result.get_multiple_results()
  102. print(f"search result of user{i}:")
  103. print(
  104. f"single model num: {len(single_result)}, max_score: {single_result[0].score}, min_score: {single_result[-1].score}"
  105. )
  106. acc_list = []
  107. for idx in range(len(all_learnwares)):
  108. learnware = all_learnwares[idx]
  109. loss, acc = evaluate(learnware, test_dataset)
  110. acc_list.append(acc)
  111. learnware = single_result[0].learnware
  112. best_loss, best_acc = evaluate(learnware, test_dataset)
  113. best_list.append(np.max(acc_list))
  114. select_list.append(best_acc)
  115. avg_list.append(np.mean(acc_list))
  116. improve_list.append((best_acc - np.mean(acc_list)) / np.mean(acc_list))
  117. print(f"market mean accuracy: {np.mean(acc_list)}, market best accuracy: {np.max(acc_list)}")
  118. print(
  119. f"Top1-score: {single_result[0].score}, learnware_id: {single_result[0].learnware.id}, acc: {best_acc}"
  120. )
  121. if len(multiple_result) > 0:
  122. mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares])
  123. print(f"mixture_score: {multiple_result[0].score}, mixture_learnware: {mixture_id}")
  124. mixture_learnware_list = multiple_result[0].learnwares
  125. else:
  126. mixture_learnware_list = [single_result[0].learnware]
  127. # test reuse (job selector)
  128. reuse_job_selector = JobSelectorReuser(learnware_list=mixture_learnware_list, use_herding=False)
  129. job_loss, job_acc = evaluate(reuse_job_selector, test_dataset)
  130. job_selector_score_list.append(job_acc)
  131. print(f"mixture reuse accuracy (job selector): {job_acc}")
  132. # test reuse (ensemble)
  133. reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="vote_by_prob")
  134. ensemble_loss, ensemble_acc = evaluate(reuse_ensemble, test_dataset)
  135. ensemble_score_list.append(ensemble_acc)
  136. print(f"mixture reuse accuracy (ensemble): {ensemble_acc}\n")
  137. user_model_score_mat = []
  138. pruning_score_mat = []
  139. single_score_mat = []
  140. for n_label, repeated in zip(self.n_labeled_list, self.repeated_list):
  141. user_model_score_list, reuse_pruning_score_list = [], []
  142. if n_label > len(train_x):
  143. n_label = len(train_x)
  144. for _ in range(repeated):
  145. x_train, y_train = zip(*random.sample(list(zip(train_x, train_y)), k=n_label))
  146. x_train = np.array(list(x_train))
  147. y_train = np.array(list(y_train))
  148. x_train = torch.from_numpy(x_train)
  149. y_train = torch.from_numpy(y_train)
  150. sampled_dataset = TensorDataset(x_train, y_train)
  151. mode_save_path = os.path.abspath(os.path.join(self.model_path, "model.pth"))
  152. model = ConvModel(
  153. channel=x_train.shape[1], im_size=(x_train.shape[2], x_train.shape[3]), n_random_features=10
  154. ).to(device)
  155. train_model(
  156. model,
  157. sampled_dataset,
  158. sampled_dataset,
  159. mode_save_path,
  160. epochs=35,
  161. batch_size=128,
  162. device=device,
  163. verbose=False,
  164. )
  165. model.load_state_dict(torch.load(mode_save_path))
  166. _, user_model_acc = evaluate(model, test_dataset, distribution=True)
  167. user_model_score_list.append(user_model_acc)
  168. reuse_pruning = EnsemblePruningReuser(learnware_list=mixture_learnware_list, mode="classification")
  169. reuse_pruning.fit(x_train, y_train)
  170. _, pruning_acc = evaluate(reuse_pruning, test_dataset, distribution=False)
  171. reuse_pruning_score_list.append(pruning_acc)
  172. single_score_mat.append([best_acc] * repeated)
  173. user_model_score_mat.append(user_model_score_list)
  174. pruning_score_mat.append(reuse_pruning_score_list)
  175. print(
  176. f"user_label_num: {n_label}, user_acc: {np.mean(user_model_score_mat[-1])}, pruning_acc: {np.mean(pruning_score_mat[-1])}"
  177. )
  178. logger.info(f"Saving Curves for User_{i}")
  179. user_curves_data = (single_score_mat, user_model_score_mat, pruning_score_mat)
  180. with open(os.path.join(self.curve_path, f"curve{str(i)}.pkl"), "wb") as f:
  181. pickle.dump(user_curves_data, f)
  182. logger.info(
  183. "Accuracy of selected learnware: %.3f +/- %.3f, Average performance: %.3f +/- %.3f, Best performance: %.3f +/- %.3f"
  184. % (
  185. np.mean(select_list),
  186. np.std(select_list),
  187. np.mean(avg_list),
  188. np.std(avg_list),
  189. np.mean(best_list),
  190. np.std(best_list),
  191. )
  192. )
  193. logger.info("Average performance improvement: %.3f" % (np.mean(improve_list)))
  194. logger.info(
  195. "Average Job Selector Reuse Performance: %.3f +/- %.3f"
  196. % (np.mean(job_selector_score_list), np.std(job_selector_score_list))
  197. )
  198. logger.info(
  199. "Averaging Ensemble Reuse Performance: %.3f +/- %.3f"
  200. % (np.mean(ensemble_score_list), np.std(ensemble_score_list))
  201. )
  202. pruning_curves_data, user_model_curves_data = [], []
  203. total_user_model_score_mat = [np.zeros(self.repeated_list[i]) for i in range(len(self.n_labeled_list))]
  204. total_pruning_score_mat = [np.zeros(self.repeated_list[i]) for i in range(len(self.n_labeled_list))]
  205. for user_idx in range(self.image_benchmark.user_num):
  206. with open(os.path.join(self.curve_path, f"curve{str(user_idx)}.pkl"), "rb") as f:
  207. user_curves_data = pickle.load(f)
  208. (single_score_mat, user_model_score_mat, pruning_score_mat) = user_curves_data
  209. for i in range(len(self.n_labeled_list)):
  210. total_user_model_score_mat[i] += 1 - np.array(user_model_score_mat[i]) / 100
  211. total_pruning_score_mat[i] += 1 - np.array(pruning_score_mat[i]) / 100
  212. for i in range(len(self.n_labeled_list)):
  213. total_user_model_score_mat[i] /= self.image_benchmark.user_num
  214. total_pruning_score_mat[i] /= self.image_benchmark.user_num
  215. user_model_curves_data.append(
  216. (np.mean(total_user_model_score_mat[i]), np.std(total_user_model_score_mat[i]))
  217. )
  218. pruning_curves_data.append((np.mean(total_pruning_score_mat[i]), np.std(total_pruning_score_mat[i])))
  219. self._plot_labeled_peformance_curves([user_model_curves_data, pruning_curves_data])
  220. if __name__ == "__main__":
  221. fire.Fire(ImageDatasetWorkflow)