|
|
@@ -22,28 +22,27 @@ class SimpleBridge(BaseBridge): |
|
|
|
|
|
|
|
def predict(self, X) -> Tuple[List[List[Any]], ndarray]: |
|
|
|
pred_res = self.model.predict(X) |
|
|
|
pred_label, pred_prob = pred_res["label"], pred_res["prob"] |
|
|
|
return pred_label, pred_prob |
|
|
|
pred_idx, pred_prob = pred_res["label"], pred_res["prob"] |
|
|
|
return pred_idx, pred_prob |
|
|
|
|
|
|
|
def abduce_pseudo_label( |
|
|
|
self, |
|
|
|
pred_label: List[List[Any]], |
|
|
|
pred_prob: ndarray, |
|
|
|
pseudo_label: List[List[Any]], |
|
|
|
Y: List[List[Any]], |
|
|
|
pred_pseudo_label: List[List[Any]], |
|
|
|
Y: List[Any], |
|
|
|
max_revision: int = -1, |
|
|
|
require_more_revision: int = 0, |
|
|
|
) -> List[List[Any]]: |
|
|
|
return self.abducer.batch_abduce(pred_prob, pseudo_label, Y, max_revision, require_more_revision) |
|
|
|
return self.abducer.batch_abduce(pred_prob, pred_pseudo_label, Y, max_revision, require_more_revision) |
|
|
|
|
|
|
|
def label_to_pseudo_label( |
|
|
|
self, label: List[List[Any]], mapping: Dict = None |
|
|
|
def idx_to_pseudo_label( |
|
|
|
self, idx: List[List[Any]], mapping: Dict = None |
|
|
|
) -> List[List[Any]]: |
|
|
|
if mapping is None: |
|
|
|
mapping = self.abducer.mapping |
|
|
|
return [[mapping[_label] for _label in sub_list] for sub_list in label] |
|
|
|
return [[mapping[_idx] for _idx in sub_list] for sub_list in idx] |
|
|
|
|
|
|
|
def pseudo_label_to_label( |
|
|
|
def pseudo_label_to_idx( |
|
|
|
self, pseudo_label: List[List[Any]], mapping: Dict = None |
|
|
|
) -> List[List[Any]]: |
|
|
|
if mapping is None: |
|
|
@@ -69,12 +68,12 @@ class SimpleBridge(BaseBridge): |
|
|
|
|
|
|
|
for epoch in range(epochs): |
|
|
|
for seg_idx, (X, Z, Y) in enumerate(data_loader): |
|
|
|
pred_label, pred_prob = self.predict(X) |
|
|
|
pred_pseudo_label = self.label_to_pseudo_label(pred_label) |
|
|
|
pred_idx, pred_prob = self.predict(X) |
|
|
|
pred_pseudo_label = self.idx_to_pseudo_label(pred_idx) |
|
|
|
abduced_pseudo_label = self.abduce_pseudo_label( |
|
|
|
pred_label, pred_prob, pred_pseudo_label, Y |
|
|
|
pred_prob, pred_pseudo_label, Y |
|
|
|
) |
|
|
|
abduced_label = self.pseudo_label_to_label(abduced_pseudo_label) |
|
|
|
abduced_label = self.pseudo_label_to_idx(abduced_pseudo_label) |
|
|
|
min_loss = self.model.train(X, abduced_label) |
|
|
|
|
|
|
|
print_log( |
|
|
@@ -88,10 +87,10 @@ class SimpleBridge(BaseBridge): |
|
|
|
|
|
|
|
def _valid(self, data_loader): |
|
|
|
for X, Z, Y in data_loader: |
|
|
|
pred_label, pred_prob = self.predict(X) |
|
|
|
pred_pseudo_label = self.label_to_pseudo_label(pred_label) |
|
|
|
pred_idx, pred_prob = self.predict(X) |
|
|
|
pred_pseudo_label = self.idx_to_pseudo_label(pred_idx) |
|
|
|
data_samples = dict( |
|
|
|
pred_label=pred_label, |
|
|
|
pred_idx=pred_idx, |
|
|
|
pred_prob=pred_prob, |
|
|
|
pred_pseudo_label=pred_pseudo_label, |
|
|
|
gt_pseudo_label=Z, |
|
|
|