diff --git a/abl/bridge/base_bridge.py b/abl/bridge/base_bridge.py index d09f1e0..03054f7 100644 --- a/abl/bridge/base_bridge.py +++ b/abl/bridge/base_bridge.py @@ -26,12 +26,12 @@ class BaseBridge(metaclass=ABCMeta): """Placeholder for abduce pseudo labels.""" @abstractmethod - def label_to_pseudo_label(self, label: List[List[Any]]) -> List[List[Any]]: + def idx_to_pseudo_label(self, idx: List[List[Any]]) -> List[List[Any]]: """Placeholder for map label space to symbol space.""" pass @abstractmethod - def pseudo_label_to_label(self, pseudo_label: List[List[Any]]) -> List[List[Any]]: + def pseudo_label_to_idx(self, pseudo_label: List[List[Any]]) -> List[List[Any]]: """Placeholder for map symbol space to label space.""" pass diff --git a/abl/bridge/simple_bridge.py b/abl/bridge/simple_bridge.py index bc58699..7286d42 100644 --- a/abl/bridge/simple_bridge.py +++ b/abl/bridge/simple_bridge.py @@ -22,28 +22,27 @@ class SimpleBridge(BaseBridge): def predict(self, X) -> Tuple[List[List[Any]], ndarray]: pred_res = self.model.predict(X) - pred_label, pred_prob = pred_res["label"], pred_res["prob"] - return pred_label, pred_prob + pred_idx, pred_prob = pred_res["label"], pred_res["prob"] + return pred_idx, pred_prob def abduce_pseudo_label( self, - pred_label: List[List[Any]], pred_prob: ndarray, - pseudo_label: List[List[Any]], - Y: List[List[Any]], + pred_pseudo_label: List[List[Any]], + Y: List[Any], max_revision: int = -1, require_more_revision: int = 0, ) -> List[List[Any]]: - return self.abducer.batch_abduce(pred_prob, pseudo_label, Y, max_revision, require_more_revision) + return self.abducer.batch_abduce(pred_prob, pred_pseudo_label, Y, max_revision, require_more_revision) - def label_to_pseudo_label( - self, label: List[List[Any]], mapping: Dict = None + def idx_to_pseudo_label( + self, idx: List[List[Any]], mapping: Dict = None ) -> List[List[Any]]: if mapping is None: mapping = self.abducer.mapping - return [[mapping[_label] for _label in sub_list] for sub_list in label] + return [[mapping[_idx] for _idx in sub_list] for sub_list in idx] - def pseudo_label_to_label( + def pseudo_label_to_idx( self, pseudo_label: List[List[Any]], mapping: Dict = None ) -> List[List[Any]]: if mapping is None: @@ -69,12 +68,12 @@ class SimpleBridge(BaseBridge): for epoch in range(epochs): for seg_idx, (X, Z, Y) in enumerate(data_loader): - pred_label, pred_prob = self.predict(X) - pred_pseudo_label = self.label_to_pseudo_label(pred_label) + pred_idx, pred_prob = self.predict(X) + pred_pseudo_label = self.idx_to_pseudo_label(pred_idx) abduced_pseudo_label = self.abduce_pseudo_label( - pred_label, pred_prob, pred_pseudo_label, Y + pred_prob, pred_pseudo_label, Y ) - abduced_label = self.pseudo_label_to_label(abduced_pseudo_label) + abduced_label = self.pseudo_label_to_idx(abduced_pseudo_label) min_loss = self.model.train(X, abduced_label) print_log( @@ -88,10 +87,10 @@ class SimpleBridge(BaseBridge): def _valid(self, data_loader): for X, Z, Y in data_loader: - pred_label, pred_prob = self.predict(X) - pred_pseudo_label = self.label_to_pseudo_label(pred_label) + pred_idx, pred_prob = self.predict(X) + pred_pseudo_label = self.idx_to_pseudo_label(pred_idx) data_samples = dict( - pred_label=pred_label, + pred_idx=pred_idx, pred_prob=pred_prob, pred_pseudo_label=pred_pseudo_label, gt_pseudo_label=Z, diff --git a/abl/reasoning/reasoner.py b/abl/reasoning/reasoner.py index 93e005e..0a15781 100644 --- a/abl/reasoning/reasoner.py +++ b/abl/reasoning/reasoner.py @@ -119,13 +119,13 @@ class ReasonerBase(): solution = Opt.min(objective, parameter).get_x() return solution - def revise_by_idx(self, pseudo_label, y, revision_idx): + def revise_by_idx(self, pred_pseudo_label, y, revision_idx): """ Get the revisions corresponding to the given indices. Parameters ---------- - pseudo_label : list + pred_pseudo_label : list List of predicted pseudo labels. y : str Ground truth for the predicted results. @@ -137,7 +137,7 @@ class ReasonerBase(): list The revisions corresponding to the given indices. """ - return self.kb.revise_by_idx(pseudo_label, y, revision_idx) + return self.kb.revise_by_idx(pred_pseudo_label, y, revision_idx) def abduce(self, pred_prob, pred_pseudo_label, y, max_revision=-1, require_more_revision=0): """