You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_rkme.py 2.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. import os
  2. import json
  3. import torch
  4. import unittest
  5. import tempfile
  6. import numpy as np
  7. from learnware.specification import RKMETableSpecification, RKMEImageSpecification
  8. from learnware.specification import generate_rkme_image_spec, generate_rkme_spec
  9. class TestRKME(unittest.TestCase):
  10. def test_rkme(self):
  11. pass
  12. X = np.random.uniform(-10000, 10000, size=(5000, 200))
  13. rkme = generate_rkme_spec(X)
  14. rkme.generate_stat_spec_from_data(X)
  15. with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
  16. rkme_path = os.path.join(tempdir, "rkme.json")
  17. rkme.save(rkme_path)
  18. with open(rkme_path, "r") as f:
  19. data = json.load(f)
  20. assert data["type"] == "RKMETableSpecification"
  21. rkme2 = RKMETableSpecification()
  22. rkme2.load(rkme_path)
  23. assert rkme2.type == "RKMETableSpecification"
  24. def test_image_rkme(self):
  25. def _test_image_rkme(X):
  26. image_rkme = generate_rkme_image_spec(X, steps=10)
  27. with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
  28. rkme_path = os.path.join(tempdir, "rkme.json")
  29. image_rkme.save(rkme_path)
  30. with open(rkme_path, "r") as f:
  31. data = json.load(f)
  32. assert data["type"] == "RKMEImageSpecification"
  33. rkme2 = RKMEImageSpecification()
  34. rkme2.load(rkme_path)
  35. assert rkme2.type == "RKMEImageSpecification"
  36. _test_image_rkme(np.random.randint(0, 255, size=(2000, 3, 32, 32)))
  37. _test_image_rkme(np.random.randint(0, 255, size=(2000, 3, 128, 128)))
  38. _test_image_rkme(np.random.randint(0, 255, size=(2000, 3, 128, 128)) / 255)
  39. _test_image_rkme(torch.randint(0, 255, (2000, 3, 32, 32)))
  40. _test_image_rkme(torch.randint(0, 255, (2000, 3, 128, 128)))
  41. _test_image_rkme(torch.randint(0, 255, (2000, 3, 128, 128)) / 255)
  42. if __name__ == "__main__":
  43. unittest.main()