Browse Source

Merge branch 'main' of https://github.com/Learnware-LAMDA/Learnware into search_result

tags/v0.3.2
bxdd 2 years ago
parent
commit
331aabfe58
4 changed files with 103 additions and 17 deletions
  1. +30
    -3
      learnware/client/learnware_client.py
  2. +40
    -7
      learnware/market/easy/checker.py
  3. +5
    -1
      tests/test_learnware_client/test_all_learnware.py
  4. +28
    -6
      tests/test_learnware_client/test_check_learnware.py

+ 30
- 3
learnware/client/learnware_client.py View File

@@ -140,6 +140,33 @@ class LearnwareClient:

return result["data"]["learnware_id"]

@require_login
def update_learnware(self, learnware_id, semantic_specification, learnware_zip_path=None):
assert self._check_semantic_specification(semantic_specification)[0], "Semantic specification check failed!"

url_update = f"{self.host}/user/update_learnware"
payload = {"learnware_id": learnware_id, "semantic_specification": json.dumps(semantic_specification)}

if learnware_zip_path is None:
response = requests.post(
url_update,
files={"learnware_file": None},
data=payload,
headers=self.headers,
)
else:
response = requests.post(
url_update,
files={"learnware_file": open(learnware_zip_path, "rb")},
data=payload,
headers=self.headers,
)

result = response.json()

if result["code"] != 0:
raise Exception("update failed: " + json.dumps(result))

def download_learnware(self, learnware_id, save_path):
url = f"{self.host}/engine/download_learnware"

@@ -275,8 +302,8 @@ class LearnwareClient:
"Type": "String",
"Values": description if description is not None else "",
}
semantic_specification["Input"] = input_description
semantic_specification["Output"] = output_description
semantic_specification["Input"] = {} if input_description is None else input_description
semantic_specification["Output"] = {} if output_description is None else output_description

return semantic_specification

@@ -351,7 +378,7 @@ class LearnwareClient:
semantic_specification = json.load(fin)

return learnware.get_learnware_from_dirpath(learnware_id, semantic_specification, tempdir)
learnware_list = []
if learnware_path is not None:
zip_paths = [learnware_path] if isinstance(learnware_path, str) else learnware_path


+ 40
- 7
learnware/market/easy/checker.py View File

@@ -47,6 +47,10 @@ class EasySemanticChecker(BaseChecker):
if semantic_spec["Task"]["Values"][0] in ["Classification", "Regression"]:
assert semantic_spec["Output"] is not None, "Lack of output semantics"
dim = semantic_spec["Output"]["Dimension"]
assert (
dim > 1 or semantic_spec["Task"]["Values"][0] == "Regression"
), "Classification task must have dimension > 1"

for k, v in semantic_spec["Output"]["Description"].items():
assert int(k) >= 0 and int(k) < dim, f"Dimension number in [0, {dim})"
assert isinstance(v, str), "Description must be string"
@@ -110,6 +114,11 @@ class EasyStatChecker(BaseChecker):

if spec_type == "RKMETableSpecification":
stat_spec = learnware.get_specification().get_stat_spec_by_name(spec_type)
if not isinstance(input_shape, tuple) or not all(isinstance(item, int) for item in input_shape):
raise ValueError(
f"For RKMETableSpecification, input_shape should be tuple of int, but got {input_shape}"
)

if stat_spec.get_z().shape[1:] != input_shape:
message = f"The learnware [{learnware.id}] input dimension mismatch with stat specification."
logger.warning(message)
@@ -118,6 +127,10 @@ class EasyStatChecker(BaseChecker):
elif spec_type == "RKMETextSpecification":
inputs = EasyStatChecker._generate_random_text_list(10)
elif spec_type == "RKMEImageSpecification":
if not isinstance(input_shape, tuple) or not all(isinstance(item, int) for item in input_shape):
raise ValueError(
f"For RKMEImageSpecification, input_shape should be tuple of int, but got {input_shape}"
)
inputs = np.random.randint(0, 255, size=(10, *input_shape))
else:
raise ValueError(f"not supported spec type for spec_type = {spec_type}")
@@ -155,19 +168,39 @@ class EasyStatChecker(BaseChecker):

# Check output shape
if outputs[0].shape != learnware_model.output_shape:
message = f"The learnware [{learnware.id}] output dimension mismatch!, where pred_shape={outputs[0].shape}, model_shape={learnware_model.output_shape}"
message = f"The learnware [{learnware.id}] output dimension mismatch, where pred_shape={outputs[0].shape}, model_shape={learnware_model.output_shape}"
logger.warning(message)
return self.INVALID_LEARNWARE, message

# Check output dimension
if semantic_spec["Task"]["Values"][0] in [
"Classification",
"Regression",
] and learnware_model.output_shape[0] != int(semantic_spec["Output"]["Dimension"]):
message = f"The learnware [{learnware.id}] output dimension mismatch!, where model_shape={learnware_model.output_shape}, semantic_shape={(int(semantic_spec['Output']['Dimension']), )}"
# Check output dimension for regression
if semantic_spec["Task"]["Values"][0] == "Regression" and learnware_model.output_shape[0] != int(
semantic_spec["Output"]["Dimension"]
):
message = f"The learnware [{learnware.id}] output dimension mismatch, where model_shape={learnware_model.output_shape}, semantic_shape={(int(semantic_spec['Output']['Dimension']), )}"
logger.warning(message)
return self.INVALID_LEARNWARE, message

