You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_image_rkme.py 1.3 kB

123456789101112131415161718192021222324252627282930313233343536373839
  1. import json
  2. import os
  3. import tempfile
  4. import unittest
  5. import numpy as np
  6. import torch
  7. from learnware.specification import RKMEImageSpecification, generate_stat_spec
  8. class TestImageRKME(unittest.TestCase):
  9. @staticmethod
  10. def _test_image_rkme(X):
  11. image_rkme = generate_stat_spec(type="image", X=X, steps=10)
  12. with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
  13. rkme_path = os.path.join(tempdir, "rkme.json")
  14. image_rkme.save(rkme_path)
  15. with open(rkme_path, "r") as f:
  16. data = json.load(f)
  17. assert data["type"] == "RKMEImageSpecification"
  18. rkme2 = RKMEImageSpecification()
  19. rkme2.load(rkme_path)
  20. assert rkme2.type == "RKMEImageSpecification"
  21. def test_image_rkme(self):
  22. self._test_image_rkme(np.random.randint(0, 255, size=(2000, 3, 32, 32)))
  23. self._test_image_rkme(np.random.randint(0, 255, size=(100, 1, 128, 128)))
  24. self._test_image_rkme(np.random.randint(0, 255, size=(50, 3, 128, 128)) / 255)
  25. self._test_image_rkme(torch.randint(0, 255, (2000, 3, 32, 32)))
  26. self._test_image_rkme(torch.randint(0, 255, (20, 3, 128, 128)))
  27. self._test_image_rkme(torch.randint(0, 255, (1, 1, 128, 128)) / 255)
  28. if __name__ == "__main__":
  29. unittest.main()