|
|
|
@@ -2,6 +2,7 @@ import os |
|
|
|
import pickle |
|
|
|
import tempfile |
|
|
|
import zipfile |
|
|
|
import numpy as np |
|
|
|
from dataclasses import dataclass |
|
|
|
from typing import List, Optional, Tuple, Union |
|
|
|
|
|
|
|
@@ -21,7 +22,9 @@ class Benchmark: |
|
|
|
train_y_paths: Optional[List[str]] = None |
|
|
|
extra_info_path: Optional[str] = None |
|
|
|
|
|
|
|
def get_test_data(self, user_ids: Union[int, List[int]]): |
|
|
|
def get_test_data( |
|
|
|
self, user_ids: Union[int, List[int]] |
|
|
|
) -> Union[Tuple[np.ndarray, np.ndarray], List[Tuple[np.ndarray, np.ndarray]]]: |
|
|
|
raw_user_ids = user_ids |
|
|
|
if isinstance(user_ids, int): |
|
|
|
user_ids = [user_ids] |
|
|
|
@@ -41,7 +44,9 @@ class Benchmark: |
|
|
|
else: |
|
|
|
return ret |
|
|
|
|
|
|
|
def get_train_data(self, user_ids: Union[int, List[int]]): |
|
|
|
def get_train_data( |
|
|
|
self, user_ids: Union[int, List[int]] |
|
|
|
) -> Optional[Union[Tuple[np.ndarray, np.ndarray], List[Tuple[np.ndarray, np.ndarray]]]]: |
|
|
|
if self.train_X_paths is None or self.train_y_paths is None: |
|
|
|
return None |
|
|
|
|
|
|
|
@@ -98,7 +103,7 @@ class LearnwareBenchmark: |
|
|
|
else: |
|
|
|
return False |
|
|
|
|
|
|
|
def _download_data(self, download_path: str, save_path: str): |
|
|
|
def _download_data(self, download_path: str, save_path: str) -> None: |
|
|
|
"""Download data from backend |
|
|
|
|
|
|
|
Parameters |
|
|
|
@@ -128,7 +133,7 @@ class LearnwareBenchmark: |
|
|
|
""" |
|
|
|
cache_folder = os.path.join(C.cache_path, benchmark_config.name, f"{data_type}_data") |
|
|
|
if not self._check_cache_data_valid(benchmark_config, data_type): |
|
|
|
download_path = getattr(benchmark_config, f"{data_type}_data_path", None) |
|
|
|
download_path = getattr(benchmark_config, f"{data_type}_data_path") |
|
|
|
self._download_data(download_path, cache_folder) |
|
|
|
|
|
|
|
X_paths, y_paths = [], [] |
|
|
|
@@ -142,10 +147,15 @@ class LearnwareBenchmark: |
|
|
|
|
|
|
|
return X_paths, y_paths |
|
|
|
|
|
|
|
def get_benchmark(self, benchmark_config: Union[str, BenchmarkConfig]): |
|
|
|
def get_benchmark(self, benchmark_config: Union[str, BenchmarkConfig]) -> Benchmark: |
|
|
|
if isinstance(benchmark_config, str): |
|
|
|
benchmark_config = self.benchmark_configs[benchmark_config] |
|
|
|
|
|
|
|
if not isinstance(benchmark_config, BenchmarkConfig): |
|
|
|
raise ValueError( |
|
|
|
"benchmark_config must be a BenchmarkConfig object or a string in benchmark_configs.keys()!" |
|
|
|
) |
|
|
|
|
|
|
|
# Load test data |
|
|
|
test_X_paths, test_y_paths = self._load_cache_data(benchmark_config, "test") |
|
|
|
|
|
|
|
|