Browse Source

[MNT] unify figure format

tags/v0.3.2
liuht 1 year ago
parent
commit
cf2b64d1b7
4 changed files with 48 additions and 43 deletions
  1. +2
    -1
      .gitignore
  2. +10
    -13
      examples/dataset_table_workflow/base.py
  3. +3
    -3
      examples/dataset_table_workflow/homo.py
  4. +33
    -26
      examples/dataset_table_workflow/utils.py

+ 2
- 1
.gitignore View File

@@ -44,4 +44,5 @@ tmp/
learnware_pool/
PFS/
data/
examples/results/
examples/results/
examples/*/results/

+ 10
- 13
examples/dataset_table_workflow/base.py View File

@@ -34,13 +34,13 @@ class TableWorkflow:
x_train, y_train = sample["x_train"], sample["y_train"]
model = method(x_train, y_train, test_info)
subset_scores.append(loss_func(model.predict(test_info["test_x"]), test_info["test_y"]))
all_scores.append(np.mean(subset_scores))
all_scores.append(subset_scores)
return all_scores
@staticmethod
def get_train_subsets(train_x, train_y):
np.random.seed(1)
random.seed(1)
def get_train_subsets(idx, train_x, train_y):
np.random.seed(idx)
random.seed(idx)
train_subsets = []
for n_label, repeated in zip(n_labeled_list, n_repeat_list):
train_subsets.append([])
@@ -82,24 +82,21 @@ class TableWorkflow:
os.makedirs(save_root_path, exist_ok=True)
save_path = os.path.join(save_root_path, f"{method_name}.json")
if method_name == "single_aug":
if method_name_full == "hetero_single_aug":
if test_info["force"] or recorder.should_test_method(user, idx, save_path):
for learnware in test_info["learnwares"]:
test_info["single_learnware"] = [learnware]
scores = self._limited_data(test_methods[method_name_full], test_info, loss_func)
recorder.record(user, idx, scores)
recorder.record(user, scores)

process_single_aug(user, idx, scores, recorders, save_root_path)
recorder.save(save_path)
logger.info(f"Method {method_name} on {user}_{idx} finished")
else:
process_single_aug(user, idx, recorder.data[user][str(idx)], recorders, save_root_path)
logger.info(f"Method {method_name} on {user}_{idx} already exists")
process_single_aug(user, idx, recorder.data[user], recorders, save_root_path)
else:
if test_info["force"] or recorder.should_test_method(user, idx, save_path):
scores = self._limited_data(test_methods[method_name_full], test_info, loss_func)
recorder.record(user, idx, scores)
recorder.record(user, scores)
recorder.save(save_path)
logger.info(f"Method {method_name} on {user}_{idx} finished")
else:
logger.info(f"Method {method_name} on {user}_{idx} already exists")

logger.info(f"Method {method_name} on {user}_{idx} finished")

+ 3
- 3
examples/dataset_table_workflow/homo.py View File

@@ -117,8 +117,8 @@ class CorporacionDatasetWorkflow(TableWorkflow):
def labeled_homo_table_example(self):
logger.info("Total Item: %d" % (len(self.market)))
methods = ["user_model", "homo_single_aug", "homo_multiple_aug", "homo_multiple_avg", "homo_ensemble_pruning"]
recorders = {method: Recorder() for method in methods}
methods_to_retest = []
recorders = {method: Recorder() for method in methods}

user = self.benchmark.name
for idx in range(self.benchmark.user_num):
@@ -127,7 +127,7 @@ class CorporacionDatasetWorkflow(TableWorkflow):
train_x, train_y = self.benchmark.get_train_data(user_ids=idx)
train_x, train_y = train_x.values, train_y.values
train_subsets = self.get_train_subsets(train_x, train_y)
train_subsets = self.get_train_subsets(idx, train_x, train_y)

user_stat_spec = generate_stat_spec(type="table", X=test_x)
user_info = BaseUserInfo(
@@ -155,7 +155,7 @@ class CorporacionDatasetWorkflow(TableWorkflow):
common_config = {"learnwares": mixture_learnware_list}
method_configs = {
"user_model": {"dataset": self.benchmark.name, "model_type": "lgb"},
"homo_single_aug": {"learnwares": [single_result[0].learnware]},
"homo_single_aug": {"single_learnware": [single_result[0].learnware]},
"homo_multiple_aug": common_config,
"homo_multiple_avg": common_config,
"homo_ensemble_pruning": common_config


+ 33
- 26
examples/dataset_table_workflow/utils.py View File

@@ -14,18 +14,15 @@ logger = get_module_logger("base_table", level="INFO")
class Recorder:
def __init__(self, headers=["Mean", "Std Dev"], formats=["{:.2f}", "{:.2f}"]):
assert len(headers) == len(formats), "Headers and formats length must match."
self.data = defaultdict(lambda: defaultdict(list))
self.data = defaultdict(list)
self.headers = headers
self.formats = formats

def record(self, user, idx, scores):
self.data[user][idx].append(scores)
def record(self, user, scores):
self.data[user].append(scores)

def get_performance_data(self, user):
if user in self.data:
return [idx_scores for idx_scores in self.data[user].values()]
else:
return []
return self.data.get(user, [])

def save(self, path):
with open(path, "w") as f:
@@ -38,7 +35,7 @@ class Recorder:
def should_test_method(self, user, idx, path):
if os.path.exists(path):
self.load(path)
return user not in self.data or str(idx) not in self.data[user]
return user not in self.data or idx > len(self.data[user]) - 1
return True


@@ -53,7 +50,7 @@ def process_single_aug(user, idx, scores, recorders, root_path):

for method_name, scores in zip(["select_score", "mean_score", "oracle_score"],
[select_scores, mean_scores, oracle_scores]):
recorders[method_name].record(user, idx, scores)
recorders[method_name].record(user, scores)
save_path = os.path.join(root_path, f"{method_name}_performance.json")
recorders[method_name].save(save_path)
except Exception as e:
@@ -82,33 +79,43 @@ def analyze_performance(user, recorders):

def plot_performance_curves(user, recorders, task, n_labeled_list):
plt.figure(figsize=(10, 6))
plt.xticks(range(len(n_labeled_list)), n_labeled_list)
for method, recorder in recorders.items():
if method == "hetero_single_aug":
continue
user_data = recorder.get_performance_data(user)
if user_data:
scores_array = np.array([np.array(lst) for lst in user_data])
mean_scores = np.squeeze(np.mean(scores_array, axis=0))
std_scores = np.squeeze(np.std(scores_array, axis=0))
scores_array = recorder.get_performance_data(user)
if scores_array:
mean_curve, std_curve = [], []
for i in range(len(n_labeled_list)):
sub_scores_array = np.vstack([lst[i] for lst in scores_array])
sub_scores_mean = np.squeeze(np.mean(sub_scores_array, axis=0))
mean_curve.append(np.mean(sub_scores_mean))
std_curve.append(np.std(sub_scores_mean))
mean_curve = np.array(mean_curve)
std_curve = np.array(std_curve)

method_plot = '_'.join(method.split('_')[1:]) if method not in ['user_model', 'oracle_score', 'select_score', 'mean_score'] else method
style = styles.get(method_plot, {"color": "black", "linestyle": "-"})
plt.plot(range(len(n_labeled_list)), mean_scores, label=labels.get(method_plot), **style)

std_scale = 0.2 if task == "Hetero" else 0.5
plt.fill_between(range(len(n_labeled_list)), mean_scores - std_scale * std_scores, mean_scores + std_scale * std_scores, color=style["color"], alpha=0.2)

plt.xticks(range(len(n_labeled_list)), n_labeled_list)
plt.xlabel('Sample Size')
plt.ylabel('RMSE')
plt.title(f'Table {task} Limited Labeled Data')
plt.legend()
plt.plot(mean_curve, label=labels.get(method_plot), **style)

plt.fill_between(
range(len(mean_curve)),
mean_curve - std_curve,
mean_curve + std_curve,
color=style["color"],
alpha=0.2
)

plt.xlabel("Amount of Labeled User Data", fontsize=14)
plt.ylabel("RMSE", fontsize=14)
plt.title(f"Results on Homo Table Experimental Scenario", fontsize=16)
plt.legend(fontsize=14)
plt.tight_layout()
root_path = os.path.abspath(os.path.join(__file__, ".."))
fig_path = os.path.join(root_path, "results", "figs")
os.makedirs(fig_path, exist_ok=True)
plt.savefig(os.path.join(fig_path, f"{user}_labeled_{list(recorders.keys())}.svg"), bbox_inches="tight", dpi=700)
plt.savefig(os.path.join(fig_path, f"{task}_labeled_curves.svg"), bbox_inches="tight", dpi=700)

Loading…
Cancel
Save