You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_hetero_spec.py 1.5 kB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. import os
  2. import json
  3. import unittest
  4. import tempfile
  5. import numpy as np
  6. from learnware.specification import RKMETableSpecification, HeteroMapTableSpecification
  7. from learnware.specification import generate_stat_spec
  8. from learnware.market.heterogeneous.organizer import HeteroMap
  9. class TestTableRKME(unittest.TestCase):
  10. def setUp(self):
  11. self.hetero_map = HeteroMap()
  12. def _test_hetero_spec(self, X):
  13. rkme: RKMETableSpecification = generate_stat_spec(type="table", X=X)
  14. hetero_spec = self.hetero_map.hetero_mapping(rkme_spec=rkme, features=dict())
  15. with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
  16. rkme_path = os.path.join(tempdir, "rkme.json")
  17. hetero_spec.save(rkme_path)
  18. with open(rkme_path, "r") as f:
  19. data = json.load(f)
  20. assert data["type"] == "HeteroMapTableSpecification"
  21. rkme2 = HeteroMapTableSpecification()
  22. rkme2.load(rkme_path)
  23. assert rkme2.type == "HeteroMapTableSpecification"
  24. def test_hetero_rkme(self):
  25. self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(5000, 200)))
  26. self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(10000, 100)))
  27. self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(5, 20)))
  28. self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(1, 50)))
  29. self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(100, 150)))
  30. if __name__ == "__main__":
  31. unittest.main()