Browse Source

[MNT] unify variable names

pull/3/head
Gao Enhao 2 years ago
parent
commit
6ac0bb9378
3 changed files with 21 additions and 22 deletions
  1. +2
    -2
      abl/bridge/base_bridge.py
  2. +16
    -17
      abl/bridge/simple_bridge.py
  3. +3
    -3
      abl/reasoning/reasoner.py

+ 2
- 2
abl/bridge/base_bridge.py View File

@@ -26,12 +26,12 @@ class BaseBridge(metaclass=ABCMeta):
"""Placeholder for abduce pseudo labels."""

@abstractmethod
def label_to_pseudo_label(self, label: List[List[Any]]) -> List[List[Any]]:
def idx_to_pseudo_label(self, idx: List[List[Any]]) -> List[List[Any]]:
"""Placeholder for map label space to symbol space."""
pass

@abstractmethod
def pseudo_label_to_label(self, pseudo_label: List[List[Any]]) -> List[List[Any]]:
def pseudo_label_to_idx(self, pseudo_label: List[List[Any]]) -> List[List[Any]]:
"""Placeholder for map symbol space to label space."""
pass


+ 16
- 17
abl/bridge/simple_bridge.py View File

@@ -22,28 +22,27 @@ class SimpleBridge(BaseBridge):

def predict(self, X) -> Tuple[List[List[Any]], ndarray]:
pred_res = self.model.predict(X)
pred_label, pred_prob = pred_res["label"], pred_res["prob"]
return pred_label, pred_prob
pred_idx, pred_prob = pred_res["label"], pred_res["prob"]
return pred_idx, pred_prob
def abduce_pseudo_label(
self,
pred_label: List[List[Any]],
pred_prob: ndarray,
pseudo_label: List[List[Any]],
Y: List[List[Any]],
pred_pseudo_label: List[List[Any]],
Y: List[Any],
max_revision: int = -1,
require_more_revision: int = 0,
) -> List[List[Any]]:
return self.abducer.batch_abduce(pred_prob, pseudo_label, Y, max_revision, require_more_revision)
return self.abducer.batch_abduce(pred_prob, pred_pseudo_label, Y, max_revision, require_more_revision)

def label_to_pseudo_label(
self, label: List[List[Any]], mapping: Dict = None
def idx_to_pseudo_label(
self, idx: List[List[Any]], mapping: Dict = None
) -> List[List[Any]]:
if mapping is None:
mapping = self.abducer.mapping
return [[mapping[_label] for _label in sub_list] for sub_list in label]
return [[mapping[_idx] for _idx in sub_list] for sub_list in idx]

def pseudo_label_to_label(
def pseudo_label_to_idx(
self, pseudo_label: List[List[Any]], mapping: Dict = None
) -> List[List[Any]]:
if mapping is None:
@@ -69,12 +68,12 @@ class SimpleBridge(BaseBridge):

for epoch in range(epochs):
for seg_idx, (X, Z, Y) in enumerate(data_loader):
pred_label, pred_prob = self.predict(X)
pred_pseudo_label = self.label_to_pseudo_label(pred_label)
pred_idx, pred_prob = self.predict(X)
pred_pseudo_label = self.idx_to_pseudo_label(pred_idx)
abduced_pseudo_label = self.abduce_pseudo_label(
pred_label, pred_prob, pred_pseudo_label, Y
pred_prob, pred_pseudo_label, Y
)
abduced_label = self.pseudo_label_to_label(abduced_pseudo_label)
abduced_label = self.pseudo_label_to_idx(abduced_pseudo_label)
min_loss = self.model.train(X, abduced_label)

print_log(
@@ -88,10 +87,10 @@ class SimpleBridge(BaseBridge):

def _valid(self, data_loader):
for X, Z, Y in data_loader:
pred_label, pred_prob = self.predict(X)
pred_pseudo_label = self.label_to_pseudo_label(pred_label)
pred_idx, pred_prob = self.predict(X)
pred_pseudo_label = self.idx_to_pseudo_label(pred_idx)
data_samples = dict(
pred_label=pred_label,
pred_idx=pred_idx,
pred_prob=pred_prob,
pred_pseudo_label=pred_pseudo_label,
gt_pseudo_label=Z,


+ 3
- 3
abl/reasoning/reasoner.py View File

@@ -119,13 +119,13 @@ class ReasonerBase():
solution = Opt.min(objective, parameter).get_x()
return solution

def revise_by_idx(self, pseudo_label, y, revision_idx):
def revise_by_idx(self, pred_pseudo_label, y, revision_idx):
"""
Get the revisions corresponding to the given indices.

Parameters
----------
pseudo_label : list
pred_pseudo_label : list
List of predicted pseudo labels.
y : str
Ground truth for the predicted results.
@@ -137,7 +137,7 @@ class ReasonerBase():
list
The revisions corresponding to the given indices.
"""
return self.kb.revise_by_idx(pseudo_label, y, revision_idx)
return self.kb.revise_by_idx(pred_pseudo_label, y, revision_idx)

def abduce(self, pred_prob, pred_pseudo_label, y, max_revision=-1, require_more_revision=0):
"""


Loading…
Cancel
Save