Browse Source

[MNT] further add cache_dir

tags/v0.3.2
liuht 2 years ago
parent
commit
614f8edcf5
3 changed files with 6 additions and 7 deletions
  1. +1
    -1
      learnware/market/heterogeneous/organizer/__init__.py
  2. +3
    -4
      learnware/market/heterogeneous/organizer/hetero_map/feature_extractor.py
  3. +2
    -2
      tests/test_hetero_market/test_hetero.py

+ 1
- 1
learnware/market/heterogeneous/organizer/__init__.py View File

@@ -45,7 +45,7 @@ class HeteroMapTableOrganizer(EasyOrganizer):
self._update_learnware_by_ids(self.get_learnware_ids(check_status=BaseChecker.USABLE_LEARWARE))
else:
logger.warning(f"No market mapping to reload!")
self.market_mapping = HeteroMap()
self.market_mapping = HeteroMap(cache_dir=hetero_folder_path)

def reset(self, market_id=None, auto_update=False, auto_update_limit=100, **training_args):
self.market_id = market_id


+ 3
- 4
learnware/market/heterogeneous/organizer/hetero_map/feature_extractor.py View File

@@ -65,12 +65,13 @@ class FeatureTokenizer:
def __init__(
self,
disable_tokenizer_parallel=True,
cache_dir=None,
**kwargs,
):
"""args:
disable_tokenizer_parallel: true if use extractor for collator function in torch.DataLoader
"""
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased", cache_dir=cache_dir)
self.tokenizer.__dict__["model_max_length"] = 512
if disable_tokenizer_parallel: # disable tokenizer parallel
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -96,9 +97,7 @@ class FeatureTokenizer:
"""
encoded_inputs = {
"x_num": None,
"num_col_input_ids": None,
"x_cat_input_ids": None,
"x_bin_input_ids": None,
"num_col_input_ids": None
}
num_cols = x.columns.tolist() if not shuffle else np.random.shuffle(x.columns.tolist())
x_num = x[num_cols].fillna(0)


+ 2
- 2
tests/test_hetero_market/test_hetero.py View File

@@ -164,7 +164,7 @@ class TestMarket(unittest.TestCase):
hetero_market = self._init_learnware_market()
self.test_prepare_learnware_randomly(learnware_num)
self.learnware_num = learnware_num
hetero_market.learnware_organizer.reset(auto_update=True, auto_update_limit=learnware_num)
hetero_market.learnware_organizer.reset(auto_update=False, auto_update_limit=learnware_num)

print("Total Item:", len(hetero_market))
assert len(hetero_market) == 0, f"The market should be empty!"
@@ -303,7 +303,7 @@ class TestMarket(unittest.TestCase):
semantic_spec["Input"] = copy.deepcopy(input_description_list[idx % 2])
semantic_spec["Input"]["Dimension"] = user_dim - 2
semantic_spec["Input"]["Description"] = {
"key": semantic_spec["Input"]["Description"][str(key)] for key in range(user_dim)
str(key): semantic_spec["Input"]["Description"][str(key)] for key in range(user_dim)
}

user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec})


Loading…
Cancel
Save