Browse Source

[MNT] modify details

tags/v0.3.2
Gene 1 year ago
parent
commit
032b1e76ba
1 changed files with 57 additions and 72 deletions
  1. +57
    -72
      examples/dataset_text_workflow/main.py

+ 57
- 72
examples/dataset_text_workflow/main.py View File

@@ -6,6 +6,7 @@ import pickle
import tempfile import tempfile
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
from sklearn.naive_bayes import MultinomialNB from sklearn.naive_bayes import MultinomialNB
from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.feature_extraction.text import TfidfVectorizer


@@ -20,39 +21,60 @@ from config import text_benchmark_config
logger = get_module_logger("text_workflow", level="INFO") logger = get_module_logger("text_workflow", level="INFO")




def train(X, y):
# Train Uploaders' models
vectorizer = TfidfVectorizer(stop_words="english")
X_tfidf = vectorizer.fit_transform(X)

clf = MultinomialNB(alpha=0.1)
clf.fit(X_tfidf, y)

return vectorizer, clf

class TextDatasetWorkflow:
@staticmethod
def _train_model(X, y):
vectorizer = TfidfVectorizer(stop_words="english")
X_tfidf = vectorizer.fit_transform(X)
clf = MultinomialNB(alpha=0.1)
clf.fit(X_tfidf, y)
return vectorizer, clf

@staticmethod
def _eval_prediction(pred_y, target_y):
if not isinstance(pred_y, np.ndarray):
pred_y = pred_y.detach().cpu().numpy()

pred_y = np.array(pred_y) if len(pred_y.shape) == 1 else np.argmax(pred_y, 1)
target_y = np.array(target_y)
return accuracy_score(target_y, pred_y)


def eval_prediction(pred_y, target_y):
if not isinstance(pred_y, np.ndarray):
pred_y = pred_y.detach().cpu().numpy()
if len(pred_y.shape) == 1:
predicted = np.array(pred_y)
else:
predicted = np.argmax(pred_y, 1)
annos = np.array(target_y)
def _plot_labeled_peformance_curves(self, all_user_curves_data):
plt.figure(figsize=(10, 6))
plt.xticks(range(len(self.n_labeled_list)), self.n_labeled_list)


total = predicted.shape[0]
correct = (predicted == annos).sum().item()
styles = [
{"color": "navy", "linestyle": "-", "marker": "o"},
{"color": "magenta", "linestyle": "-.", "marker": "d"},
]
labels = ["User Model", "Multiple Learnware Reuse (EnsemblePrune)"]


return correct / total
user_mat, pruning_mat = all_user_curves_data
user_mat, pruning_mat = np.array(user_mat), np.array(pruning_mat)
for mat, style, label in zip([user_mat, pruning_mat], styles, labels):
mean_curve, std_curve = 1 - np.mean(mat, axis=0), np.std(mat, axis=0)
plt.plot(mean_curve, **style, label=label)
plt.fill_between(
range(len(mean_curve)),
mean_curve - 0.5 * std_curve,
mean_curve + 0.5 * std_curve,
color=style["color"],
alpha=0.2,
)


plt.xlabel("Labeled Data Size")
plt.ylabel("1 - Accuracy")
plt.title(f"Text Limited Labeled Data")
plt.legend()
plt.tight_layout()
plt.savefig(os.path.join(self.fig_path, "text_labeled_curves.png"), bbox_inches="tight", dpi=700)


class TextDatasetWorkflow:
def prepare_market(self, rebuild=False):
def _prepare_market(self, rebuild=False):
client = LearnwareClient() client = LearnwareClient()
self.text_benchmark = LearnwareBenchmark().get_benchmark(text_benchmark_config) self.text_benchmark = LearnwareBenchmark().get_benchmark(text_benchmark_config)
self.text_market = instantiate_learnware_market(market_id=self.text_benchmark.name, rebuild=rebuild) self.text_market = instantiate_learnware_market(market_id=self.text_benchmark.name, rebuild=rebuild)
self.user_semantic = client.get_semantic_specification(self.text_benchmark.learnware_ids[0]) self.user_semantic = client.get_semantic_specification(self.text_benchmark.learnware_ids[0])
self.user_semantic['Name']['Values'] = ''
self.user_semantic["Name"]["Values"] = ""


