Browse Source

[MNT] add two arguments to abduce_pseudo_label in SimpleBridge

pull/3/head
Gao Enhao 2 years ago
parent
commit
11ed5953e4
1 changed files with 4 additions and 2 deletions
  1. +4
    -2
      abl/bridge/simple_bridge.py

+ 4
- 2
abl/bridge/simple_bridge.py View File

@@ -24,15 +24,17 @@ class SimpleBridge(BaseBridge):
pred_res = self.model.predict(X)
pred_label, pred_prob = pred_res["label"], pred_res["prob"]
return pred_label, pred_prob
def abduce_pseudo_label(
self,
pred_label: List[List[Any]],
pred_prob: ndarray,
pseudo_label: List[List[Any]],
Y: List[List[Any]],
max_revision: int = -1,
require_more_revision: int = 0,
) -> List[List[Any]]:
return self.abducer.batch_abduce(pred_label, pred_prob, pseudo_label, Y)
return self.abducer.batch_abduce(pred_label, pred_prob, pseudo_label, Y, max_revision, require_more_revision)

def label_to_pseudo_label(
self, label: List[List[Any]], mapping: Dict = None


Loading…
Cancel
Save