| @@ -45,7 +45,7 @@ class HeteroMapTableOrganizer(EasyOrganizer): | |||||
| self._update_learnware_by_ids(self.get_learnware_ids(check_status=BaseChecker.USABLE_LEARWARE)) | self._update_learnware_by_ids(self.get_learnware_ids(check_status=BaseChecker.USABLE_LEARWARE)) | ||||
| else: | else: | ||||
| logger.warning(f"No market mapping to reload!") | logger.warning(f"No market mapping to reload!") | ||||
| self.market_mapping = HeteroMap() | |||||
| self.market_mapping = HeteroMap(cache_dir=hetero_folder_path) | |||||
| def reset(self, market_id=None, auto_update=False, auto_update_limit=100, **training_args): | def reset(self, market_id=None, auto_update=False, auto_update_limit=100, **training_args): | ||||
| self.market_id = market_id | self.market_id = market_id | ||||
| @@ -65,12 +65,13 @@ class FeatureTokenizer: | |||||
| def __init__( | def __init__( | ||||
| self, | self, | ||||
| disable_tokenizer_parallel=True, | disable_tokenizer_parallel=True, | ||||
| cache_dir=None, | |||||
| **kwargs, | **kwargs, | ||||
| ): | ): | ||||
| """args: | """args: | ||||
| disable_tokenizer_parallel: true if use extractor for collator function in torch.DataLoader | disable_tokenizer_parallel: true if use extractor for collator function in torch.DataLoader | ||||
| """ | """ | ||||
| self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") | |||||
| self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased", cache_dir=cache_dir) | |||||
| self.tokenizer.__dict__["model_max_length"] = 512 | self.tokenizer.__dict__["model_max_length"] = 512 | ||||
| if disable_tokenizer_parallel: # disable tokenizer parallel | if disable_tokenizer_parallel: # disable tokenizer parallel | ||||
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | os.environ["TOKENIZERS_PARALLELISM"] = "false" | ||||
| @@ -96,9 +97,7 @@ class FeatureTokenizer: | |||||
| """ | """ | ||||
| encoded_inputs = { | encoded_inputs = { | ||||
| "x_num": None, | "x_num": None, | ||||
| "num_col_input_ids": None, | |||||
| "x_cat_input_ids": None, | |||||
| "x_bin_input_ids": None, | |||||
| "num_col_input_ids": None | |||||
| } | } | ||||
| num_cols = x.columns.tolist() if not shuffle else np.random.shuffle(x.columns.tolist()) | num_cols = x.columns.tolist() if not shuffle else np.random.shuffle(x.columns.tolist()) | ||||
| x_num = x[num_cols].fillna(0) | x_num = x[num_cols].fillna(0) | ||||
| @@ -164,7 +164,7 @@ class TestMarket(unittest.TestCase): | |||||
| hetero_market = self._init_learnware_market() | hetero_market = self._init_learnware_market() | ||||
| self.test_prepare_learnware_randomly(learnware_num) | self.test_prepare_learnware_randomly(learnware_num) | ||||
| self.learnware_num = learnware_num | self.learnware_num = learnware_num | ||||
| hetero_market.learnware_organizer.reset(auto_update=True, auto_update_limit=learnware_num) | |||||
| hetero_market.learnware_organizer.reset(auto_update=False, auto_update_limit=learnware_num) | |||||
| print("Total Item:", len(hetero_market)) | print("Total Item:", len(hetero_market)) | ||||
| assert len(hetero_market) == 0, f"The market should be empty!" | assert len(hetero_market) == 0, f"The market should be empty!" | ||||
| @@ -303,7 +303,7 @@ class TestMarket(unittest.TestCase): | |||||
| semantic_spec["Input"] = copy.deepcopy(input_description_list[idx % 2]) | semantic_spec["Input"] = copy.deepcopy(input_description_list[idx % 2]) | ||||
| semantic_spec["Input"]["Dimension"] = user_dim - 2 | semantic_spec["Input"]["Dimension"] = user_dim - 2 | ||||
| semantic_spec["Input"]["Description"] = { | semantic_spec["Input"]["Description"] = { | ||||
| "key": semantic_spec["Input"]["Description"][str(key)] for key in range(user_dim) | |||||
| str(key): semantic_spec["Input"]["Description"][str(key)] for key in range(user_dim) | |||||
| } | } | ||||
| user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) | user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) | ||||