diff --git a/learnware/market/easy/searcher.py b/learnware/market/easy/searcher.py index 8de00af..fc04297 100644 --- a/learnware/market/easy/searcher.py +++ b/learnware/market/easy/searcher.py @@ -591,10 +591,10 @@ class EasyStatSearcher(BaseSearcher): class EasySearcher(BaseSearcher): - def __init__(self, organizer: EasyOrganizer = None): - super(EasySearcher, self).__init__(organizer) + def __init__(self, organizer: EasyOrganizer): self.semantic_searcher = EasyFuzzSemanticSearcher(organizer) self.stat_searcher = EasyStatSearcher(organizer) + super(EasySearcher, self).__init__(organizer) def reset(self, organizer): self.learnware_organizer = organizer diff --git a/tests/test_hetero_market/test_hetero.py b/tests/test_hetero_market/test_hetero.py index 3a55ca7..58c285e 100644 --- a/tests/test_hetero_market/test_hetero.py +++ b/tests/test_hetero_market/test_hetero.py @@ -57,9 +57,11 @@ class TestMarket(unittest.TestCase): np.random.seed(2023) learnware.init() - def _init_learnware_market(self): + def _init_learnware_market(self, organizer_kwargs=None): """initialize learnware market""" - hetero_market = instantiate_learnware_market(market_id="hetero_toy", name="hetero", rebuild=True) + hetero_market = instantiate_learnware_market( + market_id="hetero_toy", name="hetero", rebuild=True, organizer_kwargs=organizer_kwargs + ) return hetero_market def test_prepare_learnware_randomly(self, learnware_num=5): @@ -161,10 +163,11 @@ class TestMarket(unittest.TestCase): return hetero_market def test_train_market_model(self, learnware_num=5): - hetero_market = self._init_learnware_market() + hetero_market = self._init_learnware_market( + organizer_kwargs={"auto_update": False, "auto_update_limit": learnware_num} + ) self.test_prepare_learnware_randomly(learnware_num) self.learnware_num = learnware_num - hetero_market.learnware_organizer.reset(auto_update=False, auto_update_limit=learnware_num) print("Total Item:", len(hetero_market)) assert len(hetero_market) == 0, f"The market should be empty!" @@ -407,13 +410,13 @@ class TestMarket(unittest.TestCase): def suite(): _suite = unittest.TestSuite() - # _suite.addTest(TestMarket("test_prepare_learnware_randomly")) - # _suite.addTest(TestMarket("test_generated_learnwares")) - # _suite.addTest(TestMarket("test_upload_delete_learnware")) - # _suite.addTest(TestMarket("test_train_market_model")) - # _suite.addTest(TestMarket("test_search_semantics")) + _suite.addTest(TestMarket("test_prepare_learnware_randomly")) + _suite.addTest(TestMarket("test_generated_learnwares")) + _suite.addTest(TestMarket("test_upload_delete_learnware")) + _suite.addTest(TestMarket("test_train_market_model")) + _suite.addTest(TestMarket("test_search_semantics")) _suite.addTest(TestMarket("test_stat_search")) - # _suite.addTest(TestMarket("test_model_reuse")) + _suite.addTest(TestMarket("test_model_reuse")) return _suite