You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_table_rkme.py 1.2 kB

12345678910111213141516171819202122232425262728293031323334353637
  1. import json
  2. import os
  3. import tempfile
  4. import unittest
  5. import numpy as np
  6. from learnware.specification import RKMETableSpecification, generate_stat_spec
  7. class TestTableRKME(unittest.TestCase):
  8. @staticmethod
  9. def _test_table_rkme(X):
  10. rkme = generate_stat_spec(type="table", X=X)
  11. with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
  12. rkme_path = os.path.join(tempdir, "rkme.json")
  13. rkme.save(rkme_path)
  14. with open(rkme_path, "r") as f:
  15. data = json.load(f)
  16. assert data["type"] == "RKMETableSpecification"
  17. rkme2 = RKMETableSpecification()
  18. rkme2.load(rkme_path)
  19. assert rkme2.type == "RKMETableSpecification"
  20. def test_table_rkme(self):
  21. self._test_table_rkme(np.random.uniform(-10000, 10000, size=(5000, 200)))
  22. self._test_table_rkme(np.random.uniform(-10000, 10000, size=(10000, 100)))
  23. self._test_table_rkme(np.random.uniform(-10000, 10000, size=(5, 20)))
  24. self._test_table_rkme(np.random.uniform(-10000, 10000, size=(1, 50)))
  25. self._test_table_rkme(np.random.uniform(-10000, 10000, size=(100, 150)))
  26. if __name__ == "__main__":
  27. unittest.main()