Browse Source

[MNT] now the test_search works well

tags/v0.3.2
bxdd 1 year ago
parent
commit
6e2fc53664
8 changed files with 110 additions and 14 deletions
  1. +9
    -2
      learnware/client/learnware_client.py
  2. +1
    -1
      learnware/market/base.py
  3. +1
    -1
      learnware/market/heterogeneous/organizer/__init__.py
  4. +9
    -6
      learnware/market/heterogeneous/utils.py
  5. +6
    -4
      learnware/market/module.py
  6. +84
    -0
      tests/test_procedure/test_search.py
  7. +0
    -0
      tests/test_search_learnware/test_search_image.py
  8. +0
    -0
      tests/test_search_learnware/test_search_text.py

+ 9
- 2
learnware/client/learnware_client.py View File

@@ -70,7 +70,14 @@ class LearnwareClient:
self.tempdir_list = []
self.login_status = False
atexit.register(self.cleanup)

def is_connected(self):
url = f"{self.host}/auth/login_by_token"
response = requests.post(url)
if response.status_code == 404:
return False
return True
def login(self, email, token):
url = f"{self.host}/auth/login_by_token"

@@ -172,7 +179,7 @@ class LearnwareClient:
if result["code"] != 0:
raise Exception("update failed: " + json.dumps(result))

def download_learnware(self, learnware_id, save_path):
def download_learnware(self, learnware_id: str, save_path: str):
url = f"{self.host}/engine/download_learnware"

