|
- import os
- import json
- import unittest
- import tempfile
- import numpy as np
-
- from learnware.specification import RKMETableSpecification, HeteroMapTableSpecification
- from learnware.specification import generate_stat_spec
- from learnware.market.heterogeneous.organizer import HeteroMap
-
-
- class TestTableRKME(unittest.TestCase):
- def setUp(self):
- self.hetero_map = HeteroMap()
-
- def _test_hetero_spec(self, X):
- rkme: RKMETableSpecification = generate_stat_spec(type="table", X=X)
- hetero_spec = self.hetero_map.hetero_mapping(rkme_spec=rkme, features=dict())
- with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
- rkme_path = os.path.join(tempdir, "rkme.json")
- hetero_spec.save(rkme_path)
-
- with open(rkme_path, "r") as f:
- data = json.load(f)
- assert data["type"] == "HeteroMapTableSpecification"
-
- rkme2 = HeteroMapTableSpecification()
- rkme2.load(rkme_path)
- assert rkme2.type == "HeteroMapTableSpecification"
-
- def test_hetero_rkme(self):
- self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(5000, 200)))
- self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(10000, 100)))
- self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(5, 20)))
- self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(1, 50)))
- self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(100, 150)))
-
-
- if __name__ == "__main__":
- unittest.main()
|