Browse Source

[FIX] change name of mapping

pull/1/head
troyyyyy 1 year ago
parent
commit
4213fa0063
3 changed files with 24 additions and 24 deletions
  1. +2
    -2
      abl/bridge/simple_bridge.py
  2. +16
    -16
      abl/reasoning/reasoner.py
  3. +6
    -6
      examples/hed/hed_bridge.py

+ 2
- 2
abl/bridge/simple_bridge.py View File

@@ -95,7 +95,7 @@ class SimpleBridge(BaseBridge):
"""
pred_idx = data_samples.pred_idx
data_samples.pred_pseudo_label = [
[self.reasoner.mapping[_idx] for _idx in sub_list] for sub_list in pred_idx
[self.reasoner.idx_to_label[_idx] for _idx in sub_list] for sub_list in pred_idx
]
return data_samples.pred_pseudo_label

@@ -114,7 +114,7 @@ class SimpleBridge(BaseBridge):
A list of indices converted from pseudo labels.
"""
abduced_idx = [
[self.reasoner.remapping[_abduced_pseudo_label] for _abduced_pseudo_label in sub_list]
[self.reasoner.label_to_idx[_abduced_pseudo_label] for _abduced_pseudo_label in sub_list]
for sub_list in data_samples.abduced_pseudo_label
]
data_samples.abduced_idx = abduced_idx


+ 16
- 16
abl/reasoning/reasoner.py View File

@@ -32,9 +32,9 @@ class Reasoner:
in this cost list should be a numerical value representing the cost for each
candidate, and the list should have the same length as candidates.
Defaults to 'confidence'.
mapping : Optional[dict], optional
idx_to_label : Optional[dict], optional
A mapping from index in the base model to label. If not provided, a default
order-based mapping is created. Defaults to None.
order-based index to label mapping is created. Defaults to None.
max_revision : Union[int, float], optional
The upper limit on the number of revisions for each data sample when
performing abductive reasoning. If float, denotes the fraction of the total
@@ -51,7 +51,7 @@ class Reasoner:
self,
kb: KBBase,
dist_func: Union[str, Callable] = "confidence",
mapping: Optional[dict] = None,
idx_to_label: Optional[dict] = None,
max_revision: Union[int, float] = -1,
require_more_revision: int = 0,
use_zoopt: bool = False,
@@ -63,12 +63,12 @@ class Reasoner:
self.max_revision = max_revision
self.require_more_revision = require_more_revision

if mapping is None:
self.mapping = {index: label for index, label in enumerate(self.kb.pseudo_label_list)}
if idx_to_label is None:
self.idx_to_label = {index: label for index, label in enumerate(self.kb.pseudo_label_list)}
else:
self._check_valid_mapping(mapping)
self.mapping = mapping
self.remapping = dict(zip(self.mapping.values(), self.mapping.keys()))
self._check_valid_idx_to_label(idx_to_label)
self.idx_to_label = idx_to_label
self.label_to_idx = dict(zip(self.idx_to_label.values(), self.idx_to_label.keys()))

def _check_valid_dist(self, dist_func):
if isinstance(dist_func, str):
@@ -87,15 +87,15 @@ class Reasoner:
f"dist_func must be a string or a callable function, but got {type(dist_func)}."
)

def _check_valid_mapping(self, mapping):
if not isinstance(mapping, dict):
raise TypeError(f"mapping should be dict, but got {type(mapping)}.")
for key, value in mapping.items():
def _check_valid_idx_to_label(self, idx_to_label):
if not isinstance(idx_to_label, dict):
raise TypeError(f"idx_to_label should be dict, but got {type(idx_to_label)}.")
for key, value in idx_to_label.items():
if not isinstance(key, int):
raise ValueError(f"All keys in the mapping must be integers, but got {key}.")
raise ValueError(f"All keys in the idx_to_label must be integers, but got {key}.")
if value not in self.kb.pseudo_label_list:
raise ValueError(
f"All values in the mapping must be in the pseudo_label_list, but got {value}."
f"All values in the idx_to_label must be in the pseudo_label_list, but got {value}."
)

def _get_one_candidate(
@@ -158,10 +158,10 @@ class Reasoner:
if self.dist_func == "hamming":
return hamming_dist(data_sample.pred_pseudo_label, candidates)
elif self.dist_func == "confidence":
candidates = [[self.remapping[x] for x in c] for c in candidates]
candidates = [[self.label_to_idx[x] for x in c] for c in candidates]
return confidence_dist(data_sample.pred_prob, candidates)
else:
candidate_idxs = [[self.remapping[x] for x in c] for c in candidates]
candidate_idxs = [[self.label_to_idx[x] for x in c] for c in candidates]
cost_list = self.dist_func(data_sample, candidates, candidate_idxs, reasoning_results)
if len(cost_list) != len(candidates):
raise ValueError(


+ 6
- 6
examples/hed/hed_bridge.py View File

@@ -67,8 +67,8 @@ class HEDBridge(SimpleBridge):
mapping_score = []
abduced_pseudo_label_list = []
for _mapping in candidate_mappings:
self.reasoner.mapping = _mapping
self.reasoner.remapping = dict(zip(_mapping.values(), _mapping.keys()))
self.reasoner.idx_to_label = _mapping
self.reasoner.label_to_idx = dict(zip(_mapping.values(), _mapping.keys()))
self.idx_to_pseudo_label(data_samples)
abduced_pseudo_label = self.reasoner.abduce(data_samples)
mapping_score.append(len(abduced_pseudo_label) - abduced_pseudo_label.count([]))
@@ -76,9 +76,9 @@ class HEDBridge(SimpleBridge):

max_revisible_instances = max(mapping_score)
return_idx = mapping_score.index(max_revisible_instances)
self.reasoner.mapping = candidate_mappings[return_idx]
self.reasoner.remapping = dict(
zip(self.reasoner.mapping.values(), self.reasoner.mapping.keys())
self.reasoner.idx_to_label = candidate_mappings[return_idx]
self.reasoner.label_to_idx = dict(
zip(self.reasoner.idx_to_label.values(), self.reasoner.idx_to_label.keys())
)
self.idx_to_pseudo_label(data_samples)
data_samples.abduced_pseudo_label = abduced_pseudo_label_list[return_idx]
@@ -236,7 +236,7 @@ class HEDBridge(SimpleBridge):
else:
if equation_len == min_len:
print_log(
"Learned mapping is: " + str(self.reasoner.mapping),
"Learned mapping is: " + str(self.reasoner.idx_to_label),
logger="current",
)
self.model.load(load_path="./weights/pretrain_weights.pth")


Loading…
Cancel
Save