You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_image_rkme.py 1.4 kB

123456789101112131415161718192021222324252627282930313233343536373839
  1. import os
  2. import json
  3. import torch
  4. import unittest
  5. import tempfile
  6. import numpy as np
  7. from learnware.specification import RKMEImageSpecification
  8. from learnware.specification import generate_stat_spec
  9. class TestImageRKME(unittest.TestCase):
  10. @staticmethod
  11. def _test_image_rkme(X):
  12. image_rkme = generate_stat_spec(type="image", X=X, steps=10)
  13. with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
  14. rkme_path = os.path.join(tempdir, "rkme.json")
  15. image_rkme.save(rkme_path)
  16. with open(rkme_path, "r") as f:
  17. data = json.load(f)
  18. assert data["type"] == "RKMEImageSpecification"
  19. rkme2 = RKMEImageSpecification()
  20. rkme2.load(rkme_path)
  21. assert rkme2.type == "RKMEImageSpecification"
  22. def test_image_rkme(self):
  23. self._test_image_rkme(np.random.randint(0, 255, size=(2000, 3, 32, 32)))
  24. self._test_image_rkme(np.random.randint(0, 255, size=(100, 1, 128, 128)))
  25. self._test_image_rkme(np.random.randint(0, 255, size=(50, 3, 128, 128)) / 255)
  26. self._test_image_rkme(torch.randint(0, 255, (2000, 3, 32, 32)))
  27. self._test_image_rkme(torch.randint(0, 255, (20, 3, 128, 128)))
  28. self._test_image_rkme(torch.randint(0, 255, (1, 1, 128, 128)) / 255)
  29. if __name__ == "__main__":
  30. unittest.main()