Browse Source

[MNT] add the corner case in the test

tags/v0.3.2
Gene 2 years ago
parent
commit
6eed14ca03
1 changed files with 28 additions and 22 deletions
  1. +28
    -22
      tests/test_specification/test_rkme.py

+ 28
- 22
tests/test_specification/test_rkme.py View File

@@ -8,30 +8,35 @@ import tempfile
import numpy as np

from learnware.specification import RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification
from learnware.specification import generate_rkme_image_spec, generate_rkme_table_spec, generate_rkme_text_spec
from learnware.specification import generate_stat_spec


class TestRKME(unittest.TestCase):
def test_rkme(self):
X = np.random.uniform(-10000, 10000, size=(5000, 200))
rkme = generate_rkme_table_spec(X)
rkme.generate_stat_spec_from_data(X)
def _test_table_rkme(X):
rkme = generate_stat_spec(type="table", X=X)

with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
rkme_path = os.path.join(tempdir, "rkme.json")
rkme.save(rkme_path)
with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
rkme_path = os.path.join(tempdir, "rkme.json")
rkme.save(rkme_path)

with open(rkme_path, "r") as f:
data = json.load(f)
assert data["type"] == "RKMETableSpecification"
with open(rkme_path, "r") as f:
data = json.load(f)
assert data["type"] == "RKMETableSpecification"

rkme2 = RKMETableSpecification()
rkme2.load(rkme_path)
assert rkme2.type == "RKMETableSpecification"

rkme2 = RKMETableSpecification()
rkme2.load(rkme_path)
assert rkme2.type == "RKMETableSpecification"
_test_table_rkme(np.random.uniform(-10000, 10000, size=(5000, 200)))
_test_table_rkme(np.random.uniform(-10000, 10000, size=(10000, 100)))
_test_table_rkme(np.random.uniform(-10000, 10000, size=(5, 20)))
_test_table_rkme(np.random.uniform(-10000, 10000, size=(1, 50)))
_test_table_rkme(np.random.uniform(-10000, 10000, size=(100, 150)))

def test_image_rkme(self):
def _test_image_rkme(X):
image_rkme = generate_rkme_image_spec(X, steps=10)
image_rkme = generate_stat_spec(type="image", X=X, steps=10)

with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
rkme_path = os.path.join(tempdir, "rkme.json")
@@ -46,12 +51,12 @@ class TestRKME(unittest.TestCase):
assert rkme2.type == "RKMEImageSpecification"

_test_image_rkme(np.random.randint(0, 255, size=(2000, 3, 32, 32)))
_test_image_rkme(np.random.randint(0, 255, size=(2000, 3, 128, 128)))
_test_image_rkme(np.random.randint(0, 255, size=(2000, 3, 128, 128)) / 255)
_test_image_rkme(np.random.randint(0, 255, size=(100, 1, 128, 128)))
_test_image_rkme(np.random.randint(0, 255, size=(50, 3, 128, 128)) / 255)

_test_image_rkme(torch.randint(0, 255, (2000, 3, 32, 32)))
_test_image_rkme(torch.randint(0, 255, (2000, 3, 128, 128)))
_test_image_rkme(torch.randint(0, 255, (2000, 3, 128, 128)) / 255)
_test_image_rkme(torch.randint(0, 255, (20, 3, 128, 128)))
_test_image_rkme(torch.randint(0, 255, (1, 1, 128, 128)) / 255)

def test_text_rkme(self):
def generate_random_text_list(num, text_type="en", min_len=10, max_len=1000):
@@ -70,7 +75,7 @@ class TestRKME(unittest.TestCase):
return text_list

def _test_text_rkme(X):
rkme = generate_rkme_text_spec(X)
rkme = generate_stat_spec(type="text", X=X)

with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
rkme_path = os.path.join(tempdir, "rkme.json")
@@ -87,11 +92,12 @@ class TestRKME(unittest.TestCase):
return rkme2.get_z().shape[1]

dim1 = _test_text_rkme(generate_random_text_list(3000, "en"))
dim2 = _test_text_rkme(generate_random_text_list(4000, "en"))
dim3 = _test_text_rkme(generate_random_text_list(2000, "zh"))
dim2 = _test_text_rkme(generate_random_text_list(100, "en"))
dim3 = _test_text_rkme(generate_random_text_list(50, "zh"))
dim4 = _test_text_rkme(generate_random_text_list(5000, "zh"))
dim5 = _test_text_rkme(generate_random_text_list(1, "zh"))

assert dim1 == dim2 and dim2 == dim3 and dim3 == dim4
assert dim1 == dim2 and dim2 == dim3 and dim3 == dim4 and dim4 == dim5


if __name__ == "__main__":


Loading…
Cancel
Save