Browse Source

[FIX] remove @unittest.skipIf

tags/v0.3.2
Gene 1 year ago
parent
commit
370b497594
3 changed files with 138 additions and 104 deletions
  1. +56
    -37
      tests/test_function/test_search.py
  2. +38
    -30
      tests/test_learnware_client/test_all_learnware.py
  3. +44
    -37
      tests/test_learnware_client/test_upload.py

+ 56
- 37
tests/test_function/test_search.py View File

@@ -4,6 +4,7 @@ import tempfile
import logging

import learnware

learnware.init(logging_level=logging.WARNING)

from learnware.learnware import Learnware
@@ -11,9 +12,10 @@ from learnware.client import LearnwareClient
from learnware.market import instantiate_learnware_market, BaseUserInfo, EasySemanticChecker
from learnware.config import C


class TestSearch(unittest.TestCase):
client = LearnwareClient()
@classmethod
def setUpClass(cls):
cls.market = instantiate_learnware_market(market_id="search_test", name="hetero", rebuild=True)
@@ -31,46 +33,62 @@ class TestSearch(unittest.TestCase):
learnware_zippath = os.path.join(tempdir, f"learnware_{learnware_id}.zip")
try:
cls.client.download_learnware(learnware_id=learnware_id, save_path=learnware_zippath)
semantic_spec = cls.client.load_learnware(learnware_path=learnware_zippath).get_specification().get_semantic_spec()
semantic_spec = (
cls.client.load_learnware(learnware_path=learnware_zippath)
.get_specification()
.get_semantic_spec()
)
except Exception:
print("'learnware_id' is passed due to the network problem.")
cls.market.add_learnware(learnware_zippath, learnware_id=learnware_id, semantic_spec=semantic_spec, checker_names=["EasySemanticChecker"])
@unittest.skipIf(not client.is_connected(), "Client can not connect!")
cls.market.add_learnware(
learnware_zippath,
learnware_id=learnware_id,
semantic_spec=semantic_spec,
checker_names=["EasySemanticChecker"],
)

def _skip_test(self):
if not self.client.is_connected():
print("Client can not connect!")
return True
return False

def test_image_search(self):
learnware_id = "00000619"
try:
learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id)
except Exception:
print("'test_image_search' is passed due to the network problem.")
user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec())
search_result = self.market.search_learnware(user_info)
print("Single Search Results:", search_result.get_single_results())
print("Multiple Search Results:", search_result.get_multiple_results())
@unittest.skipIf(not client.is_connected(), "Client can not connect!")
if not self._skip_test():
learnware_id = "00000619"
try:
learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id)
except Exception:
print("'test_image_search' is passed due to the network problem.")
user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec())
search_result = self.market.search_learnware(user_info)
print("Single Search Results:", search_result.get_single_results())
print("Multiple Search Results:", search_result.get_multiple_results())
def test_text_search(self):
learnware_id = "00000653"
try:
learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id)
except Exception:
print("'test_text_search' is passed due to the network problem.")
user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec())
search_result = self.market.search_learnware(user_info)
print("Single Search Results:", search_result.get_single_results())
print("Multiple Search Results:", search_result.get_multiple_results())
@unittest.skipIf(not client.is_connected(), "Client can not connect!")
if not self._skip_test():
learnware_id = "00000653"
try:
learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id)
except Exception:
print("'test_text_search' is passed due to the network problem.")
user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec())
search_result = self.market.search_learnware(user_info)
print("Single Search Results:", search_result.get_single_results())
print("Multiple Search Results:", search_result.get_multiple_results())
def test_table_search(self):
learnware_id = "00001950"
try:
learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id)
except Exception:
print("'test_table_search' is passed due to the network problem.")
user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec())
search_result = self.market.search_learnware(user_info)
print("Single Search Results:", search_result.get_single_results())
print("Multiple Search Results:", search_result.get_multiple_results())
if not self._skip_test():
learnware_id = "00001950"
try:
learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id)
except Exception:
print("'test_table_search' is passed due to the network problem.")
user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec())
search_result = self.market.search_learnware(user_info)
print("Single Search Results:", search_result.get_single_results())
print("Multiple Search Results:", search_result.get_multiple_results())


def suite():
_suite = unittest.TestSuite()
@@ -79,6 +97,7 @@ def suite():
_suite.addTest(TestSearch("test_table_search"))
return _suite


if __name__ == "__main__":
runner = unittest.TextTestRunner()
runner.run(suite())
runner.run(suite())

+ 38
- 30
tests/test_learnware_client/test_all_learnware.py View File

@@ -9,13 +9,14 @@ from learnware.client import LearnwareClient
from learnware.specification import generate_semantic_spec
from learnware.market import BaseUserInfo


class TestAllLearnware(unittest.TestCase):
client = LearnwareClient()
@classmethod
def setUpClass(cls) -> None:
config_path = os.path.join(os.path.dirname(__file__), "config.json")
if not os.path.exists(config_path):
data = {"email": None, "token": None}
with open(config_path, "w") as file:
@@ -25,40 +26,46 @@ class TestAllLearnware(unittest.TestCase):
data = json.load(file)
email = data.get("email")
token = data.get("token")
if email is None or token is None:
print("Please set email and token in config.json.")
else:
cls.client.login(email, token)

@unittest.skipIf(not client.is_login(), "Client doest not login!")
def _skip_test(self):
if not self.client.is_login():
print("Client does not login!")
return True
return False