if len(self.text_market) == 0 or rebuild == True: if len(self.text_market) == 0 or rebuild == True:
for learnware_id in self.text_benchmark.learnware_ids: for learnware_id in self.text_benchmark.learnware_ids:
@@ -71,7 +93,7 @@ class TextDatasetWorkflow:
logger.info("Total Item: %d" % (len(self.text_market))) logger.info("Total Item: %d" % (len(self.text_market)))


def test_unlabeled(self, rebuild=False): def test_unlabeled(self, rebuild=False):
self.prepare_market(rebuild)
self._prepare_market(rebuild)


select_list = [] select_list = []
avg_list = [] avg_list = []
@@ -104,12 +126,12 @@ class TextDatasetWorkflow:
for idx in range(len(all_learnwares)): for idx in range(len(all_learnwares)):
learnware = all_learnwares[idx] learnware = all_learnwares[idx]
pred_y = learnware.predict(user_data) pred_y = learnware.predict(user_data)
acc = eval_prediction(pred_y, user_label)
acc = self._eval_prediction(pred_y, user_label)
acc_list.append(acc) acc_list.append(acc)


learnware = single_result[0].learnware learnware = single_result[0].learnware
pred_y = learnware.predict(user_data) pred_y = learnware.predict(user_data)
best_acc = eval_prediction(pred_y, user_label)
best_acc = self._eval_prediction(pred_y, user_label)
best_list.append(np.max(acc_list)) best_list.append(np.max(acc_list))
select_list.append(best_acc) select_list.append(best_acc)
avg_list.append(np.mean(acc_list)) avg_list.append(np.mean(acc_list))
@@ -129,18 +151,16 @@ class TextDatasetWorkflow:
# test reuse (job selector) # test reuse (job selector)
reuse_baseline = JobSelectorReuser(learnware_list=mixture_learnware_list, herding_num=100) reuse_baseline = JobSelectorReuser(learnware_list=mixture_learnware_list, herding_num=100)
reuse_predict = reuse_baseline.predict(user_data=user_data) reuse_predict = reuse_baseline.predict(user_data=user_data)
reuse_score = eval_prediction(reuse_predict, user_label)
reuse_score = self._eval_prediction(reuse_predict, user_label)
job_selector_score_list.append(reuse_score) job_selector_score_list.append(reuse_score)
print(f"mixture reuse accuracy (job selector): {reuse_score}") print(f"mixture reuse accuracy (job selector): {reuse_score}")


# test reuse (ensemble) # test reuse (ensemble)
# be careful with the ensemble mode
reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="vote_by_label") reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="vote_by_label")
ensemble_predict_y = reuse_ensemble.predict(user_data=user_data) ensemble_predict_y = reuse_ensemble.predict(user_data=user_data)
ensemble_score = eval_prediction(ensemble_predict_y, user_label)
ensemble_score = self._eval_prediction(ensemble_predict_y, user_label)
ensemble_score_list.append(ensemble_score) ensemble_score_list.append(ensemble_score)
print(f"mixture reuse accuracy (ensemble): {ensemble_score}")
print("\n")
print(f"mixture reuse accuracy (ensemble): {ensemble_score}\n")


