You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

workflow.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. import os
  2. import pickle
  3. import random
  4. import tempfile
  5. import time
  6. import fire
  7. import matplotlib.pyplot as plt
  8. import numpy as np
  9. from config import text_benchmark_config
  10. from sklearn.feature_extraction.text import TfidfVectorizer
  11. from sklearn.metrics import accuracy_score
  12. from sklearn.naive_bayes import MultinomialNB
  13. from learnware.client import LearnwareClient
  14. from learnware.logger import get_module_logger
  15. from learnware.market import BaseUserInfo, instantiate_learnware_market
  16. from learnware.reuse import AveragingReuser, EnsemblePruningReuser, JobSelectorReuser
  17. from learnware.specification import RKMETextSpecification
  18. from learnware.tests.benchmarks import LearnwareBenchmark
  19. logger = get_module_logger("text_workflow", level="INFO")
  20. class TextDatasetWorkflow:
  21. @staticmethod
  22. def _train_model(X, y):
  23. vectorizer = TfidfVectorizer(stop_words="english")
  24. X_tfidf = vectorizer.fit_transform(X)
  25. clf = MultinomialNB(alpha=0.1)
  26. clf.fit(X_tfidf, y)
  27. return vectorizer, clf
  28. @staticmethod
  29. def _eval_prediction(pred_y, target_y):
  30. if not isinstance(pred_y, np.ndarray):
  31. pred_y = pred_y.detach().cpu().numpy()
  32. pred_y = np.array(pred_y) if len(pred_y.shape) == 1 else np.argmax(pred_y, 1)
  33. target_y = np.array(target_y)
  34. return accuracy_score(target_y, pred_y)
  35. def _plot_labeled_peformance_curves(self, all_user_curves_data):
  36. plt.figure(figsize=(10, 6))
  37. plt.xticks(range(len(self.n_labeled_list)), self.n_labeled_list)
  38. styles = [
  39. {"color": "navy", "linestyle": "-", "marker": "o"},
  40. {"color": "magenta", "linestyle": "-.", "marker": "d"},
  41. ]
  42. labels = ["User Model", "Multiple Learnware Reuse (EnsemblePrune)"]
  43. user_array, pruning_array = all_user_curves_data
  44. for array, style, label in zip([user_array, pruning_array], styles, labels):
  45. mean_curve = np.array([item[0] for item in array])
  46. std_curve = np.array([item[1] for item in array])
  47. plt.plot(mean_curve, **style, label=label)
  48. plt.fill_between(
  49. range(len(mean_curve)),
  50. mean_curve - std_curve,
  51. mean_curve + std_curve,
  52. color=style["color"],
  53. alpha=0.2,
  54. )
  55. plt.xlabel("Amout of Labeled User Data", fontsize=14)
  56. plt.ylabel("1 - Accuracy", fontsize=14)
  57. plt.title("Results on Text Experimental Scenario", fontsize=16)
  58. plt.legend(fontsize=14)
  59. plt.tight_layout()
  60. plt.savefig(os.path.join(self.fig_path, "text_labeled_curves.svg"), bbox_inches="tight", dpi=700)
  61. def _prepare_market(self, rebuild=False):
  62. client = LearnwareClient()
  63. self.text_benchmark = LearnwareBenchmark().get_benchmark(text_benchmark_config)
  64. self.text_market = instantiate_learnware_market(market_id=self.text_benchmark.name, rebuild=rebuild)
  65. self.user_semantic = client.get_semantic_specification(self.text_benchmark.learnware_ids[0])
  66. self.user_semantic["Name"]["Values"] = ""
  67. if len(self.text_market) == 0 or rebuild is True:
  68. for learnware_id in self.text_benchmark.learnware_ids:
  69. with tempfile.TemporaryDirectory(prefix="text_benchmark_") as tempdir:
  70. zip_path = os.path.join(tempdir, f"{learnware_id}.zip")
  71. for i in range(20):
  72. try:
  73. semantic_spec = client.get_semantic_specification(learnware_id)
  74. client.download_learnware(learnware_id, zip_path)
  75. self.text_market.add_learnware(zip_path, semantic_spec)
  76. break
  77. except Exception:
  78. time.sleep(1)
  79. continue
  80. logger.info("Total Item: %d" % (len(self.text_market)))
  81. def unlabeled_text_example(self, rebuild=False):
  82. self._prepare_market(rebuild)
  83. select_list = []
  84. avg_list = []
  85. best_list = []
  86. improve_list = []
  87. job_selector_score_list = []
  88. ensemble_score_list = []
  89. all_learnwares = self.text_market.get_learnwares()
  90. for i in range(text_benchmark_config.user_num):
  91. user_data, user_label = self.text_benchmark.get_test_data(user_ids=i)
  92. user_stat_spec = RKMETextSpecification()
  93. user_stat_spec.generate_stat_spec_from_data(X=user_data)
  94. user_info = BaseUserInfo(
  95. semantic_spec=self.user_semantic, stat_info={"RKMETextSpecification": user_stat_spec}
  96. )
  97. logger.info("Searching Market for user: %d" % (i))
  98. search_result = self.text_market.search_learnware(user_info)
  99. single_result = search_result.get_single_results()
  100. multiple_result = search_result.get_multiple_results()
  101. print(f"search result of user{i}:")
  102. print(
  103. f"single model num: {len(single_result)}, max_score: {single_result[0].score}, min_score: {single_result[-1].score}"
  104. )
  105. acc_list = []
  106. for idx in range(len(all_learnwares)):
  107. learnware = all_learnwares[idx]
  108. pred_y = learnware.predict(user_data)
  109. acc = self._eval_prediction(pred_y, user_label)
  110. acc_list.append(acc)
  111. learnware = single_result[0].learnware
  112. pred_y = learnware.predict(user_data)
  113. best_acc = self._eval_prediction(pred_y, user_label)
  114. best_list.append(np.max(acc_list))
  115. select_list.append(best_acc)
  116. avg_list.append(np.mean(acc_list))
  117. improve_list.append((best_acc - np.mean(acc_list)) / np.mean(acc_list))
  118. print(f"market mean accuracy: {np.mean(acc_list)}, market best accuracy: {np.max(acc_list)}")
  119. print(
  120. f"Top1-score: {single_result[0].score}, learnware_id: {single_result[0].learnware.id}, acc: {best_acc}"
  121. )
  122. if len(multiple_result) > 0:
  123. mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares])
  124. print(f"mixture_score: {multiple_result[0].score}, mixture_learnware: {mixture_id}")
  125. mixture_learnware_list = multiple_result[0].learnwares
  126. else:
  127. mixture_learnware_list = [single_result[0].learnware]
  128. # test reuse (job selector)
  129. reuse_baseline = JobSelectorReuser(learnware_list=mixture_learnware_list, herding_num=100)
  130. reuse_predict = reuse_baseline.predict(user_data=user_data)
  131. reuse_score = self._eval_prediction(reuse_predict, user_label)
  132. job_selector_score_list.append(reuse_score)
  133. print(f"mixture reuse accuracy (job selector): {reuse_score}")
  134. # test reuse (ensemble)
  135. reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="vote_by_label")
  136. ensemble_predict_y = reuse_ensemble.predict(user_data=user_data)
  137. ensemble_score = self._eval_prediction(ensemble_predict_y, user_label)
  138. ensemble_score_list.append(ensemble_score)
  139. print(f"mixture reuse accuracy (ensemble): {ensemble_score}\n")
  140. logger.info(
  141. "Accuracy of selected learnware: %.3f +/- %.3f, Average performance: %.3f +/- %.3f, Best performance: %.3f +/- %.3f"
  142. % (
  143. np.mean(select_list),
  144. np.std(select_list),
  145. np.mean(avg_list),
  146. np.std(avg_list),
  147. np.mean(best_list),
  148. np.std(best_list),
  149. )
  150. )
  151. logger.info("Average performance improvement: %.3f" % (np.mean(improve_list)))
  152. logger.info(
  153. "Average Job Selector Reuse Performance: %.3f +/- %.3f"
  154. % (np.mean(job_selector_score_list), np.std(job_selector_score_list))
  155. )
  156. logger.info(
  157. "Averaging Ensemble Reuse Performance: %.3f +/- %.3f"
  158. % (np.mean(ensemble_score_list), np.std(ensemble_score_list))
  159. )
  160. def labeled_text_example(self, rebuild=False, skip_test=False):
  161. self.n_labeled_list = [100, 200, 500, 1000, 2000, 4000]
  162. self.repeated_list = [10, 10, 10, 3, 3, 3]
  163. self.root_path = os.path.dirname(os.path.abspath(__file__))
  164. self.fig_path = os.path.join(self.root_path, "figs")
  165. self.curve_path = os.path.join(self.root_path, "curves")
  166. if not skip_test:
  167. self._prepare_market(rebuild)
  168. os.makedirs(self.fig_path, exist_ok=True)
  169. os.makedirs(self.curve_path, exist_ok=True)
  170. for i in range(text_benchmark_config.user_num):
  171. user_model_score_mat = []
  172. pruning_score_mat = []
  173. single_score_mat = []
  174. test_x, test_y = self.text_benchmark.get_test_data(user_ids=i)
  175. test_y = np.array(test_y)
  176. train_x, train_y = self.text_benchmark.get_train_data(user_ids=i)
  177. train_y = np.array(train_y)
  178. user_stat_spec = RKMETextSpecification()
  179. user_stat_spec.generate_stat_spec_from_data(X=test_x)
  180. user_info = BaseUserInfo(
  181. semantic_spec=self.user_semantic, stat_info={"RKMETextSpecification": user_stat_spec}
  182. )
  183. logger.info(f"Searching Market for user_{i}")
  184. search_result = self.text_market.search_learnware(user_info)
  185. single_result = search_result.get_single_results()
  186. multiple_result = search_result.get_multiple_results()
  187. learnware = single_result[0].learnware
  188. pred_y = learnware.predict(test_x)
  189. best_acc = self._eval_prediction(pred_y, test_y)
  190. print(f"search result of user_{i}:")
  191. print(
  192. f"single model num: {len(single_result)}, max_score: {single_result[0].score}, min_score: {single_result[-1].score}, single model acc: {best_acc}"
  193. )
  194. if len(multiple_result) > 0:
  195. mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares])
  196. print(f"mixture_score: {multiple_result[0].score}, mixture_learnware: {mixture_id}")
  197. mixture_learnware_list = multiple_result[0].learnwares
  198. else:
  199. mixture_learnware_list = [single_result[0].learnware]
  200. for n_label, repeated in zip(self.n_labeled_list, self.repeated_list):
  201. user_model_score_list, reuse_pruning_score_list = [], []
  202. if n_label > len(train_x):
  203. n_label = len(train_x)
  204. for _ in range(repeated):
  205. x_train, y_train = zip(*random.sample(list(zip(train_x, train_y)), k=n_label))
  206. x_train = list(x_train)
  207. y_train = np.array(list(y_train))
  208. modelv, modell = self._train_model(x_train, y_train)
  209. user_model_predict_y = modell.predict(modelv.transform(test_x))
  210. user_model_score = self._eval_prediction(user_model_predict_y, test_y)
  211. user_model_score_list.append(user_model_score)
  212. reuse_pruning = EnsemblePruningReuser(
  213. learnware_list=mixture_learnware_list, mode="classification"
  214. )
  215. reuse_pruning.fit(x_train, y_train)
  216. reuse_pruning_predict_y = reuse_pruning.predict(user_data=test_x)
  217. reuse_pruning_score = self._eval_prediction(reuse_pruning_predict_y, test_y)
  218. reuse_pruning_score_list.append(reuse_pruning_score)
  219. single_score_mat.append([best_acc] * repeated)
  220. user_model_score_mat.append(user_model_score_list)
  221. pruning_score_mat.append(reuse_pruning_score_list)
  222. print(
  223. f"user_label_num: {n_label}, user_acc: {np.mean(user_model_score_mat[-1])}, pruning_acc: {np.mean(pruning_score_mat[-1])}"
  224. )
  225. logger.info(f"Saving Curves for User_{i}")
  226. user_curves_data = (single_score_mat, user_model_score_mat, pruning_score_mat)
  227. with open(os.path.join(self.curve_path, f"curve{str(i)}.pkl"), "wb") as f:
  228. pickle.dump(user_curves_data, f)
  229. pruning_curves_data, user_model_curves_data = [], []
  230. total_user_model_score_mat = [np.zeros(self.repeated_list[i]) for i in range(len(self.n_labeled_list))]
  231. total_pruning_score_mat = [np.zeros(self.repeated_list[i]) for i in range(len(self.n_labeled_list))]
  232. for user_idx in range(text_benchmark_config.user_num):
  233. with open(os.path.join(self.curve_path, f"curve{str(user_idx)}.pkl"), "rb") as f:
  234. user_curves_data = pickle.load(f)
  235. (single_score_mat, user_model_score_mat, pruning_score_mat) = user_curves_data
  236. for i in range(len(self.n_labeled_list)):
  237. total_user_model_score_mat[i] += 1 - np.array(user_model_score_mat[i])
  238. total_pruning_score_mat[i] += 1 - np.array(pruning_score_mat[i])
  239. for i in range(len(self.n_labeled_list)):
  240. total_user_model_score_mat[i] /= text_benchmark_config.user_num
  241. total_pruning_score_mat[i] /= text_benchmark_config.user_num
  242. user_model_curves_data.append(
  243. (np.mean(total_user_model_score_mat[i]), np.std(total_user_model_score_mat[i]))
  244. )
  245. pruning_curves_data.append((np.mean(total_pruning_score_mat[i]), np.std(total_pruning_score_mat[i])))
  246. self._plot_labeled_peformance_curves([user_model_curves_data, pruning_curves_data])
  247. if __name__ == "__main__":
  248. fire.Fire(TextDatasetWorkflow)