def test_all_learnware(self):
max_learnware_num = 2000
semantic_spec = generate_semantic_spec()
user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={})
result = self.client.search_learnware(user_info, page_size=max_learnware_num)
learnware_ids = result["single"]["learnware_ids"]
keys = [key for key in result["single"]["semantic_specifications"][0]]
print(f"result size: {len(learnware_ids)}")
print(f"key in result: {keys}")

failed_ids = []
with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
for idx in learnware_ids:
zip_path = os.path.join(tempdir, f"test_{idx}.zip")
self.client.download_learnware(idx, zip_path)
with zipfile.ZipFile(zip_path, "r") as zip_file:
with zip_file.open("semantic_specification.json") as json_file:
semantic_spec = json.load(json_file)
try:
LearnwareClient.check_learnware(zip_path, semantic_spec)
print(f"check learnware {idx} succeed")
except:
failed_ids.append(idx)
print(f"check learnware {idx} failed!!!")

print(f"The currently failed learnware ids: {failed_ids}")
if not self._skip_test():
max_learnware_num = 2000
semantic_spec = generate_semantic_spec()
user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={})
result = self.client.search_learnware(user_info, page_size=max_learnware_num)

learnware_ids = result["single"]["learnware_ids"]
keys = [key for key in result["single"]["semantic_specifications"][0]]
print(f"result size: {len(learnware_ids)}")
print(f"key in result: {keys}")

failed_ids = []
with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
for idx in learnware_ids:
zip_path = os.path.join(tempdir, f"test_{idx}.zip")
self.client.download_learnware(idx, zip_path)
with zipfile.ZipFile(zip_path, "r") as zip_file:
with zip_file.open("semantic_specification.json") as json_file:
semantic_spec = json.load(json_file)
try:
LearnwareClient.check_learnware(zip_path, semantic_spec)
print(f"check learnware {idx} succeed")
except:
failed_ids.append(idx)
print(f"check learnware {idx} failed!!!")

print(f"The currently failed learnware ids: {failed_ids}")


def suite():
@@ -66,6 +73,7 @@ def suite():
_suite.addTest(TestAllLearnware("test_all_learnware"))
return _suite


if __name__ == "__main__":
runner = unittest.TextTestRunner()
runner.run(suite())

+ 44
- 37
tests/test_learnware_client/test_upload.py View File

@@ -6,13 +6,14 @@ import tempfile
from learnware.client import LearnwareClient
from learnware.specification import generate_semantic_spec


class TestUpload(unittest.TestCase):
client = LearnwareClient()
@classmethod
def setUpClass(cls) -> None:
config_path = os.path.join(os.path.dirname(__file__), "config.json")
if not os.path.exists(config_path):
data = {"email": None, "token": None}
with open(config_path, "w") as file:
@@ -22,50 +23,55 @@ class TestUpload(unittest.TestCase):
data = json.load(file)
email = data.get("email")
token = data.get("token")
if email is None or token is None:
print("Please set email and token in config.json.")
else:
cls.client.login(email, token)

@unittest.skipIf(not client.is_login(), "Client doest not login!")
def _skip_test(self):
if not self.client.is_login():
print("Client does not login!")
return True
return False

def test_upload(self):
input_description = {
"Dimension": 13,
"Description": {"0": "age", "1": "weight", "2": "body length", "3": "animal type", "4": "claw length"},
}
output_description = {
"Dimension": 1,
"Description": {
"0": "the probability of being a cat",
},
}
semantic_spec = generate_semantic_spec(
name="learnware_example",
description="Just a example for uploading a learnware",
data_type="Table",
task_type="Classification",
library_type="Scikit-learn",
scenarios=["Business", "Financial"],
input_description=input_description,
output_description=output_description,
)
assert isinstance(semantic_spec, dict)

download_learnware_id = "00000084"
with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
zip_path = os.path.join(tempdir, f"test.zip")
self.client.download_learnware(download_learnware_id, zip_path)
learnware_id = self.client.upload_learnware(
learnware_zip_path=zip_path, semantic_specification=semantic_spec
if not self._skip_test():
input_description = {
"Dimension": 13,
"Description": {"0": "age", "1": "weight", "2": "body length", "3": "animal type", "4": "claw length"},
}
output_description = {
"Dimension": 2,
"Description": {"0": "cat", "1": "not cat"},
}
semantic_spec = generate_semantic_spec(
name="learnware_example",
description="Just a example for uploading a learnware",
data_type="Table",
task_type="Classification",
library_type="Scikit-learn",
scenarios=["Business", "Financial"],
license="MIT",
input_description=input_description,
output_description=output_description,
)
assert isinstance(semantic_spec, dict)

download_learnware_id = "00000084"
with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
zip_path = os.path.join(tempdir, f"test.zip")
self.client.download_learnware(download_learnware_id, zip_path)
learnware_id = self.client.upload_learnware(
learnware_zip_path=zip_path, semantic_specification=semantic_spec
)

uploaded_ids = [learnware["learnware_id"] for learnware in self.client.list_learnware()]
assert learnware_id in uploaded_ids
uploaded_ids = [learnware["learnware_id"] for learnware in self.client.list_learnware()]
assert learnware_id in uploaded_ids

self.client.delete_learnware(learnware_id)
uploaded_ids = [learnware["learnware_id"] for learnware in self.client.list_learnware()]
assert learnware_id not in uploaded_ids
self.client.delete_learnware(learnware_id)
uploaded_ids = [learnware["learnware_id"] for learnware in self.client.list_learnware()]
assert learnware_id not in uploaded_ids


def suite():
@@ -73,6 +79,7 @@ def suite():
_suite.addTest(TestUpload("test_upload"))
return _suite


if __name__ == "__main__":
runner = unittest.TextTestRunner()
runner.run(suite())

Loading…
Cancel
Save