Browse Source

[MNT] enable benchmarks pass the mypy test

tags/v0.3.2
Gene 1 year ago
parent
commit
cd92ac2922
2 changed files with 17 additions and 7 deletions
  1. +15
    -5
      learnware/tests/benchmarks/__init__.py
  2. +2
    -2
      learnware/tests/benchmarks/config.py

+ 15
- 5
learnware/tests/benchmarks/__init__.py View File

@@ -2,6 +2,7 @@ import os
import pickle import pickle
import tempfile import tempfile
import zipfile import zipfile
import numpy as np
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union


@@ -21,7 +22,9 @@ class Benchmark:
train_y_paths: Optional[List[str]] = None train_y_paths: Optional[List[str]] = None
extra_info_path: Optional[str] = None extra_info_path: Optional[str] = None


def get_test_data(self, user_ids: Union[int, List[int]]):
def get_test_data(
self, user_ids: Union[int, List[int]]
) -> Union[Tuple[np.ndarray, np.ndarray], List[Tuple[np.ndarray, np.ndarray]]]:
raw_user_ids = user_ids raw_user_ids = user_ids
if isinstance(user_ids, int): if isinstance(user_ids, int):
user_ids = [user_ids] user_ids = [user_ids]
@@ -41,7 +44,9 @@ class Benchmark:
else: else:
return ret return ret


def get_train_data(self, user_ids: Union[int, List[int]]):
def get_train_data(
self, user_ids: Union[int, List[int]]
) -> Optional[Union[Tuple[np.ndarray, np.ndarray], List[Tuple[np.ndarray, np.ndarray]]]]:
if self.train_X_paths is None or self.train_y_paths is None: if self.train_X_paths is None or self.train_y_paths is None:
return None return None


@@ -98,7 +103,7 @@ class LearnwareBenchmark:
else: else:
return False return False


def _download_data(self, download_path: str, save_path: str):
def _download_data(self, download_path: str, save_path: str) -> None:
"""Download data from backend """Download data from backend


Parameters Parameters
@@ -128,7 +133,7 @@ class LearnwareBenchmark:
""" """
cache_folder = os.path.join(C.cache_path, benchmark_config.name, f"{data_type}_data") cache_folder = os.path.join(C.cache_path, benchmark_config.name, f"{data_type}_data")
if not self._check_cache_data_valid(benchmark_config, data_type): if not self._check_cache_data_valid(benchmark_config, data_type):
download_path = getattr(benchmark_config, f"{data_type}_data_path", None)
download_path = getattr(benchmark_config, f"{data_type}_data_path")
self._download_data(download_path, cache_folder) self._download_data(download_path, cache_folder)


X_paths, y_paths = [], [] X_paths, y_paths = [], []
@@ -142,10 +147,15 @@ class LearnwareBenchmark:


return X_paths, y_paths return X_paths, y_paths


def get_benchmark(self, benchmark_config: Union[str, BenchmarkConfig]):
def get_benchmark(self, benchmark_config: Union[str, BenchmarkConfig]) -> Benchmark:
if isinstance(benchmark_config, str): if isinstance(benchmark_config, str):
benchmark_config = self.benchmark_configs[benchmark_config] benchmark_config = self.benchmark_configs[benchmark_config]


if not isinstance(benchmark_config, BenchmarkConfig):
raise ValueError(
"benchmark_config must be a BenchmarkConfig object or a string in benchmark_configs.keys()!"
)

# Load test data # Load test data
test_X_paths, test_y_paths = self._load_cache_data(benchmark_config, "test") test_X_paths, test_y_paths = self._load_cache_data(benchmark_config, "test")




+ 2
- 2
learnware/tests/benchmarks/config.py View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional
from typing import List, Optional, Dict




@dataclass @dataclass
@@ -12,4 +12,4 @@ class BenchmarkConfig:
extra_info_path: Optional[str] = None extra_info_path: Optional[str] = None




benchmark_configs = {}
benchmark_configs: Dict[str, BenchmarkConfig] = {}

Loading…
Cancel
Save