You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

workflow.py 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. import os
  2. import pickle
  3. import random
  4. import tempfile
  5. import time
  6. import fire
  7. import matplotlib.pyplot as plt
  8. import numpy as np
  9. import torch
  10. from config import image_benchmark_config
  11. from model import ConvModel
  12. from torch.utils.data import TensorDataset
  13. from utils import evaluate, train_model
  14. from learnware.client import LearnwareClient
  15. from learnware.logger import get_module_logger
  16. from learnware.market import BaseUserInfo, instantiate_learnware_market
  17. from learnware.reuse import AveragingReuser, EnsemblePruningReuser, JobSelectorReuser
  18. from learnware.specification import generate_stat_spec
  19. from learnware.tests.benchmarks import LearnwareBenchmark
  20. from learnware.utils import choose_device
  21. logger = get_module_logger("image_workflow", level="INFO")
  22. class ImageDatasetWorkflow:
  23. def _plot_labeled_peformance_curves(self, all_user_curves_data):
  24. plt.figure(figsize=(10, 6))
  25. plt.xticks(range(len(self.n_labeled_list)), self.n_labeled_list)
  26. styles = [
  27. {"color": "navy", "linestyle": "-", "marker": "o"},
  28. {"color": "magenta", "linestyle": "-.", "marker": "d"},
  29. ]
  30. labels = ["User Model", "Multiple Learnware Reuse (EnsemblePrune)"]
  31. user_array, pruning_array = all_user_curves_data
  32. for array, style, label in zip([user_array, pruning_array], styles, labels):
  33. mean_curve = np.array([item[0] for item in array])
  34. std_curve = np.array([item[1] for item in array])
  35. plt.plot(mean_curve, **style, label=label)
  36. plt.fill_between(
  37. range(len(mean_curve)),
  38. mean_curve - std_curve,
  39. mean_curve + std_curve,
  40. color=style["color"],
  41. alpha=0.2,
  42. )
  43. plt.xlabel("Amout of Labeled User Data", fontsize=14)
  44. plt.ylabel("1 - Accuracy", fontsize=14)
  45. plt.title("Results on Image Experimental Scenario", fontsize=16)
  46. plt.legend(fontsize=14)
  47. plt.tight_layout()
  48. plt.savefig(os.path.join(self.fig_path, "image_labeled_curves.svg"), bbox_inches="tight", dpi=700)
  49. def _prepare_market(self, rebuild=False):
  50. client = LearnwareClient()
  51. self.image_benchmark = LearnwareBenchmark().get_benchmark(image_benchmark_config)
  52. self.image_market = instantiate_learnware_market(market_id=self.image_benchmark.name, rebuild=rebuild)
  53. self.user_semantic = client.get_semantic_specification(self.image_benchmark.learnware_ids[0])
  54. self.user_semantic["Name"]["Values"] = ""
  55. if len(self.image_market) == 0 or rebuild is True:
  56. for learnware_id in self.image_benchmark.learnware_ids:
  57. with tempfile.TemporaryDirectory(prefix="image_benchmark_") as tempdir:
  58. zip_path = os.path.join(tempdir, f"{learnware_id}.zip")
  59. for i in range(20):
  60. try:
  61. semantic_spec = client.get_semantic_specification(learnware_id)
  62. client.download_learnware(learnware_id, zip_path)
  63. self.image_market.add_learnware(zip_path, semantic_spec)
  64. break
  65. except Exception:
  66. time.sleep(1)
  67. continue
  68. logger.info("Total Item: %d" % (len(self.image_market)))
  69. def image_example(self, rebuild=False, skip_test=False):
  70. np.random.seed(1)
  71. random.seed(1)
  72. self.n_labeled_list = [100, 200, 500, 1000, 2000, 4000]
  73. self.repeated_list = [10, 10, 10, 3, 3, 3]
  74. device = choose_device(0)
  75. self.root_path = os.path.dirname(os.path.abspath(__file__))
  76. self.fig_path = os.path.join(self.root_path, "figs")
  77. self.curve_path = os.path.join(self.root_path, "curves")
  78. self.model_path = os.path.join(self.root_path, "models")
  79. os.makedirs(self.fig_path, exist_ok=True)
  80. os.makedirs(self.curve_path, exist_ok=True)
  81. os.makedirs(self.model_path, exist_ok=True)
  82. select_list = []
  83. avg_list = []
  84. best_list = []
  85. improve_list = []
  86. job_selector_score_list = []
  87. ensemble_score_list = []
  88. if not skip_test:
  89. self._prepare_market(rebuild)
  90. all_learnwares = self.image_market.get_learnwares()
  91. for i in range(image_benchmark_config.user_num):
  92. test_x, test_y = self.image_benchmark.get_test_data(user_ids=i)
  93. train_x, train_y = self.image_benchmark.get_train_data(user_ids=i)
  94. test_x = torch.from_numpy(test_x)
  95. test_y = torch.from_numpy(test_y)
  96. test_dataset = TensorDataset(test_x, test_y)
  97. user_stat_spec = generate_stat_spec(type="image", X=test_x, whitening=False)
  98. user_info = BaseUserInfo(
  99. semantic_spec=self.user_semantic, stat_info={user_stat_spec.type: user_stat_spec}
  100. )
  101. logger.info("Searching Market for user: %d" % (i))
  102. search_result = self.image_market.search_learnware(user_info)
  103. single_result = search_result.get_single_results()
  104. multiple_result = search_result.get_multiple_results()
  105. print(f"search result of user{i}:")
  106. print(
  107. f"single model num: {len(single_result)}, max_score: {single_result[0].score}, min_score: {single_result[-1].score}"
  108. )
  109. acc_list = []
  110. for idx in range(len(all_learnwares)):
  111. learnware = all_learnwares[idx]
  112. loss, acc = evaluate(learnware, test_dataset)
  113. acc_list.append(acc)
  114. learnware = single_result[0].learnware
  115. best_loss, best_acc = evaluate(learnware, test_dataset)
  116. best_list.append(np.max(acc_list))
  117. select_list.append(best_acc)
  118. avg_list.append(np.mean(acc_list))
  119. improve_list.append((best_acc - np.mean(acc_list)) / np.mean(acc_list))
  120. print(f"market mean accuracy: {np.mean(acc_list)}, market best accuracy: {np.max(acc_list)}")
  121. print(
  122. f"Top1-score: {single_result[0].score}, learnware_id: {single_result[0].learnware.id}, acc: {best_acc}"
  123. )
  124. if len(multiple_result) > 0:
  125. mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares])
  126. print(f"mixture_score: {multiple_result[0].score}, mixture_learnware: {mixture_id}")
  127. mixture_learnware_list = multiple_result[0].learnwares
  128. else:
  129. mixture_learnware_list = [single_result[0].learnware]
  130. # test reuse (job selector)
  131. reuse_job_selector = JobSelectorReuser(learnware_list=mixture_learnware_list, use_herding=False)
  132. job_loss, job_acc = evaluate(reuse_job_selector, test_dataset)
  133. job_selector_score_list.append(job_acc)
  134. print(f"mixture reuse accuracy (job selector): {job_acc}")
  135. # test reuse (ensemble)
  136. reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="vote_by_prob")
  137. ensemble_loss, ensemble_acc = evaluate(reuse_ensemble, test_dataset)
  138. ensemble_score_list.append(ensemble_acc)
  139. print(f"mixture reuse accuracy (ensemble): {ensemble_acc}\n")
  140. user_model_score_mat = []
  141. pruning_score_mat = []
  142. single_score_mat = []
  143. for n_label, repeated in zip(self.n_labeled_list, self.repeated_list):
  144. user_model_score_list, reuse_pruning_score_list = [], []
  145. if n_label > len(train_x):
  146. n_label = len(train_x)
  147. for _ in range(repeated):
  148. x_train, y_train = zip(*random.sample(list(zip(train_x, train_y)), k=n_label))
  149. x_train = np.array(list(x_train))
  150. y_train = np.array(list(y_train))
  151. x_train = torch.from_numpy(x_train)
  152. y_train = torch.from_numpy(y_train)
  153. sampled_dataset = TensorDataset(x_train, y_train)
  154. mode_save_path = os.path.abspath(os.path.join(self.model_path, "model.pth"))
  155. model = ConvModel(
  156. channel=x_train.shape[1], im_size=(x_train.shape[2], x_train.shape[3]), n_random_features=10
  157. ).to(device)
  158. train_model(
  159. model,
  160. sampled_dataset,
  161. sampled_dataset,
  162. mode_save_path,
  163. epochs=35,
  164. batch_size=128,
  165. device=device,
  166. verbose=False,
  167. )
  168. model.load_state_dict(torch.load(mode_save_path))
  169. _, user_model_acc = evaluate(model, test_dataset, distribution=True)
  170. user_model_score_list.append(user_model_acc)
  171. reuse_pruning = EnsemblePruningReuser(
  172. learnware_list=mixture_learnware_list, mode="classification"
  173. )
  174. reuse_pruning.fit(x_train, y_train)
  175. _, pruning_acc = evaluate(reuse_pruning, test_dataset, distribution=False)
  176. reuse_pruning_score_list.append(pruning_acc)
  177. single_score_mat.append([best_acc] * repeated)
  178. user_model_score_mat.append(user_model_score_list)
  179. pruning_score_mat.append(reuse_pruning_score_list)
  180. print(
  181. f"user_label_num: {n_label}, user_acc: {np.mean(user_model_score_mat[-1])}, pruning_acc: {np.mean(pruning_score_mat[-1])}"
  182. )
  183. logger.info(f"Saving Curves for User_{i}")
  184. user_curves_data = (single_score_mat, user_model_score_mat, pruning_score_mat)
  185. with open(os.path.join(self.curve_path, f"curve{str(i)}.pkl"), "wb") as f:
  186. pickle.dump(user_curves_data, f)
  187. logger.info(
  188. "Accuracy of selected learnware: %.3f +/- %.3f, Average performance: %.3f +/- %.3f, Best performance: %.3f +/- %.3f"
  189. % (
  190. np.mean(select_list),
  191. np.std(select_list),
  192. np.mean(avg_list),
  193. np.std(avg_list),
  194. np.mean(best_list),
  195. np.std(best_list),
  196. )
  197. )
  198. logger.info("Average performance improvement: %.3f" % (np.mean(improve_list)))
  199. logger.info(
  200. "Average Job Selector Reuse Performance: %.3f +/- %.3f"
  201. % (np.mean(job_selector_score_list), np.std(job_selector_score_list))
  202. )
  203. logger.info(
  204. "Averaging Ensemble Reuse Performance: %.3f +/- %.3f"
  205. % (np.mean(ensemble_score_list), np.std(ensemble_score_list))
  206. )
  207. pruning_curves_data, user_model_curves_data = [], []
  208. total_user_model_score_mat = [np.zeros(self.repeated_list[i]) for i in range(len(self.n_labeled_list))]
  209. total_pruning_score_mat = [np.zeros(self.repeated_list[i]) for i in range(len(self.n_labeled_list))]
  210. for user_idx in range(image_benchmark_config.user_num):
  211. with open(os.path.join(self.curve_path, f"curve{str(user_idx)}.pkl"), "rb") as f:
  212. user_curves_data = pickle.load(f)
  213. (single_score_mat, user_model_score_mat, pruning_score_mat) = user_curves_data
  214. for i in range(len(self.n_labeled_list)):
  215. total_user_model_score_mat[i] += 1 - np.array(user_model_score_mat[i]) / 100
  216. total_pruning_score_mat[i] += 1 - np.array(pruning_score_mat[i]) / 100
  217. for i in range(len(self.n_labeled_list)):
  218. total_user_model_score_mat[i] /= image_benchmark_config.user_num
  219. total_pruning_score_mat[i] /= image_benchmark_config.user_num
  220. user_model_curves_data.append(
  221. (np.mean(total_user_model_score_mat[i]), np.std(total_user_model_score_mat[i]))
  222. )
  223. pruning_curves_data.append((np.mean(total_pruning_score_mat[i]), np.std(total_pruning_score_mat[i])))
  224. self._plot_labeled_peformance_curves([user_model_curves_data, pruning_curves_data])
  225. if __name__ == "__main__":
  226. fire.Fire(ImageDatasetWorkflow)