You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base.py 5.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import os
  2. import random
  3. import tempfile
  4. import time
  5. import traceback
  6. import numpy as np
  7. import requests
  8. from config import market_mapping_params
  9. from methods import loss_func_rmse, test_methods
  10. from utils import set_seed
  11. from learnware.client import LearnwareClient
  12. from learnware.logger import get_module_logger
  13. from learnware.market import instantiate_learnware_market
  14. from learnware.reuse.utils import fill_data_with_mean
  15. from learnware.tests.benchmarks import LearnwareBenchmark
  16. logger = get_module_logger("base_table", level="INFO")
  17. class TableWorkflow:
  18. def __init__(self, benchmark_config, name="easy", rebuild=False, retrain=False):
  19. self.root_path = os.path.abspath(os.path.join(__file__, ".."))
  20. self.result_path = os.path.join(self.root_path, "results")
  21. self.curves_result_path = os.path.join(self.result_path, "curves")
  22. os.makedirs(self.result_path, exist_ok=True)
  23. os.makedirs(self.curves_result_path, exist_ok=True)
  24. self._prepare_market(benchmark_config, name, rebuild, retrain)
  25. @staticmethod
  26. def _limited_data(method, test_info, loss_func):
  27. def subset_generator():
  28. for subset in test_info["train_subsets"]:
  29. yield subset
  30. all_scores = []
  31. for subset in subset_generator():
  32. subset_scores = []
  33. for sample in subset:
  34. x_train, y_train = sample["x_train"], sample["y_train"]
  35. model = method(x_train, y_train, test_info)
  36. subset_scores.append(loss_func(model.predict(test_info["test_x"]), test_info["test_y"]))
  37. all_scores.append(subset_scores)
  38. return all_scores
  39. @staticmethod
  40. def get_train_subsets(n_labeled_list, n_repeat_list, train_x, train_y):
  41. np.random.seed(1)
  42. random.seed(1)
  43. train_x = fill_data_with_mean(train_x)
  44. train_subsets = []
  45. for n_label, repeated in zip(n_labeled_list, n_repeat_list):
  46. train_subsets.append([])
  47. if n_label > len(train_x):
  48. n_label = len(train_x)
  49. for _ in range(repeated):
  50. subset_idxs = np.random.choice(len(train_x), n_label, replace=False)
  51. train_subsets[-1].append(
  52. {"x_train": np.array(train_x[subset_idxs]), "y_train": np.array(train_y[subset_idxs])}
  53. )
  54. return train_subsets
  55. def _prepare_market(self, benchmark_config, name, rebuild, retrain):
  56. client = LearnwareClient()
  57. self.benchmark = LearnwareBenchmark().get_benchmark(benchmark_config)
  58. self.market = instantiate_learnware_market(
  59. market_id=self.benchmark.name,
  60. name=name,
  61. rebuild=rebuild,
  62. organizer_kwargs={
  63. "auto_update": True,
  64. "auto_update_limit": len(self.benchmark.learnware_ids),
  65. **market_mapping_params,
  66. }
  67. if retrain
  68. else None,
  69. )
  70. self.user_semantic = client.get_semantic_specification(self.benchmark.learnware_ids[0])
  71. self.user_semantic["Name"]["Values"] = ""
  72. if len(self.market) == 0 or rebuild is True:
  73. if retrain:
  74. set_seed(0)
  75. for learnware_id in self.benchmark.learnware_ids:
  76. with tempfile.TemporaryDirectory(prefix="table_benchmark_") as tempdir:
  77. zip_path = os.path.join(tempdir, f"{learnware_id}.zip")
  78. for i in range(20):
  79. try:
  80. semantic_spec = client.get_semantic_specification(learnware_id)
  81. client.download_learnware(learnware_id, zip_path)
  82. self.market.add_learnware(zip_path, semantic_spec)
  83. break
  84. except (requests.exceptions.RequestException, IOError, Exception) as e:
  85. logger.info(
  86. f"An error occurred when downloading {learnware_id}: {e}\n{traceback.format_exc()}, retrying..."
  87. )
  88. time.sleep(1)
  89. continue
  90. def test_method(self, test_info, recorders, loss_func=loss_func_rmse):
  91. method_name_full = test_info["method_name"]
  92. method_name = (
  93. method_name_full if method_name_full == "user_model" else "_".join(method_name_full.split("_")[1:])
  94. )
  95. method = test_methods[method_name_full]
  96. user, idx = test_info["user"], test_info["idx"]
  97. recorder = recorders[method_name_full]
  98. save_root_path = os.path.join(self.curves_result_path, f"{user}/{user}_{idx}")
  99. os.makedirs(save_root_path, exist_ok=True)
  100. save_path = os.path.join(save_root_path, f"{method_name}.json")
  101. if recorder.should_test_method(user, idx, save_path):
  102. scores = self._limited_data(method, test_info, loss_func)
  103. recorder.record(user, scores)
  104. recorder.save(save_path)
  105. logger.info(f"Method {method_name} on {user}_{idx} finished")