|
|
@@ -24,15 +24,17 @@ class SimpleBridge(BaseBridge): |
|
|
|
pred_res = self.model.predict(X) |
|
|
|
pred_label, pred_prob = pred_res["label"], pred_res["prob"] |
|
|
|
return pred_label, pred_prob |
|
|
|
|
|
|
|
|
|
|
|
def abduce_pseudo_label( |
|
|
|
self, |
|
|
|
pred_label: List[List[Any]], |
|
|
|
pred_prob: ndarray, |
|
|
|
pseudo_label: List[List[Any]], |
|
|
|
Y: List[List[Any]], |
|
|
|
max_revision: int = -1, |
|
|
|
require_more_revision: int = 0, |
|
|
|
) -> List[List[Any]]: |
|
|
|
return self.abducer.batch_abduce(pred_label, pred_prob, pseudo_label, Y) |
|
|
|
return self.abducer.batch_abduce(pred_label, pred_prob, pseudo_label, Y, max_revision, require_more_revision) |
|
|
|
|
|
|
|
def label_to_pseudo_label( |
|
|
|
self, label: List[List[Any]], mapping: Dict = None |
|
|
|