import os import pickle import random import tempfile import time import fire import matplotlib.pyplot as plt import numpy as np import torch from config import image_benchmark_config from model import ConvModel from torch.utils.data import TensorDataset from utils import evaluate, train_model from learnware.client import LearnwareClient from learnware.logger import get_module_logger from learnware.market import BaseUserInfo, instantiate_learnware_market from learnware.reuse import AveragingReuser, EnsemblePruningReuser, JobSelectorReuser from learnware.specification import generate_stat_spec from learnware.tests.benchmarks import LearnwareBenchmark from learnware.utils import choose_device logger = get_module_logger("image_workflow", level="INFO") class ImageDatasetWorkflow: def _plot_labeled_peformance_curves(self, all_user_curves_data): plt.figure(figsize=(10, 6)) plt.xticks(range(len(self.n_labeled_list)), self.n_labeled_list) styles = [ {"color": "navy", "linestyle": "-", "marker": "o"}, {"color": "magenta", "linestyle": "-.", "marker": "d"}, ] labels = ["User Model", "Multiple Learnware Reuse (EnsemblePrune)"] user_array, pruning_array = all_user_curves_data for array, style, label in zip([user_array, pruning_array], styles, labels): mean_curve = np.array([item[0] for item in array]) std_curve = np.array([item[1] for item in array]) plt.plot(mean_curve, **style, label=label) plt.fill_between( range(len(mean_curve)), mean_curve - std_curve, mean_curve + std_curve, color=style["color"], alpha=0.2, ) plt.xlabel("Amout of Labeled User Data", fontsize=14) plt.ylabel("1 - Accuracy", fontsize=14) plt.title("Results on Image Experimental Scenario", fontsize=16) plt.legend(fontsize=14) plt.tight_layout() plt.savefig(os.path.join(self.fig_path, "image_labeled_curves.svg"), bbox_inches="tight", dpi=700) def _prepare_market(self, rebuild=False): client = LearnwareClient() self.image_benchmark = LearnwareBenchmark().get_benchmark(image_benchmark_config) self.image_market = instantiate_learnware_market(market_id=self.image_benchmark.name, rebuild=rebuild) self.user_semantic = client.get_semantic_specification(self.image_benchmark.learnware_ids[0]) self.user_semantic["Name"]["Values"] = "" if len(self.image_market) == 0 or rebuild is True: for learnware_id in self.image_benchmark.learnware_ids: with tempfile.TemporaryDirectory(prefix="image_benchmark_") as tempdir: zip_path = os.path.join(tempdir, f"{learnware_id}.zip") for i in range(20): try: semantic_spec = client.get_semantic_specification(learnware_id) client.download_learnware(learnware_id, zip_path) self.image_market.add_learnware(zip_path, semantic_spec) break except Exception: time.sleep(1) continue logger.info("Total Item: %d" % (len(self.image_market))) def image_example(self, rebuild=False, skip_test=False): np.random.seed(1) random.seed(1) self.n_labeled_list = [100, 200, 500, 1000, 2000, 4000] self.repeated_list = [10, 10, 10, 3, 3, 3] device = choose_device(0) self.root_path = os.path.dirname(os.path.abspath(__file__)) self.fig_path = os.path.join(self.root_path, "figs") self.curve_path = os.path.join(self.root_path, "curves") self.model_path = os.path.join(self.root_path, "models") os.makedirs(self.fig_path, exist_ok=True) os.makedirs(self.curve_path, exist_ok=True) os.makedirs(self.model_path, exist_ok=True) select_list = [] avg_list = [] best_list = [] improve_list = [] job_selector_score_list = [] ensemble_score_list = [] if not skip_test: self._prepare_market(rebuild) all_learnwares = self.image_market.get_learnwares() for i in range(image_benchmark_config.user_num): test_x, test_y = self.image_benchmark.get_test_data(user_ids=i) train_x, train_y = self.image_benchmark.get_train_data(user_ids=i) test_x = torch.from_numpy(test_x) test_y = torch.from_numpy(test_y) test_dataset = TensorDataset(test_x, test_y) user_stat_spec = generate_stat_spec(type="image", X=test_x, whitening=False) user_info = BaseUserInfo( semantic_spec=self.user_semantic, stat_info={user_stat_spec.type: user_stat_spec} ) logger.info("Searching Market for user: %d" % (i)) search_result = self.image_market.search_learnware(user_info) single_result = search_result.get_single_results() multiple_result = search_result.get_multiple_results() print(f"search result of user{i}:") print( f"single model num: {len(single_result)}, max_score: {single_result[0].score}, min_score: {single_result[-1].score}" ) acc_list = [] for idx in range(len(all_learnwares)): learnware = all_learnwares[idx] loss, acc = evaluate(learnware, test_dataset) acc_list.append(acc) learnware = single_result[0].learnware best_loss, best_acc = evaluate(learnware, test_dataset) best_list.append(np.max(acc_list)) select_list.append(best_acc) avg_list.append(np.mean(acc_list)) improve_list.append((best_acc - np.mean(acc_list)) / np.mean(acc_list)) print(f"market mean accuracy: {np.mean(acc_list)}, market best accuracy: {np.max(acc_list)}") print( f"Top1-score: {single_result[0].score}, learnware_id: {single_result[0].learnware.id}, acc: {best_acc}" ) if len(multiple_result) > 0: mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares]) print(f"mixture_score: {multiple_result[0].score}, mixture_learnware: {mixture_id}") mixture_learnware_list = multiple_result[0].learnwares else: mixture_learnware_list = [single_result[0].learnware] # test reuse (job selector) reuse_job_selector = JobSelectorReuser(learnware_list=mixture_learnware_list, use_herding=False) job_loss, job_acc = evaluate(reuse_job_selector, test_dataset) job_selector_score_list.append(job_acc) print(f"mixture reuse accuracy (job selector): {job_acc}") # test reuse (ensemble) reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="vote_by_prob") ensemble_loss, ensemble_acc = evaluate(reuse_ensemble, test_dataset) ensemble_score_list.append(ensemble_acc) print(f"mixture reuse accuracy (ensemble): {ensemble_acc}\n") user_model_score_mat = [] pruning_score_mat = [] single_score_mat = [] for n_label, repeated in zip(self.n_labeled_list, self.repeated_list): user_model_score_list, reuse_pruning_score_list = [], [] if n_label > len(train_x): n_label = len(train_x) for _ in range(repeated): x_train, y_train = zip(*random.sample(list(zip(train_x, train_y)), k=n_label)) x_train = np.array(list(x_train)) y_train = np.array(list(y_train)) x_train = torch.from_numpy(x_train) y_train = torch.from_numpy(y_train) sampled_dataset = TensorDataset(x_train, y_train) mode_save_path = os.path.abspath(os.path.join(self.model_path, "model.pth")) model = ConvModel( channel=x_train.shape[1], im_size=(x_train.shape[2], x_train.shape[3]), n_random_features=10 ).to(device) train_model( model, sampled_dataset, sampled_dataset, mode_save_path, epochs=35, batch_size=128, device=device, verbose=False, ) model.load_state_dict(torch.load(mode_save_path)) _, user_model_acc = evaluate(model, test_dataset, distribution=True) user_model_score_list.append(user_model_acc) reuse_pruning = EnsemblePruningReuser( learnware_list=mixture_learnware_list, mode="classification" ) reuse_pruning.fit(x_train, y_train) _, pruning_acc = evaluate(reuse_pruning, test_dataset, distribution=False) reuse_pruning_score_list.append(pruning_acc) single_score_mat.append([best_acc] * repeated) user_model_score_mat.append(user_model_score_list) pruning_score_mat.append(reuse_pruning_score_list) print( f"user_label_num: {n_label}, user_acc: {np.mean(user_model_score_mat[-1])}, pruning_acc: {np.mean(pruning_score_mat[-1])}" ) logger.info(f"Saving Curves for User_{i}") user_curves_data = (single_score_mat, user_model_score_mat, pruning_score_mat) with open(os.path.join(self.curve_path, f"curve{str(i)}.pkl"), "wb") as f: pickle.dump(user_curves_data, f) logger.info( "Accuracy of selected learnware: %.3f +/- %.3f, Average performance: %.3f +/- %.3f, Best performance: %.3f +/- %.3f" % ( np.mean(select_list), np.std(select_list), np.mean(avg_list), np.std(avg_list), np.mean(best_list), np.std(best_list), ) ) logger.info("Average performance improvement: %.3f" % (np.mean(improve_list))) logger.info( "Average Job Selector Reuse Performance: %.3f +/- %.3f" % (np.mean(job_selector_score_list), np.std(job_selector_score_list)) ) logger.info( "Averaging Ensemble Reuse Performance: %.3f +/- %.3f" % (np.mean(ensemble_score_list), np.std(ensemble_score_list)) ) pruning_curves_data, user_model_curves_data = [], [] total_user_model_score_mat = [np.zeros(self.repeated_list[i]) for i in range(len(self.n_labeled_list))] total_pruning_score_mat = [np.zeros(self.repeated_list[i]) for i in range(len(self.n_labeled_list))] for user_idx in range(image_benchmark_config.user_num): with open(os.path.join(self.curve_path, f"curve{str(user_idx)}.pkl"), "rb") as f: user_curves_data = pickle.load(f) (single_score_mat, user_model_score_mat, pruning_score_mat) = user_curves_data for i in range(len(self.n_labeled_list)): total_user_model_score_mat[i] += 1 - np.array(user_model_score_mat[i]) / 100 total_pruning_score_mat[i] += 1 - np.array(pruning_score_mat[i]) / 100 for i in range(len(self.n_labeled_list)): total_user_model_score_mat[i] /= image_benchmark_config.user_num total_pruning_score_mat[i] /= image_benchmark_config.user_num user_model_curves_data.append( (np.mean(total_user_model_score_mat[i]), np.std(total_user_model_score_mat[i])) ) pruning_curves_data.append((np.mean(total_pruning_score_mat[i]), np.std(total_pruning_score_mat[i]))) self._plot_labeled_peformance_curves([user_model_curves_data, pruning_curves_data]) if __name__ == "__main__": fire.Fire(ImageDatasetWorkflow)