Browse Source

[MNT] enable benchmarks pass the mypy test

tags/v0.3.2
Gene 1 year ago
parent
commit
cd92ac2922
2 changed files with 17 additions and 7 deletions
  1. +15
    -5
      learnware/tests/benchmarks/__init__.py
  2. +2
    -2
      learnware/tests/benchmarks/config.py

+ 15
- 5
learnware/tests/benchmarks/__init__.py View File

@@ -2,6 +2,7 @@ import os
import pickle
import tempfile
import zipfile
import numpy as np
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

@@ -21,7 +22,9 @@ class Benchmark:
train_y_paths: Optional[List[str]] = None
extra_info_path: Optional[str] = None

def get_test_data(self, user_ids: Union[int, List[int]]):
def get_test_data(
self, user_ids: Union[int, List[int]]
) -> Union[Tuple[np.ndarray, np.ndarray], List[Tuple[np.ndarray, np.ndarray]]]:
raw_user_ids = user_ids
if isinstance(user_ids, int):
user_ids = [user_ids]
@@ -41,7 +44,9 @@ class Benchmark:
else:
return ret

def get_train_data(self, user_ids: Union[int, List[int]]):
def get_train_data(
self, user_ids: Union[int, List[int]]
) -> Optional[Union[Tuple[np.ndarray, np.ndarray], List[Tuple[np.ndarray, np.ndarray]]]]:
if self.train_X_paths is None or self.train_y_paths is None:
return None

@@ -98,7 +103,7 @@ class LearnwareBenchmark:
else:
return False

def _download_data(self, download_path: str, save_path: str):
def _download_data(self, download_path: str, save_path: str) -> None:
"""Download data from backend

Parameters
@@ -128,7 +133,7 @@ class LearnwareBenchmark:
"""
cache_folder = os.path.join(C.cache_path, benchmark_config.name, f"{data_type}_data")
if not self._check_cache_data_valid(benchmark_config, data_type):
download_path = getattr(benchmark_config, f"{data_type}_data_path", None)
download_path = getattr(benchmark_config, f"{data_type}_data_path")
self._download_data(download_path, cache_folder)

X_paths, y_paths = [], []
@@ -142,10 +147,15 @@ class LearnwareBenchmark:

return X_paths, y_paths

def get_benchmark(self, benchmark_config: Union[str, BenchmarkConfig]):
def get_benchmark(self, benchmark_config: Union[str, BenchmarkConfig]) -> Benchmark:
if isinstance(benchmark_config, str):
benchmark_config = self.benchmark_configs[benchmark_config]

if not isinstance(benchmark_config, BenchmarkConfig):
raise ValueError(
"benchmark_config must be a BenchmarkConfig object or a string in benchmark_configs.keys()!"
)

# Load test data
test_X_paths, test_y_paths = self._load_cache_data(benchmark_config, "test")



+ 2
- 2
learnware/tests/benchmarks/config.py View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List, Optional
from typing import List, Optional, Dict


@dataclass
@@ -12,4 +12,4 @@ class BenchmarkConfig:
extra_info_path: Optional[str] = None


benchmark_configs = {}
benchmark_configs: Dict[str, BenchmarkConfig] = {}

Loading…
Cancel
Save