| @@ -8,30 +8,35 @@ import tempfile | |||
| import numpy as np | |||
| from learnware.specification import RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification | |||
| from learnware.specification import generate_rkme_image_spec, generate_rkme_table_spec, generate_rkme_text_spec | |||
| from learnware.specification import generate_stat_spec | |||
| class TestRKME(unittest.TestCase): | |||
| def test_rkme(self): | |||
| X = np.random.uniform(-10000, 10000, size=(5000, 200)) | |||
| rkme = generate_rkme_table_spec(X) | |||
| rkme.generate_stat_spec_from_data(X) | |||
| def _test_table_rkme(X): | |||
| rkme = generate_stat_spec(type="table", X=X) | |||
| with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: | |||
| rkme_path = os.path.join(tempdir, "rkme.json") | |||
| rkme.save(rkme_path) | |||
| with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: | |||
| rkme_path = os.path.join(tempdir, "rkme.json") | |||
| rkme.save(rkme_path) | |||
| with open(rkme_path, "r") as f: | |||
| data = json.load(f) | |||
| assert data["type"] == "RKMETableSpecification" | |||
| with open(rkme_path, "r") as f: | |||
| data = json.load(f) | |||
| assert data["type"] == "RKMETableSpecification" | |||
| rkme2 = RKMETableSpecification() | |||
| rkme2.load(rkme_path) | |||
| assert rkme2.type == "RKMETableSpecification" | |||
| rkme2 = RKMETableSpecification() | |||
| rkme2.load(rkme_path) | |||
| assert rkme2.type == "RKMETableSpecification" | |||
| _test_table_rkme(np.random.uniform(-10000, 10000, size=(5000, 200))) | |||
| _test_table_rkme(np.random.uniform(-10000, 10000, size=(10000, 100))) | |||
| _test_table_rkme(np.random.uniform(-10000, 10000, size=(5, 20))) | |||
| _test_table_rkme(np.random.uniform(-10000, 10000, size=(1, 50))) | |||
| _test_table_rkme(np.random.uniform(-10000, 10000, size=(100, 150))) | |||
| def test_image_rkme(self): | |||
| def _test_image_rkme(X): | |||
| image_rkme = generate_rkme_image_spec(X, steps=10) | |||
| image_rkme = generate_stat_spec(type="image", X=X, steps=10) | |||
| with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: | |||
| rkme_path = os.path.join(tempdir, "rkme.json") | |||
| @@ -46,12 +51,12 @@ class TestRKME(unittest.TestCase): | |||
| assert rkme2.type == "RKMEImageSpecification" | |||
| _test_image_rkme(np.random.randint(0, 255, size=(2000, 3, 32, 32))) | |||
| _test_image_rkme(np.random.randint(0, 255, size=(2000, 3, 128, 128))) | |||
| _test_image_rkme(np.random.randint(0, 255, size=(2000, 3, 128, 128)) / 255) | |||
| _test_image_rkme(np.random.randint(0, 255, size=(100, 1, 128, 128))) | |||
| _test_image_rkme(np.random.randint(0, 255, size=(50, 3, 128, 128)) / 255) | |||
| _test_image_rkme(torch.randint(0, 255, (2000, 3, 32, 32))) | |||
| _test_image_rkme(torch.randint(0, 255, (2000, 3, 128, 128))) | |||
| _test_image_rkme(torch.randint(0, 255, (2000, 3, 128, 128)) / 255) | |||
| _test_image_rkme(torch.randint(0, 255, (20, 3, 128, 128))) | |||
| _test_image_rkme(torch.randint(0, 255, (1, 1, 128, 128)) / 255) | |||
| def test_text_rkme(self): | |||
| def generate_random_text_list(num, text_type="en", min_len=10, max_len=1000): | |||
| @@ -70,7 +75,7 @@ class TestRKME(unittest.TestCase): | |||
| return text_list | |||
| def _test_text_rkme(X): | |||
| rkme = generate_rkme_text_spec(X) | |||
| rkme = generate_stat_spec(type="text", X=X) | |||
| with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: | |||
| rkme_path = os.path.join(tempdir, "rkme.json") | |||
| @@ -87,11 +92,12 @@ class TestRKME(unittest.TestCase): | |||
| return rkme2.get_z().shape[1] | |||
| dim1 = _test_text_rkme(generate_random_text_list(3000, "en")) | |||
| dim2 = _test_text_rkme(generate_random_text_list(4000, "en")) | |||
| dim3 = _test_text_rkme(generate_random_text_list(2000, "zh")) | |||
| dim2 = _test_text_rkme(generate_random_text_list(100, "en")) | |||
| dim3 = _test_text_rkme(generate_random_text_list(50, "zh")) | |||
| dim4 = _test_text_rkme(generate_random_text_list(5000, "zh")) | |||
| dim5 = _test_text_rkme(generate_random_text_list(1, "zh")) | |||
| assert dim1 == dim2 and dim2 == dim3 and dim3 == dim4 | |||
| assert dim1 == dim2 and dim2 == dim3 and dim3 == dim4 and dim4 == dim5 | |||
| if __name__ == "__main__": | |||