diff --git a/examples/dataset_text_workflow/main.py b/examples/dataset_text_workflow/main.py index 4b13730..603f96f 100644 --- a/examples/dataset_text_workflow/main.py +++ b/examples/dataset_text_workflow/main.py @@ -6,6 +6,7 @@ import pickle import tempfile import numpy as np import matplotlib.pyplot as plt +from sklearn.metrics import accuracy_score from sklearn.naive_bayes import MultinomialNB from sklearn.feature_extraction.text import TfidfVectorizer @@ -20,39 +21,60 @@ from config import text_benchmark_config logger = get_module_logger("text_workflow", level="INFO") -def train(X, y): - # Train Uploaders' models - vectorizer = TfidfVectorizer(stop_words="english") - X_tfidf = vectorizer.fit_transform(X) - - clf = MultinomialNB(alpha=0.1) - clf.fit(X_tfidf, y) - - return vectorizer, clf - +class TextDatasetWorkflow: + @staticmethod + def _train_model(X, y): + vectorizer = TfidfVectorizer(stop_words="english") + X_tfidf = vectorizer.fit_transform(X) + clf = MultinomialNB(alpha=0.1) + clf.fit(X_tfidf, y) + return vectorizer, clf + + @staticmethod + def _eval_prediction(pred_y, target_y): + if not isinstance(pred_y, np.ndarray): + pred_y = pred_y.detach().cpu().numpy() + + pred_y = np.array(pred_y) if len(pred_y.shape) == 1 else np.argmax(pred_y, 1) + target_y = np.array(target_y) + return accuracy_score(target_y, pred_y) -def eval_prediction(pred_y, target_y): - if not isinstance(pred_y, np.ndarray): - pred_y = pred_y.detach().cpu().numpy() - if len(pred_y.shape) == 1: - predicted = np.array(pred_y) - else: - predicted = np.argmax(pred_y, 1) - annos = np.array(target_y) + def _plot_labeled_peformance_curves(self, all_user_curves_data): + plt.figure(figsize=(10, 6)) + plt.xticks(range(len(self.n_labeled_list)), self.n_labeled_list) - total = predicted.shape[0] - correct = (predicted == annos).sum().item() + styles = [ + {"color": "navy", "linestyle": "-", "marker": "o"}, + {"color": "magenta", "linestyle": "-.", "marker": "d"}, + ] + labels = ["User Model", "Multiple Learnware Reuse (EnsemblePrune)"] - return correct / total + user_mat, pruning_mat = all_user_curves_data + user_mat, pruning_mat = np.array(user_mat), np.array(pruning_mat) + for mat, style, label in zip([user_mat, pruning_mat], styles, labels): + mean_curve, std_curve = 1 - np.mean(mat, axis=0), np.std(mat, axis=0) + plt.plot(mean_curve, **style, label=label) + plt.fill_between( + range(len(mean_curve)), + mean_curve - 0.5 * std_curve, + mean_curve + 0.5 * std_curve, + color=style["color"], + alpha=0.2, + ) + plt.xlabel("Labeled Data Size") + plt.ylabel("1 - Accuracy") + plt.title(f"Text Limited Labeled Data") + plt.legend() + plt.tight_layout() + plt.savefig(os.path.join(self.fig_path, "text_labeled_curves.png"), bbox_inches="tight", dpi=700) -class TextDatasetWorkflow: - def prepare_market(self, rebuild=False): + def _prepare_market(self, rebuild=False): client = LearnwareClient() self.text_benchmark = LearnwareBenchmark().get_benchmark(text_benchmark_config) self.text_market = instantiate_learnware_market(market_id=self.text_benchmark.name, rebuild=rebuild) self.user_semantic = client.get_semantic_specification(self.text_benchmark.learnware_ids[0]) - self.user_semantic['Name']['Values'] = '' + self.user_semantic["Name"]["Values"] = "" if len(self.text_market) == 0 or rebuild == True: for learnware_id in self.text_benchmark.learnware_ids: @@ -71,7 +93,7 @@ class TextDatasetWorkflow: logger.info("Total Item: %d" % (len(self.text_market))) def test_unlabeled(self, rebuild=False): - self.prepare_market(rebuild) + self._prepare_market(rebuild) select_list = [] avg_list = [] @@ -104,12 +126,12 @@ class TextDatasetWorkflow: for idx in range(len(all_learnwares)): learnware = all_learnwares[idx] pred_y = learnware.predict(user_data) - acc = eval_prediction(pred_y, user_label) + acc = self._eval_prediction(pred_y, user_label) acc_list.append(acc) learnware = single_result[0].learnware pred_y = learnware.predict(user_data) - best_acc = eval_prediction(pred_y, user_label) + best_acc = self._eval_prediction(pred_y, user_label) best_list.append(np.max(acc_list)) select_list.append(best_acc) avg_list.append(np.mean(acc_list)) @@ -129,18 +151,16 @@ class TextDatasetWorkflow: # test reuse (job selector) reuse_baseline = JobSelectorReuser(learnware_list=mixture_learnware_list, herding_num=100) reuse_predict = reuse_baseline.predict(user_data=user_data) - reuse_score = eval_prediction(reuse_predict, user_label) + reuse_score = self._eval_prediction(reuse_predict, user_label) job_selector_score_list.append(reuse_score) print(f"mixture reuse accuracy (job selector): {reuse_score}") # test reuse (ensemble) - # be careful with the ensemble mode reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="vote_by_label") ensemble_predict_y = reuse_ensemble.predict(user_data=user_data) - ensemble_score = eval_prediction(ensemble_predict_y, user_label) + ensemble_score = self._eval_prediction(ensemble_predict_y, user_label) ensemble_score_list.append(ensemble_score) - print(f"mixture reuse accuracy (ensemble): {ensemble_score}") - print("\n") + print(f"mixture reuse accuracy (ensemble): {ensemble_score}\n") logger.info( "Accuracy of selected learnware: %.3f +/- %.3f, Average performance: %.3f +/- %.3f, Best performance: %.3f +/- %.3f" @@ -171,7 +191,7 @@ class TextDatasetWorkflow: self.curve_path = os.path.join(self.root_path, "curves") if train_flag: - self.prepare_market(rebuild) + self._prepare_market(rebuild) os.makedirs(self.fig_path, exist_ok=True) os.makedirs(self.curve_path, exist_ok=True) @@ -198,8 +218,7 @@ class TextDatasetWorkflow: learnware = single_result[0].learnware pred_y = learnware.predict(test_x) - best_acc = eval_prediction(pred_y, test_y) - + best_acc = self._eval_prediction(pred_y, test_y) print(f"search result of user_{i}:") print( f"single model num: {len(single_result)}, max_score: {single_result[0].score}, min_score: {single_result[-1].score}, single model acc: {best_acc}" @@ -218,14 +237,13 @@ class TextDatasetWorkflow: if n_label > len(train_x): n_label = len(train_x) for _ in range(repeated): - # x_train, y_train = train_x[:n_label], train_y[:n_label] x_train, y_train = zip(*random.sample(list(zip(train_x, train_y)), k=n_label)) x_train = list(x_train) y_train = np.array(list(y_train)) - modelv, modell = train(x_train, y_train) + modelv, modell = self._train_model(x_train, y_train) user_model_predict_y = modell.predict(modelv.transform(test_x)) - user_model_score = eval_prediction(user_model_predict_y, test_y) + user_model_score = self._eval_prediction(user_model_predict_y, test_y) user_model_score_list.append(user_model_score) reuse_pruning = EnsemblePruningReuser( @@ -233,7 +251,7 @@ class TextDatasetWorkflow: ) reuse_pruning.fit(x_train, y_train) reuse_pruning_predict_y = reuse_pruning.predict(user_data=test_x) - reuse_pruning_score = eval_prediction(reuse_pruning_predict_y, test_y) + reuse_pruning_score = self._eval_prediction(reuse_pruning_predict_y, test_y) reuse_pruning_score_list.append(reuse_pruning_score) single_score_mat.append([best_acc] * repeated) @@ -262,39 +280,6 @@ class TextDatasetWorkflow: pruning_curves_data.append(pruning_score_mat[:6]) self._plot_labeled_peformance_curves([user_model_curves_data, pruning_curves_data]) - def _plot_labeled_peformance_curves(self, all_user_curves_data): - plt.figure(figsize=(10, 6)) - plt.xticks(range(len(self.n_labeled_list)), self.n_labeled_list) - - styles = [ - # {"color": "orange", "linestyle": "--", "marker": "s"}, - {"color": "navy", "linestyle": "-", "marker": "o"}, - {"color": "magenta", "linestyle": "-.", "marker": "d"}, - ] - - # labels = ["Single Learnware Reuse", "User Model", "Multiple Learnware Reuse (EnsemblePrune)"] - labels = ["User Model", "Multiple Learnware Reuse (EnsemblePrune)"] - - user_mat, pruning_mat = all_user_curves_data - user_mat, pruning_mat = np.array(user_mat), np.array(pruning_mat) - for mat, style, label in zip([user_mat, pruning_mat], styles, labels): - mean_curve, std_curve = 1 - np.mean(mat, axis=0), np.std(mat, axis=0) - plt.plot(mean_curve, **style, label=label) - plt.fill_between( - range(len(mean_curve)), - mean_curve - 0.5 * std_curve, - mean_curve + 0.5 * std_curve, - color=style["color"], - alpha=0.2, - ) - - plt.xlabel("Labeled Data Size") - plt.ylabel("1 - Accuracy") - plt.title(f"Text Limited Labeled Data") - plt.legend() - plt.tight_layout() - plt.savefig(os.path.join(self.fig_path, "text_labeled_curves.png"), bbox_inches="tight", dpi=700) - if __name__ == "__main__": fire.Fire(TextDatasetWorkflow)