diff --git a/learnware/market/heterogeneous/organizer/__init__.py b/learnware/market/heterogeneous/organizer/__init__.py index f5dfd4e..76879bb 100644 --- a/learnware/market/heterogeneous/organizer/__init__.py +++ b/learnware/market/heterogeneous/organizer/__init__.py @@ -45,7 +45,7 @@ class HeteroMapTableOrganizer(EasyOrganizer): self._update_learnware_by_ids(self.get_learnware_ids(check_status=BaseChecker.USABLE_LEARWARE)) else: logger.warning(f"No market mapping to reload!") - self.market_mapping = HeteroMap() + self.market_mapping = HeteroMap(cache_dir=hetero_folder_path) def reset(self, market_id=None, auto_update=False, auto_update_limit=100, **training_args): self.market_id = market_id diff --git a/learnware/market/heterogeneous/organizer/hetero_map/feature_extractor.py b/learnware/market/heterogeneous/organizer/hetero_map/feature_extractor.py index e47d702..89105c0 100644 --- a/learnware/market/heterogeneous/organizer/hetero_map/feature_extractor.py +++ b/learnware/market/heterogeneous/organizer/hetero_map/feature_extractor.py @@ -65,12 +65,13 @@ class FeatureTokenizer: def __init__( self, disable_tokenizer_parallel=True, + cache_dir=None, **kwargs, ): """args: disable_tokenizer_parallel: true if use extractor for collator function in torch.DataLoader """ - self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") + self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased", cache_dir=cache_dir) self.tokenizer.__dict__["model_max_length"] = 512 if disable_tokenizer_parallel: # disable tokenizer parallel os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -96,9 +97,7 @@ class FeatureTokenizer: """ encoded_inputs = { "x_num": None, - "num_col_input_ids": None, - "x_cat_input_ids": None, - "x_bin_input_ids": None, + "num_col_input_ids": None } num_cols = x.columns.tolist() if not shuffle else np.random.shuffle(x.columns.tolist()) x_num = x[num_cols].fillna(0) diff --git a/tests/test_hetero_market/test_hetero.py b/tests/test_hetero_market/test_hetero.py index aa3c7a0..3a55ca7 100644 --- a/tests/test_hetero_market/test_hetero.py +++ b/tests/test_hetero_market/test_hetero.py @@ -164,7 +164,7 @@ class TestMarket(unittest.TestCase): hetero_market = self._init_learnware_market() self.test_prepare_learnware_randomly(learnware_num) self.learnware_num = learnware_num - hetero_market.learnware_organizer.reset(auto_update=True, auto_update_limit=learnware_num) + hetero_market.learnware_organizer.reset(auto_update=False, auto_update_limit=learnware_num) print("Total Item:", len(hetero_market)) assert len(hetero_market) == 0, f"The market should be empty!" @@ -303,7 +303,7 @@ class TestMarket(unittest.TestCase): semantic_spec["Input"] = copy.deepcopy(input_description_list[idx % 2]) semantic_spec["Input"]["Dimension"] = user_dim - 2 semantic_spec["Input"]["Description"] = { - "key": semantic_spec["Input"]["Description"][str(key)] for key in range(user_dim) + str(key): semantic_spec["Input"]["Description"][str(key)] for key in range(user_dim) } user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec})