| @@ -140,6 +140,33 @@ class LearnwareClient: | |||
| return result["data"]["learnware_id"] | |||
| @require_login | |||
| def update_learnware(self, learnware_id, semantic_specification, learnware_zip_path=None): | |||
| assert self._check_semantic_specification(semantic_specification)[0], "Semantic specification check failed!" | |||
| url_update = f"{self.host}/user/update_learnware" | |||
| payload = {"learnware_id": learnware_id, "semantic_specification": json.dumps(semantic_specification)} | |||
| if learnware_zip_path is None: | |||
| response = requests.post( | |||
| url_update, | |||
| files={"learnware_file": None}, | |||
| data=payload, | |||
| headers=self.headers, | |||
| ) | |||
| else: | |||
| response = requests.post( | |||
| url_update, | |||
| files={"learnware_file": open(learnware_zip_path, "rb")}, | |||
| data=payload, | |||
| headers=self.headers, | |||
| ) | |||
| result = response.json() | |||
| if result["code"] != 0: | |||
| raise Exception("update failed: " + json.dumps(result)) | |||
| def download_learnware(self, learnware_id, save_path): | |||
| url = f"{self.host}/engine/download_learnware" | |||
| @@ -275,8 +302,8 @@ class LearnwareClient: | |||
| "Type": "String", | |||
| "Values": description if description is not None else "", | |||
| } | |||
| semantic_specification["Input"] = input_description | |||
| semantic_specification["Output"] = output_description | |||
| semantic_specification["Input"] = {} if input_description is None else input_description | |||
| semantic_specification["Output"] = {} if output_description is None else output_description | |||
| return semantic_specification | |||
| @@ -351,7 +378,7 @@ class LearnwareClient: | |||
| semantic_specification = json.load(fin) | |||
| return learnware.get_learnware_from_dirpath(learnware_id, semantic_specification, tempdir) | |||
| learnware_list = [] | |||
| if learnware_path is not None: | |||
| zip_paths = [learnware_path] if isinstance(learnware_path, str) else learnware_path | |||
| @@ -47,6 +47,10 @@ class EasySemanticChecker(BaseChecker): | |||
| if semantic_spec["Task"]["Values"][0] in ["Classification", "Regression"]: | |||
| assert semantic_spec["Output"] is not None, "Lack of output semantics" | |||
| dim = semantic_spec["Output"]["Dimension"] | |||
| assert ( | |||
| dim > 1 or semantic_spec["Task"]["Values"][0] == "Regression" | |||
| ), "Classification task must have dimension > 1" | |||
| for k, v in semantic_spec["Output"]["Description"].items(): | |||
| assert int(k) >= 0 and int(k) < dim, f"Dimension number in [0, {dim})" | |||
| assert isinstance(v, str), "Description must be string" | |||
| @@ -110,6 +114,11 @@ class EasyStatChecker(BaseChecker): | |||
| if spec_type == "RKMETableSpecification": | |||
| stat_spec = learnware.get_specification().get_stat_spec_by_name(spec_type) | |||
| if not isinstance(input_shape, tuple) or not all(isinstance(item, int) for item in input_shape): | |||
| raise ValueError( | |||
| f"For RKMETableSpecification, input_shape should be tuple of int, but got {input_shape}" | |||
| ) | |||
| if stat_spec.get_z().shape[1:] != input_shape: | |||
| message = f"The learnware [{learnware.id}] input dimension mismatch with stat specification." | |||
| logger.warning(message) | |||
| @@ -118,6 +127,10 @@ class EasyStatChecker(BaseChecker): | |||
| elif spec_type == "RKMETextSpecification": | |||
| inputs = EasyStatChecker._generate_random_text_list(10) | |||
| elif spec_type == "RKMEImageSpecification": | |||
| if not isinstance(input_shape, tuple) or not all(isinstance(item, int) for item in input_shape): | |||
| raise ValueError( | |||
| f"For RKMEImageSpecification, input_shape should be tuple of int, but got {input_shape}" | |||
| ) | |||
| inputs = np.random.randint(0, 255, size=(10, *input_shape)) | |||
| else: | |||
| raise ValueError(f"not supported spec type for spec_type = {spec_type}") | |||
| @@ -155,19 +168,39 @@ class EasyStatChecker(BaseChecker): | |||
| # Check output shape | |||
| if outputs[0].shape != learnware_model.output_shape: | |||
| message = f"The learnware [{learnware.id}] output dimension mismatch!, where pred_shape={outputs[0].shape}, model_shape={learnware_model.output_shape}" | |||
| message = f"The learnware [{learnware.id}] output dimension mismatch, where pred_shape={outputs[0].shape}, model_shape={learnware_model.output_shape}" | |||
| logger.warning(message) | |||
| return self.INVALID_LEARNWARE, message | |||
| # Check output dimension | |||
| if semantic_spec["Task"]["Values"][0] in [ | |||
| "Classification", | |||
| "Regression", | |||
| ] and learnware_model.output_shape[0] != int(semantic_spec["Output"]["Dimension"]): | |||
| message = f"The learnware [{learnware.id}] output dimension mismatch!, where model_shape={learnware_model.output_shape}, semantic_shape={(int(semantic_spec['Output']['Dimension']), )}" | |||
| # Check output dimension for regression | |||
| if semantic_spec["Task"]["Values"][0] == "Regression" and learnware_model.output_shape[0] != int( | |||
| semantic_spec["Output"]["Dimension"] | |||
| ): | |||
| message = f"The learnware [{learnware.id}] output dimension mismatch, where model_shape={learnware_model.output_shape}, semantic_shape={(int(semantic_spec['Output']['Dimension']), )}" | |||
| logger.warning(message) | |||
| return self.INVALID_LEARNWARE, message | |||
| # Check output dimension for classification | |||
| if semantic_spec["Task"]["Values"][0] == "Classification": | |||
| model_output_shape = learnware_model.output_shape[0] | |||
| semantic_output_shape = int(semantic_spec["Output"]["Dimension"]) | |||
| if model_output_shape == 1: | |||
| if isinstance(outputs, torch.Tensor): | |||
| outputs = outputs.detach().cpu().numpy() | |||
| if isinstance(outputs, list): | |||
| outputs = np.array(outputs) | |||
| if not np.all(np.logical_and(outputs >= 0, outputs < semantic_output_shape)): | |||
| message = f"The learnware [{learnware.id}] output label mismatch, where outputs of model is {outputs}, semantic_shape={(semantic_output_shape, )}" | |||
| logger.warning(message) | |||
| return self.INVALID_LEARNWARE, message | |||
| else: | |||
| if model_output_shape != semantic_output_shape: | |||
| message = f"The learnware [{learnware.id}] output dimension mismatch, where model_shape={learnware_model.output_shape}, semantic_shape={(semantic_output_shape, )}" | |||
| logger.warning(message) | |||
| return self.INVALID_LEARNWARE, message | |||
| except Exception as e: | |||
| message = f"The learnware [{learnware.id}] is not valid! Due to {repr(e)}." | |||
| logger.warning(message) | |||
| @@ -1,5 +1,6 @@ | |||
| import os | |||
| import json | |||
| import zipfile | |||
| import unittest | |||
| import tempfile | |||
| @@ -48,8 +49,11 @@ class TestAllLearnware(unittest.TestCase): | |||
| for idx in learnware_ids: | |||
| zip_path = os.path.join(tempdir, f"test_{idx}.zip") | |||
| self.client.download_learnware(idx, zip_path) | |||
| with zipfile.ZipFile(zip_path, "r") as zip_file: | |||
| with zip_file.open("semantic_specification.json") as json_file: | |||
| semantic_spec = json.load(json_file) | |||
| try: | |||
| LearnwareClient.check_learnware(zip_path) | |||
| LearnwareClient.check_learnware(zip_path, semantic_spec) | |||
| print(f"check learnware {idx} succeed") | |||
| except: | |||
| failed_ids.append(idx) | |||
| @@ -1,4 +1,6 @@ | |||
| import os | |||
| import json | |||
| import zipfile | |||
| import unittest | |||
| import tempfile | |||
| @@ -12,39 +14,59 @@ class TestCheckLearnware(unittest.TestCase): | |||
| self.client = LearnwareClient() | |||
| def test_check_learnware_pip(self): | |||
| learnware_id = "00000154" | |||
| learnware_id = "00000208" | |||
| with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: | |||
| self.zip_path = os.path.join(tempdir, "test.zip") | |||
| self.client.download_learnware(learnware_id, self.zip_path) | |||
| LearnwareClient.check_learnware(self.zip_path) | |||
| with zipfile.ZipFile(self.zip_path, "r") as zip_file: | |||
| with zip_file.open("semantic_specification.json") as json_file: | |||
| semantic_spec = json.load(json_file) | |||
| LearnwareClient.check_learnware(self.zip_path, semantic_spec) | |||
| def test_check_learnware_conda(self): | |||
| learnware_id = "00000148" | |||
| with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: | |||
| self.zip_path = os.path.join(tempdir, "test.zip") | |||
| self.client.download_learnware(learnware_id, self.zip_path) | |||
| LearnwareClient.check_learnware(self.zip_path) | |||
| with zipfile.ZipFile(self.zip_path, "r") as zip_file: | |||
| with zip_file.open("semantic_specification.json") as json_file: | |||
| semantic_spec = json.load(json_file) | |||
| LearnwareClient.check_learnware(self.zip_path, semantic_spec) | |||
| def test_check_learnware_dependency(self): | |||
| learnware_id = "00000147" | |||
| with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: | |||
| self.zip_path = os.path.join(tempdir, "test.zip") | |||
| self.client.download_learnware(learnware_id, self.zip_path) | |||
| LearnwareClient.check_learnware(self.zip_path) | |||
| with zipfile.ZipFile(self.zip_path, "r") as zip_file: | |||
| with zip_file.open("semantic_specification.json") as json_file: | |||
| semantic_spec = json.load(json_file) | |||
| LearnwareClient.check_learnware(self.zip_path, semantic_spec) | |||
| def test_check_learnware_image(self): | |||
| learnware_id = "00000677" | |||
| with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: | |||
| self.zip_path = os.path.join(tempdir, "test.zip") | |||
| self.client.download_learnware(learnware_id, self.zip_path) | |||
| LearnwareClient.check_learnware(self.zip_path) | |||
| with zipfile.ZipFile(self.zip_path, "r") as zip_file: | |||
| with zip_file.open("semantic_specification.json") as json_file: | |||
| semantic_spec = json.load(json_file) | |||
| LearnwareClient.check_learnware(self.zip_path, semantic_spec) | |||
| def test_check_learnware_text(self): | |||
| learnware_id = "00000662" | |||
| with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: | |||
| self.zip_path = os.path.join(tempdir, "test.zip") | |||
| self.client.download_learnware(learnware_id, self.zip_path) | |||
| LearnwareClient.check_learnware(self.zip_path) | |||
| with zipfile.ZipFile(self.zip_path, "r") as zip_file: | |||
| with zip_file.open("semantic_specification.json") as json_file: | |||
| semantic_spec = json.load(json_file) | |||
| LearnwareClient.check_learnware(self.zip_path, semantic_spec) | |||
| if __name__ == "__main__": | |||