Browse Source

[FIX | MNT] fix bugs in easy seacher, update test to fit with current interface

tags/v0.3.2
bxdd 2 years ago
parent
commit
4d417669c5
2 changed files with 15 additions and 12 deletions
  1. +2
    -2
      learnware/market/easy/searcher.py
  2. +13
    -10
      tests/test_hetero_market/test_hetero.py

+ 2
- 2
learnware/market/easy/searcher.py View File

@@ -591,10 +591,10 @@ class EasyStatSearcher(BaseSearcher):


class EasySearcher(BaseSearcher):
def __init__(self, organizer: EasyOrganizer = None):
super(EasySearcher, self).__init__(organizer)
def __init__(self, organizer: EasyOrganizer):
self.semantic_searcher = EasyFuzzSemanticSearcher(organizer)
self.stat_searcher = EasyStatSearcher(organizer)
super(EasySearcher, self).__init__(organizer)

def reset(self, organizer):
self.learnware_organizer = organizer


+ 13
- 10
tests/test_hetero_market/test_hetero.py View File

@@ -57,9 +57,11 @@ class TestMarket(unittest.TestCase):
np.random.seed(2023)
learnware.init()

def _init_learnware_market(self):
def _init_learnware_market(self, organizer_kwargs=None):
"""initialize learnware market"""
hetero_market = instantiate_learnware_market(market_id="hetero_toy", name="hetero", rebuild=True)
hetero_market = instantiate_learnware_market(
market_id="hetero_toy", name="hetero", rebuild=True, organizer_kwargs=organizer_kwargs
)
return hetero_market

def test_prepare_learnware_randomly(self, learnware_num=5):
@@ -161,10 +163,11 @@ class TestMarket(unittest.TestCase):
return hetero_market

def test_train_market_model(self, learnware_num=5):
hetero_market = self._init_learnware_market()
hetero_market = self._init_learnware_market(
organizer_kwargs={"auto_update": False, "auto_update_limit": learnware_num}
)
self.test_prepare_learnware_randomly(learnware_num)
self.learnware_num = learnware_num
hetero_market.learnware_organizer.reset(auto_update=False, auto_update_limit=learnware_num)

print("Total Item:", len(hetero_market))
assert len(hetero_market) == 0, f"The market should be empty!"
@@ -407,13 +410,13 @@ class TestMarket(unittest.TestCase):

def suite():
_suite = unittest.TestSuite()
# _suite.addTest(TestMarket("test_prepare_learnware_randomly"))
# _suite.addTest(TestMarket("test_generated_learnwares"))
# _suite.addTest(TestMarket("test_upload_delete_learnware"))
# _suite.addTest(TestMarket("test_train_market_model"))
# _suite.addTest(TestMarket("test_search_semantics"))
_suite.addTest(TestMarket("test_prepare_learnware_randomly"))
_suite.addTest(TestMarket("test_generated_learnwares"))
_suite.addTest(TestMarket("test_upload_delete_learnware"))
_suite.addTest(TestMarket("test_train_market_model"))
_suite.addTest(TestMarket("test_search_semantics"))
_suite.addTest(TestMarket("test_stat_search"))
# _suite.addTest(TestMarket("test_model_reuse"))
_suite.addTest(TestMarket("test_model_reuse"))
return _suite




Loading…
Cancel
Save