# Check output dimension for classification
if semantic_spec["Task"]["Values"][0] == "Classification":
model_output_shape = learnware_model.output_shape[0]
semantic_output_shape = int(semantic_spec["Output"]["Dimension"])

if model_output_shape == 1:
if isinstance(outputs, torch.Tensor):
outputs = outputs.detach().cpu().numpy()
if isinstance(outputs, list):
outputs = np.array(outputs)

if not np.all(np.logical_and(outputs >= 0, outputs < semantic_output_shape)):
message = f"The learnware [{learnware.id}] output label mismatch, where outputs of model is {outputs}, semantic_shape={(semantic_output_shape, )}"
logger.warning(message)
return self.INVALID_LEARNWARE, message
else:
if model_output_shape != semantic_output_shape:
message = f"The learnware [{learnware.id}] output dimension mismatch, where model_shape={learnware_model.output_shape}, semantic_shape={(semantic_output_shape, )}"
logger.warning(message)
return self.INVALID_LEARNWARE, message

except Exception as e:
message = f"The learnware [{learnware.id}] is not valid! Due to {repr(e)}."
logger.warning(message)


+ 5
- 1
tests/test_learnware_client/test_all_learnware.py View File

@@ -1,5 +1,6 @@
import os
import json
import zipfile
import unittest
import tempfile

@@ -48,8 +49,11 @@ class TestAllLearnware(unittest.TestCase):
for idx in learnware_ids:
zip_path = os.path.join(tempdir, f"test_{idx}.zip")
self.client.download_learnware(idx, zip_path)
with zipfile.ZipFile(zip_path, "r") as zip_file:
with zip_file.open("semantic_specification.json") as json_file:
semantic_spec = json.load(json_file)
try:
LearnwareClient.check_learnware(zip_path)
LearnwareClient.check_learnware(zip_path, semantic_spec)
print(f"check learnware {idx} succeed")
except:
failed_ids.append(idx)


+ 28
- 6
tests/test_learnware_client/test_check_learnware.py View File

@@ -1,4 +1,6 @@
import os
import json
import zipfile
import unittest
import tempfile

@@ -12,39 +14,59 @@ class TestCheckLearnware(unittest.TestCase):
self.client = LearnwareClient()

def test_check_learnware_pip(self):
learnware_id = "00000154"
learnware_id = "00000208"
with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
self.zip_path = os.path.join(tempdir, "test.zip")
self.client.download_learnware(learnware_id, self.zip_path)
LearnwareClient.check_learnware(self.zip_path)

with zipfile.ZipFile(self.zip_path, "r") as zip_file:
with zip_file.open("semantic_specification.json") as json_file:
semantic_spec = json.load(json_file)
LearnwareClient.check_learnware(self.zip_path, semantic_spec)

def test_check_learnware_conda(self):
learnware_id = "00000148"
with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
self.zip_path = os.path.join(tempdir, "test.zip")
self.client.download_learnware(learnware_id, self.zip_path)
LearnwareClient.check_learnware(self.zip_path)

with zipfile.ZipFile(self.zip_path, "r") as zip_file:
with zip_file.open("semantic_specification.json") as json_file:
semantic_spec = json.load(json_file)
LearnwareClient.check_learnware(self.zip_path, semantic_spec)

def test_check_learnware_dependency(self):
learnware_id = "00000147"
with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
self.zip_path = os.path.join(tempdir, "test.zip")
self.client.download_learnware(learnware_id, self.zip_path)
LearnwareClient.check_learnware(self.zip_path)

with zipfile.ZipFile(self.zip_path, "r") as zip_file:
with zip_file.open("semantic_specification.json") as json_file:
semantic_spec = json.load(json_file)
LearnwareClient.check_learnware(self.zip_path, semantic_spec)

def test_check_learnware_image(self):
learnware_id = "00000677"
with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
self.zip_path = os.path.join(tempdir, "test.zip")
self.client.download_learnware(learnware_id, self.zip_path)
LearnwareClient.check_learnware(self.zip_path)

with zipfile.ZipFile(self.zip_path, "r") as zip_file:
with zip_file.open("semantic_specification.json") as json_file:
semantic_spec = json.load(json_file)
LearnwareClient.check_learnware(self.zip_path, semantic_spec)

def test_check_learnware_text(self):
learnware_id = "00000662"
with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
self.zip_path = os.path.join(tempdir, "test.zip")
self.client.download_learnware(learnware_id, self.zip_path)
LearnwareClient.check_learnware(self.zip_path)

with zipfile.ZipFile(self.zip_path, "r") as zip_file:
with zip_file.open("semantic_specification.json") as json_file:
semantic_spec = json.load(json_file)
LearnwareClient.check_learnware(self.zip_path, semantic_spec)


if __name__ == "__main__":


Loading…
Cancel
Save