From cd92ac2922db800014734caf1d13379b0aac4705 Mon Sep 17 00:00:00 2001 From: Gene Date: Fri, 12 Jan 2024 14:53:31 +0800 Subject: [PATCH] [MNT] enable benchmarks pass the mypy test --- learnware/tests/benchmarks/__init__.py | 20 +++++++++++++++----- learnware/tests/benchmarks/config.py | 4 ++-- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/learnware/tests/benchmarks/__init__.py b/learnware/tests/benchmarks/__init__.py index 523a910..05609f9 100644 --- a/learnware/tests/benchmarks/__init__.py +++ b/learnware/tests/benchmarks/__init__.py @@ -2,6 +2,7 @@ import os import pickle import tempfile import zipfile +import numpy as np from dataclasses import dataclass from typing import List, Optional, Tuple, Union @@ -21,7 +22,9 @@ class Benchmark: train_y_paths: Optional[List[str]] = None extra_info_path: Optional[str] = None - def get_test_data(self, user_ids: Union[int, List[int]]): + def get_test_data( + self, user_ids: Union[int, List[int]] + ) -> Union[Tuple[np.ndarray, np.ndarray], List[Tuple[np.ndarray, np.ndarray]]]: raw_user_ids = user_ids if isinstance(user_ids, int): user_ids = [user_ids] @@ -41,7 +44,9 @@ class Benchmark: else: return ret - def get_train_data(self, user_ids: Union[int, List[int]]): + def get_train_data( + self, user_ids: Union[int, List[int]] + ) -> Optional[Union[Tuple[np.ndarray, np.ndarray], List[Tuple[np.ndarray, np.ndarray]]]]: if self.train_X_paths is None or self.train_y_paths is None: return None @@ -98,7 +103,7 @@ class LearnwareBenchmark: else: return False - def _download_data(self, download_path: str, save_path: str): + def _download_data(self, download_path: str, save_path: str) -> None: """Download data from backend Parameters @@ -128,7 +133,7 @@ class LearnwareBenchmark: """ cache_folder = os.path.join(C.cache_path, benchmark_config.name, f"{data_type}_data") if not self._check_cache_data_valid(benchmark_config, data_type): - download_path = getattr(benchmark_config, f"{data_type}_data_path", None) + download_path = getattr(benchmark_config, f"{data_type}_data_path") self._download_data(download_path, cache_folder) X_paths, y_paths = [], [] @@ -142,10 +147,15 @@ class LearnwareBenchmark: return X_paths, y_paths - def get_benchmark(self, benchmark_config: Union[str, BenchmarkConfig]): + def get_benchmark(self, benchmark_config: Union[str, BenchmarkConfig]) -> Benchmark: if isinstance(benchmark_config, str): benchmark_config = self.benchmark_configs[benchmark_config] + if not isinstance(benchmark_config, BenchmarkConfig): + raise ValueError( + "benchmark_config must be a BenchmarkConfig object or a string in benchmark_configs.keys()!" + ) + # Load test data test_X_paths, test_y_paths = self._load_cache_data(benchmark_config, "test") diff --git a/learnware/tests/benchmarks/config.py b/learnware/tests/benchmarks/config.py index e595fd3..24a8dfd 100644 --- a/learnware/tests/benchmarks/config.py +++ b/learnware/tests/benchmarks/config.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import List, Optional +from typing import List, Optional, Dict @dataclass @@ -12,4 +12,4 @@ class BenchmarkConfig: extra_info_path: Optional[str] = None -benchmark_configs = {} +benchmark_configs: Dict[str, BenchmarkConfig] = {}