logger.info( logger.info(
"Accuracy of selected learnware: %.3f +/- %.3f, Average performance: %.3f +/- %.3f, Best performance: %.3f +/- %.3f" "Accuracy of selected learnware: %.3f +/- %.3f, Average performance: %.3f +/- %.3f, Best performance: %.3f +/- %.3f"
@@ -171,7 +191,7 @@ class TextDatasetWorkflow:
self.curve_path = os.path.join(self.root_path, "curves") self.curve_path = os.path.join(self.root_path, "curves")


if train_flag: if train_flag:
self.prepare_market(rebuild)
self._prepare_market(rebuild)
os.makedirs(self.fig_path, exist_ok=True) os.makedirs(self.fig_path, exist_ok=True)
os.makedirs(self.curve_path, exist_ok=True) os.makedirs(self.curve_path, exist_ok=True)


@@ -198,8 +218,7 @@ class TextDatasetWorkflow:


learnware = single_result[0].learnware learnware = single_result[0].learnware
pred_y = learnware.predict(test_x) pred_y = learnware.predict(test_x)
best_acc = eval_prediction(pred_y, test_y)

best_acc = self._eval_prediction(pred_y, test_y)
print(f"search result of user_{i}:") print(f"search result of user_{i}:")
print( print(
f"single model num: {len(single_result)}, max_score: {single_result[0].score}, min_score: {single_result[-1].score}, single model acc: {best_acc}" f"single model num: {len(single_result)}, max_score: {single_result[0].score}, min_score: {single_result[-1].score}, single model acc: {best_acc}"
@@ -218,14 +237,13 @@ class TextDatasetWorkflow:
if n_label > len(train_x): if n_label > len(train_x):
n_label = len(train_x) n_label = len(train_x)
for _ in range(repeated): for _ in range(repeated):
# x_train, y_train = train_x[:n_label], train_y[:n_label]
x_train, y_train = zip(*random.sample(list(zip(train_x, train_y)), k=n_label)) x_train, y_train = zip(*random.sample(list(zip(train_x, train_y)), k=n_label))
x_train = list(x_train) x_train = list(x_train)
y_train = np.array(list(y_train)) y_train = np.array(list(y_train))


modelv, modell = train(x_train, y_train)
modelv, modell = self._train_model(x_train, y_train)
user_model_predict_y = modell.predict(modelv.transform(test_x)) user_model_predict_y = modell.predict(modelv.transform(test_x))
user_model_score = eval_prediction(user_model_predict_y, test_y)
user_model_score = self._eval_prediction(user_model_predict_y, test_y)
user_model_score_list.append(user_model_score) user_model_score_list.append(user_model_score)


reuse_pruning = EnsemblePruningReuser( reuse_pruning = EnsemblePruningReuser(
@@ -233,7 +251,7 @@ class TextDatasetWorkflow:
) )
reuse_pruning.fit(x_train, y_train) reuse_pruning.fit(x_train, y_train)
reuse_pruning_predict_y = reuse_pruning.predict(user_data=test_x) reuse_pruning_predict_y = reuse_pruning.predict(user_data=test_x)
reuse_pruning_score = eval_prediction(reuse_pruning_predict_y, test_y)
reuse_pruning_score = self._eval_prediction(reuse_pruning_predict_y, test_y)
reuse_pruning_score_list.append(reuse_pruning_score) reuse_pruning_score_list.append(reuse_pruning_score)


single_score_mat.append([best_acc] * repeated) single_score_mat.append([best_acc] * repeated)
@@ -262,39 +280,6 @@ class TextDatasetWorkflow:
pruning_curves_data.append(pruning_score_mat[:6]) pruning_curves_data.append(pruning_score_mat[:6])
self._plot_labeled_peformance_curves([user_model_curves_data, pruning_curves_data]) self._plot_labeled_peformance_curves([user_model_curves_data, pruning_curves_data])


def _plot_labeled_peformance_curves(self, all_user_curves_data):
plt.figure(figsize=(10, 6))
plt.xticks(range(len(self.n_labeled_list)), self.n_labeled_list)

styles = [
# {"color": "orange", "linestyle": "--", "marker": "s"},
{"color": "navy", "linestyle": "-", "marker": "o"},
{"color": "magenta", "linestyle": "-.", "marker": "d"},
]

# labels = ["Single Learnware Reuse", "User Model", "Multiple Learnware Reuse (EnsemblePrune)"]
labels = ["User Model", "Multiple Learnware Reuse (EnsemblePrune)"]

user_mat, pruning_mat = all_user_curves_data
user_mat, pruning_mat = np.array(user_mat), np.array(pruning_mat)
for mat, style, label in zip([user_mat, pruning_mat], styles, labels):
mean_curve, std_curve = 1 - np.mean(mat, axis=0), np.std(mat, axis=0)
plt.plot(mean_curve, **style, label=label)
plt.fill_between(
range(len(mean_curve)),
mean_curve - 0.5 * std_curve,
mean_curve + 0.5 * std_curve,
color=style["color"],
alpha=0.2,
)

plt.xlabel("Labeled Data Size")
plt.ylabel("1 - Accuracy")
plt.title(f"Text Limited Labeled Data")
plt.legend()
plt.tight_layout()
plt.savefig(os.path.join(self.fig_path, "text_labeled_curves.png"), bbox_inches="tight", dpi=700)



if __name__ == "__main__": if __name__ == "__main__":
fire.Fire(TextDatasetWorkflow) fire.Fire(TextDatasetWorkflow)

Loading…
Cancel
Save