You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base.py 4.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import os
  2. import time
  3. import pandas
  4. import random
  5. import tempfile
  6. import numpy as np
  7. from learnware.client import LearnwareClient
  8. from learnware.logger import get_module_logger
  9. from learnware.market import instantiate_learnware_market
  10. from learnware.tests.benchmarks import LearnwareBenchmark
  11. from config import *
  12. from methods import *
  13. from utils import process_single_aug
  14. logger = get_module_logger("base_table", level="INFO")
  15. class TableWorkflow:
  16. def __init__(self, benchmark_config, name="easy", rebuild=False):
  17. self.root_path = os.path.abspath(os.path.join(__file__, ".."))
  18. self.result_path = os.path.join(self.root_path, "results")
  19. self.curves_result_path = os.path.join(self.root_path, "curves")
  20. os.makedirs(self.result_path, exist_ok=True)
  21. os.makedirs(self.curves_result_path, exist_ok=True)
  22. self._prepare_market(benchmark_config, name, rebuild)
  23. @staticmethod
  24. def _limited_data(method, test_info, loss_func):
  25. all_scores = []
  26. for subset in test_info["train_subsets"]:
  27. subset_scores = []
  28. for sample in subset:
  29. x_train, y_train = sample["x_train"], sample["y_train"]
  30. model = method(x_train, y_train, test_info)
  31. subset_scores.append(loss_func(model.predict(test_info["test_x"]), test_info["test_y"]))
  32. all_scores.append(subset_scores)
  33. return all_scores
  34. @staticmethod
  35. def get_train_subsets(idx, train_x, train_y):
  36. np.random.seed(idx)
  37. random.seed(idx)
  38. train_subsets = []
  39. for n_label, repeated in zip(n_labeled_list, n_repeat_list):
  40. train_subsets.append([])
  41. if n_label > len(train_x):
  42. n_label = len(train_x)
  43. for _ in range(repeated):
  44. x_train, y_train = zip(*random.sample(list(zip(train_x, train_y)), k=n_label))
  45. train_subsets[-1].append({"x_train": np.array(x_train), "y_train": np.array(list(y_train))})
  46. return train_subsets
  47. def _prepare_market(self, benchmark_config, name, rebuild):
  48. client = LearnwareClient()
  49. self.benchmark = LearnwareBenchmark().get_benchmark(benchmark_config)
  50. self.market = instantiate_learnware_market(market_id=self.benchmark.name, name=name, rebuild=rebuild)
  51. self.user_semantic = client.get_semantic_specification(self.benchmark.learnware_ids[0])
  52. self.user_semantic["Name"]["Values"] = ""
  53. if len(self.market) == 0 or rebuild == True:
  54. for learnware_id in self.benchmark.learnware_ids:
  55. with tempfile.TemporaryDirectory(prefix="table_benchmark_") as tempdir:
  56. zip_path = os.path.join(tempdir, f"{learnware_id}.zip")
  57. for i in range(20):
  58. try:
  59. semantic_spec = client.get_semantic_specification(learnware_id)
  60. client.download_learnware(learnware_id, zip_path)
  61. self.market.add_learnware(zip_path, semantic_spec)
  62. break
  63. except:
  64. time.sleep(1)
  65. continue
  66. def test_method(self, test_info, recorders, loss_func=loss_func_rmse):
  67. method_name_full = test_info["method_name"]
  68. method_name = method_name_full if method_name_full == "user_model" else "_".join(method_name_full.split("_")[1:])
  69. user, idx = test_info["user"], test_info["idx"]
  70. recorder = recorders[method_name_full]
  71. save_root_path = os.path.join(self.curves_result_path, user, f"{user}_{idx}")
  72. os.makedirs(save_root_path, exist_ok=True)
  73. save_path = os.path.join(save_root_path, f"{method_name}.json")
  74. if method_name_full == "hetero_single_aug":
  75. if test_info["force"] or recorder.should_test_method(user, idx, save_path):
  76. for learnware in test_info["learnwares"]:
  77. test_info["single_learnware"] = [learnware]
  78. scores = self._limited_data(test_methods[method_name_full], test_info, loss_func)
  79. recorder.record(user, scores)
  80. process_single_aug(user, idx, scores, recorders, save_root_path)
  81. recorder.save(save_path)
  82. else:
  83. process_single_aug(user, idx, recorder.data[user], recorders, save_root_path)
  84. else:
  85. if test_info["force"] or recorder.should_test_method(user, idx, save_path):
  86. scores = self._limited_data(test_methods[method_name_full], test_info, loss_func)
  87. recorder.record(user, scores)
  88. recorder.save(save_path)
  89. logger.info(f"Method {method_name} on {user}_{idx} finished")