response = requests.get(


+ 1
- 1
learnware/market/base.py View File

@@ -132,7 +132,7 @@ class LearnwareMarket:
def check_learnware(self, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, **kwargs) -> bool:
try:
final_status = BaseChecker.NONUSABLE_LEARNWARE
if len(checker_names):
if checker_names is not None and len(checker_names):
with tempfile.TemporaryDirectory(prefix="pending_learnware_") as tempdir:
with zipfile.ZipFile(zip_path, mode="r") as z_file:
z_file.extractall(tempdir)


+ 1
- 1
learnware/market/heterogeneous/organizer/__init__.py View File

@@ -245,7 +245,7 @@ class HeteroMapTableOrganizer(EasyOrganizer):
ret = []
for idx in ids:
spec = self.learnware_list[idx].get_specification()
if is_hetero(stat_specs=spec.get_stat_spec(), semantic_spec=spec.get_semantic_spec()):
if is_hetero(stat_specs=spec.get_stat_spec(), semantic_spec=spec.get_semantic_spec(), verbose=False):
ret.append(idx)
return ret



+ 9
- 6
learnware/market/heterogeneous/utils.py View File

@@ -1,9 +1,10 @@
import traceback
from ...logger import get_module_logger

logger = get_module_logger("hetero_utils")


def is_hetero(stat_specs: dict, semantic_spec: dict) -> bool:
def is_hetero(stat_specs: dict, semantic_spec: dict, verbose=True) -> bool:
"""Check if user_info satifies all the criteria required for enabling heterogeneous learnware search

Parameters
@@ -35,15 +36,17 @@ def is_hetero(stat_specs: dict, semantic_spec: dict) -> bool:
semantic_decription_feature_num = len(semantic_input_description["Description"])

if semantic_decription_feature_num <= 0:
logger.warning("At least one of Input.Description in semantic spec should be provides.")
if verbose:
logger.warning("At least one of Input.Description in semantic spec should be provides.")
return False

if table_input_shape != semantic_description_dim:
logger.warning("User data feature dimensions mismatch with semantic specification.")
if verbose:
logger.warning("User data feature dimensions mismatch with semantic specification.")
return False

return True
except Exception as e:
logger.warning(f"Invalid heterogeneous search information provided due to {e}. Use homogeneous search instead.")
except Exception as err:
if verbose:
logger.warning(f"Invalid heterogeneous search information provided.")
return False

+ 6
- 4
learnware/market/module.py View File

@@ -1,9 +1,10 @@
from .base import LearnwareMarket
from .classes import CondaChecker
from .easy import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatChecker
from .heterogeneous import HeteroMapTableOrganizer, HeteroSearcher


def get_market_component(name, market_id, rebuild, organizer_kwargs=None, searcher_kwargs=None, checker_kwargs=None):
def get_market_component(name, market_id, rebuild, organizer_kwargs=None, searcher_kwargs=None, checker_kwargs=None, conda_checker=False):
organizer_kwargs = {} if organizer_kwargs is None else organizer_kwargs
searcher_kwargs = {} if searcher_kwargs is None else searcher_kwargs
checker_kwargs = {} if checker_kwargs is None else checker_kwargs
@@ -11,7 +12,7 @@ def get_market_component(name, market_id, rebuild, organizer_kwargs=None, search
if name == "easy":
easy_organizer = EasyOrganizer(market_id=market_id, rebuild=rebuild)
easy_searcher = EasySearcher(organizer=easy_organizer)
easy_checker_list = [EasySemanticChecker(), EasyStatChecker()]
easy_checker_list = [EasySemanticChecker(), EasyStatChecker() if conda_checker is False else CondaChecker(EasyStatChecker())]
market_component = {
"organizer": easy_organizer,
"searcher": easy_searcher,
@@ -20,7 +21,7 @@ def get_market_component(name, market_id, rebuild, organizer_kwargs=None, search
elif name == "hetero":
hetero_organizer = HeteroMapTableOrganizer(market_id=market_id, rebuild=rebuild, **organizer_kwargs)
hetero_searcher = HeteroSearcher(organizer=hetero_organizer)
hetero_checker_list = [EasySemanticChecker(), EasyStatChecker()]
hetero_checker_list = [EasySemanticChecker(), EasyStatChecker() if conda_checker is False else CondaChecker(EasyStatChecker())]

market_component = {
"organizer": hetero_organizer,
@@ -40,9 +41,10 @@ def instantiate_learnware_market(
organizer_kwargs: dict = None,
searcher_kwargs: dict = None,
checker_kwargs: dict = None,
conda_checker: bool = False,
**kwargs,
):
market_componets = get_market_component(name, market_id, rebuild, organizer_kwargs, searcher_kwargs, checker_kwargs)
market_componets = get_market_component(name, market_id, rebuild, organizer_kwargs, searcher_kwargs, checker_kwargs, conda_checker)
return LearnwareMarket(
organizer=market_componets["organizer"],
searcher=market_componets["searcher"],


+ 84
- 0
tests/test_procedure/test_search.py View File

@@ -0,0 +1,84 @@
import os
import unittest
import tempfile
import logging

import learnware
learnware.init(logging_level=logging.WARNING)

from learnware.learnware import Learnware
from learnware.client import LearnwareClient
from learnware.market import instantiate_learnware_market, BaseUserInfo, EasySemanticChecker
from learnware.config import C

class TestSearch(unittest.TestCase):
client = LearnwareClient()
@classmethod
def setUpClass(cls):
cls.market = instantiate_learnware_market(market_id="search_test", name="hetero", rebuild=True)
if cls.client.is_connected():
cls._build_learnware_market()

@classmethod
def _build_learnware_market(cls):
table_learnware_ids = ["00001951", "00001980", "00001987"]
image_learnware_ids = ["00000851", "00000858", "00000841"]
text_learnware_ids = ["00000652", "00000637"]
learnware_ids = table_learnware_ids + image_learnware_ids + text_learnware_ids
with tempfile.TemporaryDirectory(prefix="learnware_search_test") as tempdir:
for learnware_id in learnware_ids:
learnware_zippath = os.path.join(tempdir, f"learnware_{learnware_id}.zip")
try:
cls.client.download_learnware(learnware_id=learnware_id, save_path=learnware_zippath)
semantic_spec = cls.client.load_learnware(learnware_path=learnware_zippath).get_specification().get_semantic_spec()
except Exception:
print("'learnware_id' is passed due to the network problem.")
cls.market.add_learnware(learnware_zippath, learnware_id=learnware_id, semantic_spec=semantic_spec, checker_names=["EasySemanticChecker"])
@unittest.skipIf(not client.is_connected(), "Client can not connect!")
def test_image_search(self):
learnware_id = "00000619"
try:
learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id)
except Exception:
print("'test_image_search' is passed due to the network problem.")
user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec())
search_result = self.market.search_learnware(user_info)
print("Single Search Results:", search_result.get_single_results())
print("Multiple Search Results:", search_result.get_multiple_results())
@unittest.skipIf(not client.is_connected(), "Client can not connect!")
def test_text_search(self):
learnware_id = "00000653"
try:
learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id)
except Exception:
print("'test_text_search' is passed due to the network problem.")
user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec())
search_result = self.market.search_learnware(user_info)
print("Single Search Results:", search_result.get_single_results())
print("Multiple Search Results:", search_result.get_multiple_results())
@unittest.skipIf(not client.is_connected(), "Client can not connect!")
def test_table_search(self):
learnware_id = "00001950"
try:
learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id)
except Exception:
print("'test_table_search' is passed due to the network problem.")
user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec())
search_result = self.market.search_learnware(user_info)
print("Single Search Results:", search_result.get_single_results())
print("Multiple Search Results:", search_result.get_multiple_results())

def suite():
_suite = unittest.TestSuite()
_suite.addTest(TestSearch("test_image_search"))
_suite.addTest(TestSearch("test_text_search"))
_suite.addTest(TestSearch("test_table_search"))
return _suite

if __name__ == "__main__":
runner = unittest.TextTestRunner()
runner.run(suite())

+ 0
- 0
tests/test_search_learnware/test_search_image.py View File


+ 0
- 0
tests/test_search_learnware/test_search_text.py View File


Loading…
Cancel
Save